Skip to content

cassetta.core

Overview

This contains core utilities that are mostly used internally, and whose API may be less stable than the rest of the package. Use at your own risk.

cassetta.core.typing

A set of cassetta-specific type hints, plus backward-compatible hints.

OneOrSeveral module-attribute

OneOrSeveral = Union[T, Sequence[T]]

Either a single value of a type, or a sequence of such values.

DeviceType module-attribute

DeviceType = Union[str, device]

An instantiated torch.device, or a string that allows instantiating one, such as "cpu", "cuda" or "cuda:0".

DataType module-attribute

DataType = Union[str, dtype, dtype, type]

A torch.dtype, np.dtype, or a string that represents one such data type.

See to_torch_dtype for details.

ActivationType module-attribute

ActivationType = Optional[Union[str, Module, Type[Module]]]

An instantiated nn.Module, or a nn.Module subtype, or the name of an activation class in torch.nn or in cassetta.layers.activations

See [make_activation][cassetta.layers.make_activation] for details.

NormType module-attribute

NormType = Optional[Union[Literal['batch', 'instance', 'layer'], Module, Type[Module]]]

An instantiated nn.Module, or a nn.Module subtype, or one of:

String Class
"batch" [BatchNorm][cassetta.layers.BatchNorm]
"instance" [InstanceNorm][cassetta.layers.InstanceNorm]
"layer" [LayerNorm][cassetta.layers.LayerNorm]

See [make_norm][cassetta.layers.make_norm] for details.

DropoutType module-attribute

DropoutType = Optional[Union[float, Module, Type[Module]]]

An instantiated nn.Module, or a nn.Module subtype, or a dropout probability between 0 and 1.

See [make_dropout][cassetta.layers.make_dropout] for details.

AttentionType module-attribute

AttentionType = Optional[Union[Literal['sqzex', 'cbam', 'dp', 'sdp', 'mha'], Module, Type[Module]]]

An instantiated nn.Module, or a nn.Module subtype, or one of:

String Class
"sqzex" [SqzEx][cassetta.layers.SqzEx]
"cbam" [BlockAttention][cassetta.layers.BlockAttention]
"dp" [DotProductAttention][cassetta.layers.DotProductAttention](scaled=False)
"sdp" [DotProductAttention][cassetta.layers.DotProductAttention](scaled=True)
"mha" [MultiHeadAttention][cassetta.layers.MultiHeadAttention]

See [make_attention][cassetta.layers.make_attention] for details.

InterpolationType module-attribute

InterpolationType = Union[int, Literal['nearest', 'linear', 'quadratic', 'cubic', 'fourth', 'fifth', 'sixth', 'seventh']]

The degree of B-splines used for interpolation, between 0 and 7, or one of their string aliases:

  1. "nearest"
  2. "linear"
  3. "quadratic"
  4. "cubic"
  5. "fourth"
  6. "fifth"
  7. "sixth"
  8. "seventh"

ModelType module-attribute

ModelType = Union[str, Module, Type[Module]]

A model can be:

  • the name of a cassetta model, such as "SegNet";
  • the fully qualified path to a model, such as "cassetta.models.ElasticRegNet", or "monai.networks.nets.ResNet";
  • a nn.Module subclass, such as [SegNet][cassetta.models.SegNet];
  • an already instantiated nn.Module, such as [SegNet(3, 1, 5)][cassetta.models.SegNet].

LossType module-attribute

LossType = Union[str, Module, Type[Module]]

A loss can be:

  • the name of a cassetta loss, such as "DiceLoss";
  • the fully qualified path to a model, such as "cassetta.losses.DiceLoss", or "monai.losses.GeneralizedDiceLoss";
  • a nn.Module subclass, such as [DiceLoss][cassetta.losses.DiceLoss];
  • an already instantiated nn.Module, such as [DiceLoss()][cassetta.losses.DiceLoss].

OptimType module-attribute

OptimType = Union[str, Type[Optimizer]]

A model can be:

  • the name of a torch optimizer, such as "Adam";
  • the fully qualified path to a model, such as "torch.optim.Adam";
  • a Optimizer subclass, such as Adam.

cassetta.core.utils

A set of various utility functions.

ensure_list

ensure_list(x, length=None, crop=True, **kwargs)

Ensure that an object is a list

The output list is of length at least length. When crop is True, its length is also at most length. If needed, the last value is replicated, unless default is provided.

If x is a list, nothing is done (no copy triggered). If it is a tuple, range, or generator, it is converted to a list. Otherwise, it is placed inside a list.

Source code in cassetta/core/utils.py
def ensure_list(
    x: Any,
    length: Optional[int] = None,
    crop: bool = True,
    **kwargs
) -> List:
    """
    Ensure that an object is a list

    The output list is of length at least `length`.
    When `crop` is `True`, its length is also at most `length`.
    If needed, the last value is replicated, unless `default` is provided.

    If x is a list, nothing is done (no copy triggered).
    If it is a tuple, range, or generator, it is converted to a list.
    Otherwise, it is placed inside a list.
    """
    if not isinstance(x, (list, tuple, range, generator)):
        x = [x]
    elif not isinstance(x, list):
        x = list(x)
    if length and len(x) < length:
        default = [kwargs.get('default', x[-1] if x else None)]
        x += default * (length - len(x))
    if length and crop:
        x = x[:length]
    return x

ensure_tuple

ensure_tuple(x, length=None, crop=True, **kwargs)

Ensure that an object is a tuple.

See ensure_list.

Source code in cassetta/core/utils.py
def ensure_tuple(
    x: Any,
    length: Optional[int] = None,
    crop: bool = True,
    **kwargs
) -> Tuple:
    """
    Ensure that an object is a tuple.

    See [`ensure_list`][cassetta.core.utils.ensure_list].
    """
    return tuple(ensure_list(x, length, crop, **kwargs))

make_vector

make_vector(input, length=None, crop=True, *, dtype=None, device=None, **kwargs)

Ensure that the input is a (tensor) vector and pad/crop if necessary.

PARAMETER DESCRIPTION
input

Input argument(s).

TYPE: scalar or sequence or GeneratorType

length

Target length.

TYPE: int DEFAULT: None

crop

Crop input sequence if longer than n.

TYPE: bool DEFAULT: True

Keyword Parameters

default : optional Default value to pad with. If not provided, replicate the last value. dtype : torch.dtype, optional Output data type. device : torch.device, optional Output device

RETURNS DESCRIPTION
output

Output vector.

TYPE: tensor

Source code in cassetta/core/utils.py
def make_vector(
    input: Any,
    length: Optional[int] = None,
    crop: bool = True,
    *,
    dtype: Optional[torch.dtype] = None,
    device: Optional[DeviceType] = None,
    **kwargs
) -> Tensor:
    """
    Ensure that the input is a (tensor) vector and pad/crop if necessary.

    Parameters
    ----------
    input : scalar or sequence or generator
        Input argument(s).
    length : int, optional
        Target length.
    crop : bool, default=True
        Crop input sequence if longer than `n`.

    Keyword Parameters
    ------------------
    default : optional
        Default value to pad with.
        If not provided, replicate the last value.
    dtype : torch.dtype, optional
        Output data type.
    device : torch.device, optional
        Output device

    Returns
    -------
    output : tensor
        Output vector.

    """
    input = torch.as_tensor(input, dtype=dtype, device=device).flatten()
    if length is None:
        return input
    if input.numel() >= length:
        return input[:length] if crop else input
    default = kwargs.get('default', input[-1] if input.numel() else 0)
    default = input.new_full([length - len(input)], default)
    return torch.cat([input, default])

torch_version

torch_version(mode, version)

Check torch version

PARAMETER DESCRIPTION
mode

TYPE: ('<', '<=', '>', '>=') DEFAULT: '<'

version

TYPE: tuple[int]

RETURNS DESCRIPTION
True if "torch.version <mode> version"
Source code in cassetta/core/utils.py
def torch_version(mode, version):
    """Check torch version

    Parameters
    ----------
    mode : {'<', '<=', '>', '>='}
    version : tuple[int]

    Returns
    -------
    True if "torch.version <mode> version"

    """
    current_version, *cuda_variant = torch.__version__.split('+')
    major, minor, patch, *_ = current_version.split('.')
    # strip alpha tags
    for x in 'abcdefghijklmnopqrstuvwxy':
        if x in patch:
            patch = patch[:patch.index(x)]
    current_version = (int(major), int(minor), int(patch))
    version = ensure_list(version)
    return _compare_versions(current_version, mode, version)

to_torch_dtype

to_torch_dtype(dtype, upcast=False, trunc=False)

Transform a python or numpy dtype or dtype name to a torch dtype.

Python -> PyTorch convention

We follow the PyTorch convention and convert float to torch.float32, int to torch.long and complex to torch.complex64.

PARAMETER DESCRIPTION
dtype

Input data type

TYPE: str or type or dtype or dtype

upcast

Upcast to nearest torch dtype if input dtype cannot be represented exactly. Else, raise a TypeError.

TYPE: bool DEFAULT: False

trunc

Trunc to nearest torch dtype if input dtype cannot be represented exactly. Else, raise a TypeError.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
dtype

Torch data type

TYPE: dtype

Source code in cassetta/core/utils.py
def to_torch_dtype(dtype, upcast=False, trunc=False):
    """
    Transform a python or numpy dtype or dtype name to a torch dtype.

    !!! warning "Python -> PyTorch convention"
        We follow the PyTorch convention and convert `float` to
        `torch.float32`, `int` to `torch.long` and `complex` to
        `torch.complex64`.

    Parameters
    ----------
    dtype : str or type or np.dtype or torch.dtype
        Input data type
    upcast : bool
        Upcast to nearest torch dtype if input dtype cannot be represented
        exactly. Else, raise a `TypeError`.
    trunc : bool
        Trunc to nearest torch dtype if input dtype cannot be represented
        exactly. Else, raise a `TypeError`.

    Returns
    -------
    dtype : torch.dtype
        Torch data type
    """
    if not dtype:
        return None

    # PyTorch data types
    if isinstance(dtype, torch.dtype):
        return dtype

    # Python builtin data types
    if dtype in _dtype_python2torch:
        return _dtype_python2torch[dtype]

    # Python number types
    if dtype in _dtype_numbers2torch:
        return _dtype_numbers2torch[dtype]

    # Strings for which we do not follow Numpy
    if dtype in _dtype_str2torch:
        return _dtype_str2torch[dtype]

    # Numpy data types
    dtype = np.dtype(dtype).type
    if dtype in _dtype_np2torch:
        return _dtype_np2torch[dtype]

    if dtype in _dtype_upcast2torch:
        if not upcast:
            raise TypeError('Cannot represent dtype in torch (upcast needed)')
        return _dtype_upcast2torch[dtype]

    if dtype in _dtype_trunc2torch:
        if not upcast:
            raise TypeError('Cannot represent dtype in torch (trunc needed)')
        return _dtype_trunc2torch[dtype]

    raise TypeError('Unknown type:', dtype)

import_submodules

import_submodules(submodules, module, all=None, import_into=False)

Pre-import submodules into parent module, so that we can do

import pck
x = pck.submodule.subsubmodule.function(3)
instead of
import pck.submodule.subsubmodule
x = pck.submodule.subsubmodule.function(3)

PARAMETER DESCRIPTION
submodules

Names of submodules to import

TYPE: list[str]

module

Path to parent module: __name__.

TYPE: str

all

Reference to the parent module's __all__, that then gets populated

TYPE: list[str] DEFAULT: None

import_into

Also import all objects from the submodule into the parent module

TYPE: bool DEFAULT: False

Source code in cassetta/core/utils.py
def import_submodules(submodules, module, all=None, import_into=False):
    """
    Pre-import submodules into parent module, so that we can do
    ```python
    import pck
    x = pck.submodule.subsubmodule.function(3)
    ```
    instead of
    ```python
    import pck.submodule.subsubmodule
    x = pck.submodule.subsubmodule.function(3)
    ```

    Parameters
    ----------
    submodules : list[str]
        Names of submodules to import
    module : str
        Path to parent module: `__name__`.
    all : list[str]
        Reference to the parent module's `__all__`, that then gets populated
    import_into : bool
        Also import all objects from the submodule into the parent module
    """
    parent_name = module
    parent = import_module(parent_name)
    for child_name in submodules:
        child = import_module('.' + child_name, parent_name)
        setattr(parent, child_name, child)
        if all is not None:
            all += [child_name]
        if import_into:
            for child_obj_name in child.__all__:
                setattr(parent, child_obj_name, getattr(child, child_obj_name))
                if all is not None:
                    all += [child_obj_name]

refresh_experiment_dir

refresh_experiment_dir(experiment_dir)

Check if the directory has contents, and if so, delete them recursively.

PARAMETER DESCRIPTION
dir_path

Path to the directory to be checked and cleared.

TYPE: str

RAISES DESCRIPTION
FileNotFoundError

If the directory does not exist.

Source code in cassetta/core/utils.py
def refresh_experiment_dir(experiment_dir: str) -> None:
    """
    Check if the directory has contents, and if so, delete them recursively.

    Parameters
    ----------
    dir_path : str
        Path to the directory to be checked and cleared.

    Raises
    ------
    FileNotFoundError
        If the directory does not exist.
    """
    if not os.path.exists(experiment_dir):
        raise FileNotFoundError(
            f"The directory {experiment_dir} does not exist."
            )

    # Check if directory is not empty
    if os.listdir(experiment_dir):
        # Recursively remove all contents of the directory
        for item in os.listdir(experiment_dir):
            item_path = os.path.join(experiment_dir, item)
            if os.path.isdir(item_path):
                shutil.rmtree(item_path)  # Remove directory and its contents
            else:
                os.remove(item_path)  # Remove file
        print(f"All contents of {experiment_dir} have been deleted.")
    else:
        print(f"The directory {experiment_dir} is already empty.")
    os.mkdir(f'{experiment_dir}/predictions')
    os.mkdir(f'{experiment_dir}/checkpoints')

delete_files_with_pattern

delete_files_with_pattern(directory, pattern)

Deletes all files in the specified directory that match the given pattern.

PARAMETER DESCRIPTION
directory

The path to the directory.

TYPE: str

pattern

The glob pattern to match files.

TYPE: str

Example

delete_files_with_pattern('/path/to/directory', 'last')

Source code in cassetta/core/utils.py
def delete_files_with_pattern(directory, pattern):
    """
    Deletes all files in the specified directory that match the given pattern.

    Parameters
    ----------
    directory : str
        The path to the directory.
    pattern : str
        The glob pattern to match files.

    Example
    -------
    delete_files_with_pattern('/path/to/directory', '*last*')
    """
    # Construct the full search pattern
    search_pattern = os.path.join(directory, pattern)

    # Retrieve a list of files matching the pattern
    files_to_delete = glob.glob(search_pattern)

    if not files_to_delete:
        pass

    for file_path in files_to_delete:
        try:
            os.remove(file_path)
        except Exception as e:
            print(f"Error deleting file {file_path}: {e}")

find_files_with_pattern

find_files_with_pattern(directory, pattern)

Search for files in the specified directory that match the given grep pattern.

PARAMETER DESCRIPTION
directory

The path to the directory to search.

TYPE: str

pattern

The grep pattern to search for.

TYPE: str

RETURNS DESCRIPTION
paths

A list of file paths that match the pattern.

TYPE: list

Source code in cassetta/core/utils.py
def find_files_with_pattern(directory, pattern):
    """
    Search for files in the specified directory that match the given grep
    pattern.

    Parameters
    ----------
    directory : str
        The path to the directory to search.
    pattern : str
        The grep pattern to search for.

    Returns
    -------
    paths : list
        A list of file paths that match the pattern.
    """

    # Check if the directory exists
    if not os.path.exists(directory):
        raise FileNotFoundError(f"The directory '{directory}' does not exist.")

    # Check if the path is a directory
    if not os.path.isdir(directory):
        raise NotADirectoryError(f"The path '{directory}' is not a directory.")

    # Construct the search pattern
    search_pattern = os.path.join(directory, pattern)

    # Use glob to find files matching the pattern
    matching_files = glob.glob(search_pattern)

    return matching_files

find_checkpoint

find_checkpoint(experiment_dir, checkpoint_type='best')

Find the checkpoint file in the specified experiment directory.

PARAMETER DESCRIPTION
experiment_dir

The path to the directory of the experiment.

TYPE: str

checkpoint_type

The type of checkpoint to find, either "last" or "best". Defaults to "best".

TYPE: str DEFAULT: 'best'

RETURNS DESCRIPTION
Optional[str]

The full path to the checkpoint file if found, otherwise None.

RAISES DESCRIPTION
ValueError

If checkpoint_type is not "last" or "best".

Source code in cassetta/core/utils.py
def find_checkpoint(
    experiment_dir: str,
    checkpoint_type: str = "best"
) -> Optional[str]:
    """
    Find the checkpoint file in the specified experiment directory.

    Parameters
    ----------
    experiment_dir : str
        The path to the directory of the experiment.
    checkpoint_type : str, optional
        The type of checkpoint to find, either "last" or "best".
        Defaults to "best".

    Returns
    -------
    Optional[str]
        The full path to the checkpoint file if found, otherwise None.

    Raises
    ------
    ValueError
        If checkpoint_type is not "last" or "best".
    """
    if checkpoint_type not in {"last", "best"}:
        raise ValueError("checkpoint_type must be 'last' or 'best'")

    # Define the pattern for file searching based on checkpoint type
    pattern = f"{checkpoint_type}-*.pt"
    # Search for files in the folder matching the pattern
    matches = glob.glob(os.path.join(experiment_dir, 'checkpoints', pattern))

    # If a match is found, return the first (there should only be one)
    return matches[0] if matches else None