Source code for masknmf.compression.pmd_array

from masknmf.arrays.array_interfaces import LazyFrameLoader, FactorizedVideo, ArrayLike
from masknmf.utils import Serializer
import torch
from typing import *
import numpy as np


def test_slice_effect(my_slice: slice, spatial_dim: int) -> bool:
    """
    Returns True if slice will actually have an effect
    """

    if not (
        (isinstance(my_slice.start, int) and my_slice.start == 0)
        or my_slice.start is None
    ):
        return True
    elif not (
        (isinstance(my_slice.stop, int) and my_slice.stop >= spatial_dim)
        or my_slice.stop is None
    ):
        return True
    elif not (
        my_slice.step is None or (isinstance(my_slice.step, int) and my_slice.step == 1)
    ):
        return True
    return False


def test_range_effect(my_range: range, spatial_dim: int) -> bool:
    """
    Returns True if the range will actually have an effect.

    Parameters:
    my_range (range): The range object to test.
    spatial_dim (int): The size of the dimension that the range is applied to.

    Returns:
    bool: True if the range will affect the selection; False otherwise.
    """
    # Check if the range starts from the beginning
    if my_range.start != 0:
        return True
    # Check if the range stops at the end of the dimension
    elif my_range.stop != spatial_dim:
        return True
    # Check if the range step is not 1
    elif my_range.step != 1:
        return True
    return False


def test_spatial_crop_effect(my_tuple, spatial_dims) -> bool:
    """
    Returns true if the tuple used for spatial cropping actually has an effect on the underlying data. Otherwise
    cropping can be an expensive and avoidable operation.
    """
    for k in range(len(my_tuple)):
        if isinstance(my_tuple[k], np.ndarray):
            if my_tuple[k].shape[0] < spatial_dims[k]:
                return True

        if isinstance(my_tuple[k], np.integer):
            return True

        if isinstance(my_tuple[k], int):
            return True

        if isinstance(my_tuple[k], slice):
            if test_slice_effect(my_tuple[k], spatial_dims[k]):
                return True
        if isinstance(my_tuple[k], range):
            if test_range_effect(my_tuple[k], spatial_dims[k]):
                return True
    return False


def _construct_identity_torch_sparse_tensor(dimsize: int, device: str = "cpu"):
    """
    Constructs an identity torch.sparse_coo_tensor on the specified device.

    Args:
        dimsize (int): The number of rows (or equivalently columns) of the torch.sparse_coo_tensor.
        device (str): 'cpu' or 'cuda'. The device on which the sparse tensor is constructed

    Returns:
        - (torch.sparse_coo_tensor): A (dimsize, dimsize) torch.sparse_coo_tensor.
    """
    # Indices for diagonal elements (rows and cols are the same for diagonal)
    row_col = torch.arange(dimsize, device=device)
    indices = torch.stack([row_col, row_col], dim=0)

    # Values (all ones)
    values = torch.ones(dimsize, device=device)

    sparse_tensor = torch.sparse_coo_tensor(indices, values, (dimsize, dimsize))
    return sparse_tensor

[docs] class PMDArray(FactorizedVideo, Serializer): """ Factorized demixing array for PMD movie """ _serialized = { "shape", "u", "v", "u_local_projector", "mean_img", "var_img" } def __init__( self, shape: Tuple[int, int, int] | np.ndarray, u: torch.sparse_coo_tensor, v: torch.tensor, mean_img: torch.tensor, var_img: torch.tensor, u_local_projector: Optional[torch.sparse_coo_tensor] = None, device: str = "cpu", rescale: bool = True, ): """ Key assumption: the spatial basis matrix U has n + k columns; the first n columns is blocksparse (this serves as a local spatial basis for the data) and the last k columns can have unconstrained spatial support (these serve as a global spatial basis for the data). Args: shape (tuple): (num_frames, fov_dim1, fov_dim2) u (torch.sparse_coo_tensor): shape (pixels, rank) v (torch.tensor): shape (rank, frames) mean_img (torch.tensor): shape (fov_dim1, fov_dim2). The pixelwise mean of the data var_img (torch.tensor): shape (fov_dim1, fov_dim2). A pixelwise noise normalizer for the data u_local_projector (Optional[torch.sparse_coo_tensor]): shape (pixels, rank) resid_std (torch.tensor): The residual standard deviation, shape (fov_dim1, fov_dim2) device (str): The device on which computations occur/data is stored rescale (bool): True if we rescale the PMD data (i.e. multiply by the pixelwise normalizer and add back the mean) in __getitem__ """ self._u = u.to(device).coalesce() self._device = self._u.device self._v = v.to(device) if u_local_projector is not None: self._u_local_projector = u_local_projector.to(device).coalesce() else: self._u_local_projector = None self._device = self._u.device self._shape = tuple(shape) self.pixel_mat = torch.arange( self.shape[1] * self.shape[2], device=self.device ).reshape(self.shape[1], self.shape[2]) self._order = "C" self._mean_img = mean_img.to(self.device).float() self._var_img = var_img.to(self.device).float() self._rescale = rescale @property def rescale(self) -> bool: return self._rescale @rescale.setter def rescale(self, new_state: bool): self._rescale = new_state @property def mean_img(self) -> torch.tensor: return self._mean_img @property def var_img(self) -> torch.tensor: return self._var_img @property def device(self) -> torch.device: return self._device def to(self, device: str): self._u = self._u.to(device) self._v = self._v.to(device) self._mean_img = self._mean_img.to(device) self._var_img = self._var_img.to(device) self.pixel_mat = self.pixel_mat.to(device) self._device = self._u.device if self.u_local_projector is not None: self._u_local_projector = self.u_local_projector.to(device) @property def u(self) -> torch.sparse_coo_tensor: return self._u @property def u_local_projector(self) -> Optional[torch.sparse_coo_tensor]: return self._u_local_projector @property def pmd_rank(self) -> int: return self.u.shape[1] @property def v(self) -> torch.tensor: return self._v @property def dtype(self) -> str: """ data type, default np.float32 """ return np.float32 @property def shape(self) -> Tuple[int, int, int]: """ Array shape (n_frames, dims_x, dims_y) """ return self._shape @property def order(self) -> str: """ The spatial data is "flattened" from 2D into 1D. This is not user-modifiable; "F" ordering is undesirable in PyTorch """ return self._order @property def ndim(self) -> int: """ Number of dimensions """ return len(self.shape)
[docs] def calculate_rank_heatmap(self) -> torch.tensor: """ Generates rank heatmap image based on U. Equal to row summation of binarized U matrix. Returns: rank_heatmap (torch.tensor). Shape (fov_dim1, fov_dim2). """ binarized_u = torch.sparse_coo_tensor( self.u.indices(), torch.ones_like(self.u.values()), self.u.size() ) row_sum_u = torch.sparse.sum(binarized_u, dim=1) return torch.reshape(row_sum_u.to_dense(), (self.shape[1],self.shape[2]))
[docs] def project_frames( self, frames: torch.tensor, standardize: Optional[bool] = True ) -> torch.tensor: """ Projects frames onto the spatial basis, using the u_projector property. u_projector must be defined. Args: frames (torch.tensor). Shape (fov_dim1, fov_dim2, num_frames) or (fov_dim1*fov_dim2, num_frames). Frames which we want to project onto the spatial basis. standardize (Optional[bool]): Indicates whether the frames of data are standardized before projection is performed Returns: projected_frames (torch.tensor). Shape (fov_dim1 * fov_dim2, num_frames). """ if self.u_local_projector is None: raise ValueError( "u_projector must be defined to project frames onto spatial basis" ) orig_device = frames.device frames = frames.to(self.device).float() if len(frames.shape) == 3: if standardize: frames = (frames - self.mean_img[..., None]) / self.var_img[ ..., None ] # Normalize the frames frames = torch.nan_to_num(frames, nan=0.0) frames = frames.reshape(self.shape[1] * self.shape[2], -1) else: if standardize: frames = ( frames - self.mean_img.flatten()[..., None] ) / self.var_img.flatten()[..., None] frames = torch.nan_to_num(frames, nan=0.0) projection = torch.sparse.mm(self.u_local_projector.T, frames) return projection.to(orig_device)
def getitem_tensor( self, item: Union[int, list, np.ndarray, Tuple[Union[int, np.ndarray, slice, range]]], ) -> torch.tensor: # Step 1: index the frames (dimension 0) if isinstance(item, tuple): if len(item) > len(self.shape): raise IndexError( f"Cannot index more dimensions than exist in the array. " f"You have tried to index with <{len(item)}> dimensions, " f"only <{len(self.shape)}> dimensions exist in the array" ) frame_indexer = item[0] else: frame_indexer = item # Step 2: Do some basic error handling for frame_indexer before using it to slice if isinstance(frame_indexer, np.ndarray): pass elif isinstance(frame_indexer, list): pass elif isinstance(frame_indexer, int): pass # numpy int scalar elif isinstance(frame_indexer, np.integer): frame_indexer = frame_indexer.item() # treat slice and range the same elif isinstance(frame_indexer, (slice, range)): start = frame_indexer.start stop = frame_indexer.stop step = frame_indexer.step if start is not None: if start > self.shape[0]: raise IndexError( f"Cannot index beyond `n_frames`.\n" f"Desired frame start index of <{start}> " f"lies beyond `n_frames` <{self.shape[0]}>" ) if stop is not None: if stop > self.shape[0]: raise IndexError( f"Cannot index beyond `n_frames`.\n" f"Desired frame stop index of <{stop}> " f"lies beyond `n_frames` <{self.shape[0]}>" ) if step is None: step = 1 frame_indexer = slice(start, stop, step) # in case it was a range object else: raise IndexError( f"Invalid indexing method, " f"you have passed a: <{type(item)}>" ) # Step 3: Now slice the data with frame_indexer (careful: if the ndims has shrunk, add a dim) v_crop = self._v[:, frame_indexer] if v_crop.ndim < self._v.ndim: v_crop = v_crop.unsqueeze(1) # Step 4: Deal with remaining indices after lazy computing the frame(s) if isinstance(item, tuple) and test_spatial_crop_effect( item[1:], self.shape[1:] ): if isinstance(item[1], np.ndarray) and len(item[1]) == 1: term_1 = slice(int(item[1]), int(item[1]) + 1) elif isinstance(item[1], np.integer): term_1 = slice(int(item[1]), int(item[1]) + 1) elif isinstance(item[1], int): term_1 = slice(item[1], item[1] + 1) else: term_1 = item[1] if isinstance(item[2], np.ndarray) and len(item[2]) == 1: term_2 = slice(int(item[2]), int(item[2]) + 1) elif isinstance(item[2], np.integer): term_2 = slice(int(item[2]), int(item[2]) + 1) elif isinstance(item[2], int): term_2 = slice(item[2], item[2] + 1) else: term_2 = item[2] spatial_crop_terms = (term_1, term_2) pixel_space_crop = self.pixel_mat[spatial_crop_terms] mean_img_crop = self.mean_img[spatial_crop_terms].flatten() var_img_crop = self.var_img[spatial_crop_terms].flatten() u_indices = pixel_space_crop.flatten() u_crop = torch.index_select(self._u, 0, u_indices) implied_fov = pixel_space_crop.shape else: u_crop = self._u mean_img_crop = self.mean_img.flatten() var_img_crop = self.var_img.flatten() implied_fov = self.shape[1], self.shape[2] product = torch.sparse.mm(u_crop, v_crop) * var_img_crop.unsqueeze(1) if self.rescale: product += mean_img_crop.unsqueeze(1) product = product.reshape((implied_fov[0], implied_fov[1], -1)) product = product.permute(2, 0, 1) return product def __getitem__( self, item: Union[int, list, np.ndarray, Tuple[Union[int, np.ndarray, slice, range]]], ) -> np.ndarray: product = self.getitem_tensor(item) product = product.cpu().numpy().astype(self.dtype).squeeze() return product
def convert_dense_image_stack_to_pmd_format(img_stack: Union[torch.tensor, np.ndarray]) -> PMDArray: """ Adapter for converting a dense np.ndarray image stack into a pmd_array. Note that this does not run PMD compression; it simply reformats the data into the SVD format needed to construct a PMDArray object. The resulting PMDArray should contain identical data to img_stack (up to numerical precision errors). All computations are done in numpy on CPU here because that is the only approach that produces an SVD of the raw data that is exactly equal to img_stack. Args: img_stack (Union[np.ndarray, torch.tensor]): A (frames, fov_dim1, fov_dim2) shaped image stack Returns: pmd_array (masknmf.compression.PMDArray): img_stack expressed in the pmd_array format. pmd_array contains the same data as img_stack. """ if isinstance(img_stack, np.ndarray): img_stack = torch.from_numpy(img_stack) if isinstance(img_stack, torch.Tensor): num_frames, fov_dim1, fov_dim2 = img_stack.shape img_stack_t = img_stack.permute(1, 2, 0).reshape( (fov_dim1 * fov_dim2, num_frames) ) u = _construct_identity_torch_sparse_tensor(fov_dim1 * fov_dim2, device="cpu") mean_img = torch.zeros(fov_dim1, fov_dim2, device="cpu", dtype=torch.float32) var_img = torch.ones(fov_dim1, fov_dim2, device="cpu", dtype=torch.float32) return PMDArray(img_stack.shape, u, img_stack_t, mean_img, var_img, u_local_projector=None, u_global_projector=None, device="cpu") else: raise ValueError(f"{type(img_stack)} not a supported type") class PMDResidualArray(ArrayLike): """ Factorized video for the spatial and temporal extracted sources from the data """ def __init__( self, raw_arr: Union[ArrayLike], pmd_arr: PMDArray, ): """ Args: raw_arr (LazyFrameLoader): Any object that supports LazyFrameLoder functionality pmd_arr (PMDArray) """ self.pmd_arr = pmd_arr self.raw_arr = raw_arr self._shape = self.pmd_arr.shape if self.pmd_arr.shape != self.raw_arr.shape: raise ValueError("Two image stacks do not have the same shape") @property def dtype(self) -> str: """ data type, default np.float32 """ return self.pmd_arr.dtype @property def shape(self) -> Tuple[int, int, int]: """ Array shape (n_frames, dims_x, dims_y) """ return self._shape @property def ndim(self) -> int: """ Number of dimensions """ return len(self.shape) def __getitem__( self, item: Union[int, list, np.ndarray, Tuple[Union[int, np.ndarray, slice, range]]], ): if self.pmd_arr.rescale is False: self.pmd_arr.rescale = True switch = True else: switch = False output = self.raw_arr[item].astype(self.dtype) - self.pmd_arr[item].astype(self.dtype) if switch: self.pmd_arr.rescale = False return output