from masknmf.arrays.array_interfaces import LazyFrameLoader, FactorizedVideo, ArrayLike
from masknmf.utils import Serializer
import torch
from typing import *
import numpy as np
def test_slice_effect(my_slice: slice, spatial_dim: int) -> bool:
"""
Returns True if slice will actually have an effect
"""
if not (
(isinstance(my_slice.start, int) and my_slice.start == 0)
or my_slice.start is None
):
return True
elif not (
(isinstance(my_slice.stop, int) and my_slice.stop >= spatial_dim)
or my_slice.stop is None
):
return True
elif not (
my_slice.step is None or (isinstance(my_slice.step, int) and my_slice.step == 1)
):
return True
return False
def test_range_effect(my_range: range, spatial_dim: int) -> bool:
"""
Returns True if the range will actually have an effect.
Parameters:
my_range (range): The range object to test.
spatial_dim (int): The size of the dimension that the range is applied to.
Returns:
bool: True if the range will affect the selection; False otherwise.
"""
# Check if the range starts from the beginning
if my_range.start != 0:
return True
# Check if the range stops at the end of the dimension
elif my_range.stop != spatial_dim:
return True
# Check if the range step is not 1
elif my_range.step != 1:
return True
return False
def test_spatial_crop_effect(my_tuple, spatial_dims) -> bool:
"""
Returns true if the tuple used for spatial cropping actually has an effect on the underlying data. Otherwise
cropping can be an expensive and avoidable operation.
"""
for k in range(len(my_tuple)):
if isinstance(my_tuple[k], np.ndarray):
if my_tuple[k].shape[0] < spatial_dims[k]:
return True
if isinstance(my_tuple[k], np.integer):
return True
if isinstance(my_tuple[k], int):
return True
if isinstance(my_tuple[k], slice):
if test_slice_effect(my_tuple[k], spatial_dims[k]):
return True
if isinstance(my_tuple[k], range):
if test_range_effect(my_tuple[k], spatial_dims[k]):
return True
return False
def _construct_identity_torch_sparse_tensor(dimsize: int, device: str = "cpu"):
"""
Constructs an identity torch.sparse_coo_tensor on the specified device.
Args:
dimsize (int): The number of rows (or equivalently columns) of the torch.sparse_coo_tensor.
device (str): 'cpu' or 'cuda'. The device on which the sparse tensor is constructed
Returns:
- (torch.sparse_coo_tensor): A (dimsize, dimsize) torch.sparse_coo_tensor.
"""
# Indices for diagonal elements (rows and cols are the same for diagonal)
row_col = torch.arange(dimsize, device=device)
indices = torch.stack([row_col, row_col], dim=0)
# Values (all ones)
values = torch.ones(dimsize, device=device)
sparse_tensor = torch.sparse_coo_tensor(indices, values, (dimsize, dimsize))
return sparse_tensor
[docs]
class PMDArray(FactorizedVideo, Serializer):
"""
Factorized demixing array for PMD movie
"""
_serialized = {
"shape",
"u",
"v",
"u_local_projector",
"mean_img",
"var_img"
}
def __init__(
self,
shape: Tuple[int, int, int] | np.ndarray,
u: torch.sparse_coo_tensor,
v: torch.tensor,
mean_img: torch.tensor,
var_img: torch.tensor,
u_local_projector: Optional[torch.sparse_coo_tensor] = None,
device: str = "cpu",
rescale: bool = True,
):
"""
Key assumption: the spatial basis matrix U has n + k columns; the first n columns is blocksparse (this serves
as a local spatial basis for the data) and the last k columns can have unconstrained spatial support (these serve
as a global spatial basis for the data).
Args:
shape (tuple): (num_frames, fov_dim1, fov_dim2)
u (torch.sparse_coo_tensor): shape (pixels, rank)
v (torch.tensor): shape (rank, frames)
mean_img (torch.tensor): shape (fov_dim1, fov_dim2). The pixelwise mean of the data
var_img (torch.tensor): shape (fov_dim1, fov_dim2). A pixelwise noise normalizer for the data
u_local_projector (Optional[torch.sparse_coo_tensor]): shape (pixels, rank)
resid_std (torch.tensor): The residual standard deviation, shape (fov_dim1, fov_dim2)
device (str): The device on which computations occur/data is stored
rescale (bool): True if we rescale the PMD data (i.e. multiply by the pixelwise normalizer
and add back the mean) in __getitem__
"""
self._u = u.to(device).coalesce()
self._device = self._u.device
self._v = v.to(device)
if u_local_projector is not None:
self._u_local_projector = u_local_projector.to(device).coalesce()
else:
self._u_local_projector = None
self._device = self._u.device
self._shape = tuple(shape)
self.pixel_mat = torch.arange(
self.shape[1] * self.shape[2], device=self.device
).reshape(self.shape[1], self.shape[2])
self._order = "C"
self._mean_img = mean_img.to(self.device).float()
self._var_img = var_img.to(self.device).float()
self._rescale = rescale
@property
def rescale(self) -> bool:
return self._rescale
@rescale.setter
def rescale(self, new_state: bool):
self._rescale = new_state
@property
def mean_img(self) -> torch.tensor:
return self._mean_img
@property
def var_img(self) -> torch.tensor:
return self._var_img
@property
def device(self) -> torch.device:
return self._device
def to(self, device: str):
self._u = self._u.to(device)
self._v = self._v.to(device)
self._mean_img = self._mean_img.to(device)
self._var_img = self._var_img.to(device)
self.pixel_mat = self.pixel_mat.to(device)
self._device = self._u.device
if self.u_local_projector is not None:
self._u_local_projector = self.u_local_projector.to(device)
@property
def u(self) -> torch.sparse_coo_tensor:
return self._u
@property
def u_local_projector(self) -> Optional[torch.sparse_coo_tensor]:
return self._u_local_projector
@property
def pmd_rank(self) -> int:
return self.u.shape[1]
@property
def v(self) -> torch.tensor:
return self._v
@property
def dtype(self) -> str:
"""
data type, default np.float32
"""
return np.float32
@property
def shape(self) -> Tuple[int, int, int]:
"""
Array shape (n_frames, dims_x, dims_y)
"""
return self._shape
@property
def order(self) -> str:
"""
The spatial data is "flattened" from 2D into 1D.
This is not user-modifiable; "F" ordering is undesirable in PyTorch
"""
return self._order
@property
def ndim(self) -> int:
"""
Number of dimensions
"""
return len(self.shape)
[docs]
def calculate_rank_heatmap(self) -> torch.tensor:
"""
Generates rank heatmap image based on U. Equal to row summation of binarized U matrix.
Returns:
rank_heatmap (torch.tensor). Shape (fov_dim1, fov_dim2).
"""
binarized_u = torch.sparse_coo_tensor(
self.u.indices(),
torch.ones_like(self.u.values()),
self.u.size()
)
row_sum_u = torch.sparse.sum(binarized_u, dim=1)
return torch.reshape(row_sum_u.to_dense(),
(self.shape[1],self.shape[2]))
[docs]
def project_frames(
self, frames: torch.tensor, standardize: Optional[bool] = True
) -> torch.tensor:
"""
Projects frames onto the spatial basis, using the u_projector property. u_projector must be defined.
Args:
frames (torch.tensor). Shape (fov_dim1, fov_dim2, num_frames) or (fov_dim1*fov_dim2, num_frames).
Frames which we want to project onto the spatial basis.
standardize (Optional[bool]): Indicates whether the frames of data are standardized before projection is performed
Returns:
projected_frames (torch.tensor). Shape (fov_dim1 * fov_dim2, num_frames).
"""
if self.u_local_projector is None:
raise ValueError(
"u_projector must be defined to project frames onto spatial basis"
)
orig_device = frames.device
frames = frames.to(self.device).float()
if len(frames.shape) == 3:
if standardize:
frames = (frames - self.mean_img[..., None]) / self.var_img[
..., None
] # Normalize the frames
frames = torch.nan_to_num(frames, nan=0.0)
frames = frames.reshape(self.shape[1] * self.shape[2], -1)
else:
if standardize:
frames = (
frames - self.mean_img.flatten()[..., None]
) / self.var_img.flatten()[..., None]
frames = torch.nan_to_num(frames, nan=0.0)
projection = torch.sparse.mm(self.u_local_projector.T, frames)
return projection.to(orig_device)
def getitem_tensor(
self,
item: Union[int, list, np.ndarray, Tuple[Union[int, np.ndarray, slice, range]]],
) -> torch.tensor:
# Step 1: index the frames (dimension 0)
if isinstance(item, tuple):
if len(item) > len(self.shape):
raise IndexError(
f"Cannot index more dimensions than exist in the array. "
f"You have tried to index with <{len(item)}> dimensions, "
f"only <{len(self.shape)}> dimensions exist in the array"
)
frame_indexer = item[0]
else:
frame_indexer = item
# Step 2: Do some basic error handling for frame_indexer before using it to slice
if isinstance(frame_indexer, np.ndarray):
pass
elif isinstance(frame_indexer, list):
pass
elif isinstance(frame_indexer, int):
pass
# numpy int scalar
elif isinstance(frame_indexer, np.integer):
frame_indexer = frame_indexer.item()
# treat slice and range the same
elif isinstance(frame_indexer, (slice, range)):
start = frame_indexer.start
stop = frame_indexer.stop
step = frame_indexer.step
if start is not None:
if start > self.shape[0]:
raise IndexError(
f"Cannot index beyond `n_frames`.\n"
f"Desired frame start index of <{start}> "
f"lies beyond `n_frames` <{self.shape[0]}>"
)
if stop is not None:
if stop > self.shape[0]:
raise IndexError(
f"Cannot index beyond `n_frames`.\n"
f"Desired frame stop index of <{stop}> "
f"lies beyond `n_frames` <{self.shape[0]}>"
)
if step is None:
step = 1
frame_indexer = slice(start, stop, step) # in case it was a range object
else:
raise IndexError(
f"Invalid indexing method, " f"you have passed a: <{type(item)}>"
)
# Step 3: Now slice the data with frame_indexer (careful: if the ndims has shrunk, add a dim)
v_crop = self._v[:, frame_indexer]
if v_crop.ndim < self._v.ndim:
v_crop = v_crop.unsqueeze(1)
# Step 4: Deal with remaining indices after lazy computing the frame(s)
if isinstance(item, tuple) and test_spatial_crop_effect(
item[1:], self.shape[1:]
):
if isinstance(item[1], np.ndarray) and len(item[1]) == 1:
term_1 = slice(int(item[1]), int(item[1]) + 1)
elif isinstance(item[1], np.integer):
term_1 = slice(int(item[1]), int(item[1]) + 1)
elif isinstance(item[1], int):
term_1 = slice(item[1], item[1] + 1)
else:
term_1 = item[1]
if isinstance(item[2], np.ndarray) and len(item[2]) == 1:
term_2 = slice(int(item[2]), int(item[2]) + 1)
elif isinstance(item[2], np.integer):
term_2 = slice(int(item[2]), int(item[2]) + 1)
elif isinstance(item[2], int):
term_2 = slice(item[2], item[2] + 1)
else:
term_2 = item[2]
spatial_crop_terms = (term_1, term_2)
pixel_space_crop = self.pixel_mat[spatial_crop_terms]
mean_img_crop = self.mean_img[spatial_crop_terms].flatten()
var_img_crop = self.var_img[spatial_crop_terms].flatten()
u_indices = pixel_space_crop.flatten()
u_crop = torch.index_select(self._u, 0, u_indices)
implied_fov = pixel_space_crop.shape
else:
u_crop = self._u
mean_img_crop = self.mean_img.flatten()
var_img_crop = self.var_img.flatten()
implied_fov = self.shape[1], self.shape[2]
product = torch.sparse.mm(u_crop, v_crop)
if self.rescale:
product *= var_img_crop.unsqueeze(1)
product += mean_img_crop.unsqueeze(1)
product = product.reshape((implied_fov[0], implied_fov[1], -1))
product = product.permute(2, 0, 1)
return product
def __getitem__(
self,
item: Union[int, list, np.ndarray, Tuple[Union[int, np.ndarray, slice, range]]],
) -> np.ndarray:
product = self.getitem_tensor(item)
product = product.cpu().numpy().astype(self.dtype).squeeze()
return product
class PMDResidualArray(ArrayLike):
"""
Factorized video for the spatial and temporal extracted sources from the data
"""
def __init__(
self,
raw_arr: Union[ArrayLike],
pmd_arr: PMDArray,
):
"""
Args:
raw_arr (LazyFrameLoader): Any object that supports LazyFrameLoder functionality
pmd_arr (PMDArray)
"""
self.pmd_arr = pmd_arr
self.raw_arr = raw_arr
self._shape = self.pmd_arr.shape
if self.pmd_arr.shape != self.raw_arr.shape:
raise ValueError("Two image stacks do not have the same shape")
@property
def dtype(self) -> str:
"""
data type, default np.float32
"""
return self.pmd_arr.dtype
@property
def shape(self) -> Tuple[int, int, int]:
"""
Array shape (n_frames, dims_x, dims_y)
"""
return self._shape
@property
def ndim(self) -> int:
"""
Number of dimensions
"""
return len(self.shape)
def __getitem__(
self,
item: Union[int, list, np.ndarray, Tuple[Union[int, np.ndarray, slice, range]]],
):
if self.pmd_arr.rescale is False:
self.pmd_arr.rescale = True
switch = True
else:
switch = False
output = self.raw_arr[item].astype(self.dtype) - self.pmd_arr[item].astype(self.dtype)
if switch:
self.pmd_arr.rescale = False
return output