Source code for masknmf.compression.pmd_array

from masknmf.arrays.array_interfaces import LazyFrameLoader, FactorizedVideo, ArrayLike
from masknmf.utils import Serializer
import torch
from typing import *
import numpy as np


def test_slice_effect(my_slice: slice, spatial_dim: int) -> bool:
    """
    Returns True if slice will actually have an effect
    """

    if not (
        (isinstance(my_slice.start, int) and my_slice.start == 0)
        or my_slice.start is None
    ):
        return True
    elif not (
        (isinstance(my_slice.stop, int) and my_slice.stop >= spatial_dim)
        or my_slice.stop is None
    ):
        return True
    elif not (
        my_slice.step is None or (isinstance(my_slice.step, int) and my_slice.step == 1)
    ):
        return True
    return False


def test_range_effect(my_range: range, spatial_dim: int) -> bool:
    """
    Returns True if the range will actually have an effect.

    Parameters:
    my_range (range): The range object to test.
    spatial_dim (int): The size of the dimension that the range is applied to.

    Returns:
    bool: True if the range will affect the selection; False otherwise.
    """
    # Check if the range starts from the beginning
    if my_range.start != 0:
        return True
    # Check if the range stops at the end of the dimension
    elif my_range.stop != spatial_dim:
        return True
    # Check if the range step is not 1
    elif my_range.step != 1:
        return True
    return False


def test_spatial_crop_effect(my_tuple, spatial_dims) -> bool:
    """
    Returns true if the tuple used for spatial cropping actually has an effect on the underlying data. Otherwise
    cropping can be an expensive and avoidable operation.
    """
    for k in range(len(my_tuple)):
        if isinstance(my_tuple[k], np.ndarray):
            if my_tuple[k].shape[0] < spatial_dims[k]:
                return True

        if isinstance(my_tuple[k], np.integer):
            return True

        if isinstance(my_tuple[k], int):
            return True

        if isinstance(my_tuple[k], slice):
            if test_slice_effect(my_tuple[k], spatial_dims[k]):
                return True
        if isinstance(my_tuple[k], range):
            if test_range_effect(my_tuple[k], spatial_dims[k]):
                return True
    return False


def _construct_identity_torch_sparse_tensor(dimsize: int, device: str = "cpu"):
    """
    Constructs an identity torch.sparse_coo_tensor on the specified device.

    Args:
        dimsize (int): The number of rows (or equivalently columns) of the torch.sparse_coo_tensor.
        device (str): 'cpu' or 'cuda'. The device on which the sparse tensor is constructed

    Returns:
        - (torch.sparse_coo_tensor): A (dimsize, dimsize) torch.sparse_coo_tensor.
    """
    # Indices for diagonal elements (rows and cols are the same for diagonal)
    row_col = torch.arange(dimsize, device=device)
    indices = torch.stack([row_col, row_col], dim=0)

    # Values (all ones)
    values = torch.ones(dimsize, device=device)

    sparse_tensor = torch.sparse_coo_tensor(indices, values, (dimsize, dimsize))
    return sparse_tensor

[docs] class PMDArray(FactorizedVideo, Serializer): """ Factorized demixing array for PMD movie """ _serialized = { "shape", "u", "v", "u_local_projector", "mean_img", "var_img" } def __init__( self, shape: Tuple[int, int, int] | np.ndarray, u: torch.sparse_coo_tensor, v: torch.tensor, mean_img: torch.tensor, var_img: torch.tensor, u_local_projector: Optional[torch.sparse_coo_tensor] = None, device: str = "cpu", rescale: bool = True, ): """ Key assumption: the spatial basis matrix U has n + k columns; the first n columns is blocksparse (this serves as a local spatial basis for the data) and the last k columns can have unconstrained spatial support (these serve as a global spatial basis for the data). Args: shape (tuple): (num_frames, fov_dim1, fov_dim2) u (torch.sparse_coo_tensor): shape (pixels, rank) v (torch.tensor): shape (rank, frames) mean_img (torch.tensor): shape (fov_dim1, fov_dim2). The pixelwise mean of the data var_img (torch.tensor): shape (fov_dim1, fov_dim2). A pixelwise noise normalizer for the data u_local_projector (Optional[torch.sparse_coo_tensor]): shape (pixels, rank) resid_std (torch.tensor): The residual standard deviation, shape (fov_dim1, fov_dim2) device (str): The device on which computations occur/data is stored rescale (bool): True if we rescale the PMD data (i.e. multiply by the pixelwise normalizer and add back the mean) in __getitem__ """ self._u = u.to(device).coalesce() self._device = self._u.device self._v = v.to(device) if u_local_projector is not None: self._u_local_projector = u_local_projector.to(device).coalesce() else: self._u_local_projector = None self._device = self._u.device self._shape = tuple(shape) self.pixel_mat = torch.arange( self.shape[1] * self.shape[2], device=self.device ).reshape(self.shape[1], self.shape[2]) self._order = "C" self._mean_img = mean_img.to(self.device).float() self._var_img = var_img.to(self.device).float() self._rescale = rescale @property def rescale(self) -> bool: return self._rescale @rescale.setter def rescale(self, new_state: bool): self._rescale = new_state @property def mean_img(self) -> torch.tensor: return self._mean_img @property def var_img(self) -> torch.tensor: return self._var_img @property def device(self) -> torch.device: return self._device def to(self, device: str): self._u = self._u.to(device) self._v = self._v.to(device) self._mean_img = self._mean_img.to(device) self._var_img = self._var_img.to(device) self.pixel_mat = self.pixel_mat.to(device) self._device = self._u.device if self.u_local_projector is not None: self._u_local_projector = self.u_local_projector.to(device) @property def u(self) -> torch.sparse_coo_tensor: return self._u @property def u_local_projector(self) -> Optional[torch.sparse_coo_tensor]: return self._u_local_projector @property def pmd_rank(self) -> int: return self.u.shape[1] @property def v(self) -> torch.tensor: return self._v @property def dtype(self) -> str: """ data type, default np.float32 """ return np.float32 @property def shape(self) -> Tuple[int, int, int]: """ Array shape (n_frames, dims_x, dims_y) """ return self._shape @property def order(self) -> str: """ The spatial data is "flattened" from 2D into 1D. This is not user-modifiable; "F" ordering is undesirable in PyTorch """ return self._order @property def ndim(self) -> int: """ Number of dimensions """ return len(self.shape)
[docs] def calculate_rank_heatmap(self) -> torch.tensor: """ Generates rank heatmap image based on U. Equal to row summation of binarized U matrix. Returns: rank_heatmap (torch.tensor). Shape (fov_dim1, fov_dim2). """ binarized_u = torch.sparse_coo_tensor( self.u.indices(), torch.ones_like(self.u.values()), self.u.size() ) row_sum_u = torch.sparse.sum(binarized_u, dim=1) return torch.reshape(row_sum_u.to_dense(), (self.shape[1],self.shape[2]))
[docs] def project_frames( self, frames: torch.tensor, standardize: Optional[bool] = True ) -> torch.tensor: """ Projects frames onto the spatial basis, using the u_projector property. u_projector must be defined. Args: frames (torch.tensor). Shape (fov_dim1, fov_dim2, num_frames) or (fov_dim1*fov_dim2, num_frames). Frames which we want to project onto the spatial basis. standardize (Optional[bool]): Indicates whether the frames of data are standardized before projection is performed Returns: projected_frames (torch.tensor). Shape (fov_dim1 * fov_dim2, num_frames). """ if self.u_local_projector is None: raise ValueError( "u_projector must be defined to project frames onto spatial basis" ) orig_device = frames.device frames = frames.to(self.device).float() if len(frames.shape) == 3: if standardize: frames = (frames - self.mean_img[..., None]) / self.var_img[ ..., None ] # Normalize the frames frames = torch.nan_to_num(frames, nan=0.0) frames = frames.reshape(self.shape[1] * self.shape[2], -1) else: if standardize: frames = ( frames - self.mean_img.flatten()[..., None] ) / self.var_img.flatten()[..., None] frames = torch.nan_to_num(frames, nan=0.0) projection = torch.sparse.mm(self.u_local_projector.T, frames) return projection.to(orig_device)
def getitem_tensor( self, item: Union[int, list, np.ndarray, Tuple[Union[int, np.ndarray, slice, range]]], ) -> torch.tensor: # Step 1: index the frames (dimension 0) if isinstance(item, tuple): if len(item) > len(self.shape): raise IndexError( f"Cannot index more dimensions than exist in the array. " f"You have tried to index with <{len(item)}> dimensions, " f"only <{len(self.shape)}> dimensions exist in the array" ) frame_indexer = item[0] else: frame_indexer = item # Step 2: Do some basic error handling for frame_indexer before using it to slice if isinstance(frame_indexer, np.ndarray): pass elif isinstance(frame_indexer, list): pass elif isinstance(frame_indexer, int): pass # numpy int scalar elif isinstance(frame_indexer, np.integer): frame_indexer = frame_indexer.item() # treat slice and range the same elif isinstance(frame_indexer, (slice, range)): start = frame_indexer.start stop = frame_indexer.stop step = frame_indexer.step if start is not None: if start > self.shape[0]: raise IndexError( f"Cannot index beyond `n_frames`.\n" f"Desired frame start index of <{start}> " f"lies beyond `n_frames` <{self.shape[0]}>" ) if stop is not None: if stop > self.shape[0]: raise IndexError( f"Cannot index beyond `n_frames`.\n" f"Desired frame stop index of <{stop}> " f"lies beyond `n_frames` <{self.shape[0]}>" ) if step is None: step = 1 frame_indexer = slice(start, stop, step) # in case it was a range object else: raise IndexError( f"Invalid indexing method, " f"you have passed a: <{type(item)}>" ) # Step 3: Now slice the data with frame_indexer (careful: if the ndims has shrunk, add a dim) v_crop = self._v[:, frame_indexer] if v_crop.ndim < self._v.ndim: v_crop = v_crop.unsqueeze(1) # Step 4: Deal with remaining indices after lazy computing the frame(s) if isinstance(item, tuple) and test_spatial_crop_effect( item[1:], self.shape[1:] ): if isinstance(item[1], np.ndarray) and len(item[1]) == 1: term_1 = slice(int(item[1]), int(item[1]) + 1) elif isinstance(item[1], np.integer): term_1 = slice(int(item[1]), int(item[1]) + 1) elif isinstance(item[1], int): term_1 = slice(item[1], item[1] + 1) else: term_1 = item[1] if isinstance(item[2], np.ndarray) and len(item[2]) == 1: term_2 = slice(int(item[2]), int(item[2]) + 1) elif isinstance(item[2], np.integer): term_2 = slice(int(item[2]), int(item[2]) + 1) elif isinstance(item[2], int): term_2 = slice(item[2], item[2] + 1) else: term_2 = item[2] spatial_crop_terms = (term_1, term_2) pixel_space_crop = self.pixel_mat[spatial_crop_terms] mean_img_crop = self.mean_img[spatial_crop_terms].flatten() var_img_crop = self.var_img[spatial_crop_terms].flatten() u_indices = pixel_space_crop.flatten() u_crop = torch.index_select(self._u, 0, u_indices) implied_fov = pixel_space_crop.shape else: u_crop = self._u mean_img_crop = self.mean_img.flatten() var_img_crop = self.var_img.flatten() implied_fov = self.shape[1], self.shape[2] product = torch.sparse.mm(u_crop, v_crop) if self.rescale: product *= var_img_crop.unsqueeze(1) product += mean_img_crop.unsqueeze(1) product = product.reshape((implied_fov[0], implied_fov[1], -1)) product = product.permute(2, 0, 1) return product def __getitem__( self, item: Union[int, list, np.ndarray, Tuple[Union[int, np.ndarray, slice, range]]], ) -> np.ndarray: product = self.getitem_tensor(item) product = product.cpu().numpy().astype(self.dtype).squeeze() return product
class PMDResidualArray(ArrayLike): """ Factorized video for the spatial and temporal extracted sources from the data """ def __init__( self, raw_arr: Union[ArrayLike], pmd_arr: PMDArray, ): """ Args: raw_arr (LazyFrameLoader): Any object that supports LazyFrameLoder functionality pmd_arr (PMDArray) """ self.pmd_arr = pmd_arr self.raw_arr = raw_arr self._shape = self.pmd_arr.shape if self.pmd_arr.shape != self.raw_arr.shape: raise ValueError("Two image stacks do not have the same shape") @property def dtype(self) -> str: """ data type, default np.float32 """ return self.pmd_arr.dtype @property def shape(self) -> Tuple[int, int, int]: """ Array shape (n_frames, dims_x, dims_y) """ return self._shape @property def ndim(self) -> int: """ Number of dimensions """ return len(self.shape) def __getitem__( self, item: Union[int, list, np.ndarray, Tuple[Union[int, np.ndarray, slice, range]]], ): if self.pmd_arr.rescale is False: self.pmd_arr.rescale = True switch = True else: switch = False output = self.raw_arr[item].astype(self.dtype) - self.pmd_arr[item].astype(self.dtype) if switch: self.pmd_arr.rescale = False return output