Skip to content

cassetta.models

Overview

Models are end-to-end, task-specific networks. They usually rely on a more generic backbone architecture.

MODULE DESCRIPTION
segmentation

Models for semantic segmentation

registration

Models for image registration

make_model

make_model(model, *args, **kwargs)

Instantiate a model

A model can be:

  • the name of a cassetta model, such as "SegNet";
  • the fully qualified path to a model, such as "cassetta.models.ElasticRegNet", or "monai.networks.nets.ResNet";
  • a nn.Module subclass, such as [SegNet][cassetta.models.SegNet];
  • an already instantiated nn.Module, such as [SegNet(3, 1, 5)][cassetta.models.SegNet].
PARAMETER DESCRIPTION
model

Instantiated or non-instantiated model

TYPE: ModelType

*args

Positional arguments pass to the model constructor

TYPE: tuple DEFAULT: ()

**kwargs

Keyword arguments pass to the model constructor

TYPE: dict DEFAULT: {}

RETURNS DESCRIPTION
model

Instantiated model

TYPE: Module

Source code in cassetta/models/__init__.py
def make_model(model: ModelType, *args, **kwargs):
    """
    Instantiate a model

    A model can be:

    - the name of a cassetta model, such as `"SegNet"`;
    - the fully qualified path to a model, such as
    `"cassetta.models.ElasticRegNet"`, or `"monai.networks.nets.ResNet"`;
    - a [`nn.Module`][torch.nn.Module] subclass, such as
    [`SegNet`][cassetta.models.SegNet];
    - an already instantiated [`nn.Module`][torch.nn.Module], such as
    [`SegNet(3, 1, 5)`][cassetta.models.SegNet].

    Parameters
    ----------
    model : ModelType
        Instantiated or non-instantiated model
    *args : tuple
        Positional arguments pass to the model constructor
    **kwargs : dict
        Keyword arguments pass to the model constructor

    Returns
    -------
    model : nn.Module
        Instantiated model
    """
    reentrant = kwargs.pop('__reentrant', False)
    if isinstance(model, str):
        if not reentrant:
            kwargs['__reentrant'] = True
            for prefix in ('', 'cassetta.', 'cassetta.models.', 'torch.nn.'):
                try:
                    return make_model(prefix + model, *args, **kwargs)
                except Exception:
                    pass
        model = import_fullname(model)
    if not isinstance(model, Module):
        model = model(*args, **kwargs)
    if not isinstance(model, Module):
        raise ValueError('Instantiated object is not a Module')
    return model

cassetta.models.segmentation

SegNet

SegNet(ndim, inp_channels, out_channels, kernel_size=3, activation='Softmax', backbone='UNet', opt_backbone=None)

Bases: LoadableMixin, Sequential

A generic segmentation network that works with any backbone

Diagram

flowchart LR
    i["C"]:::i ---fx("Conv"):::w-->
    fi["F"]    ---b("Backbone"):::w -->
    fo["F"]    ---fk("Conv 1x1x1"):::w-->
    l["K"]     ---s(("σ")):::d-->
    o["K"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef n fill:none,stroke:none;
PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: (2, 3) DEFAULT: 2

inp_channels

Number of input channels

TYPE: int

out_channels

Number of output classes

TYPE: int

kernel_size

Kernel size of the initial feature extraction layer

TYPE: int DEFAULT: 3

activation

Final activation function

TYPE: str DEFAULT: 'Softmax'

backbone

Generic backbone module. Can be already instantiated.

Examples: [UNet][cassetta.backbones.UNet] (default), [ATrousNet][cassetta.backbones.ATrousNet], [MeshNet][cassetta.backbones.MeshNet].

TYPE: str or Module DEFAULT: 'UNet'

opt_backbone

Parameters of the backbone (if backbone is not pre-instantiated)

TYPE: dict DEFAULT: None

Source code in cassetta/models/segmentation.py
def __init__(
    self,
    ndim: int,
    inp_channels: int,
    out_channels: int,
    kernel_size: OneOrSeveral[int] = 3,
    activation: ActivationType = 'Softmax',
    backbone: Union[str, nn.Module] = 'UNet',
    opt_backbone: Optional[dict] = None,
):
    """

    Parameters
    ----------
    ndim : {2, 3}
        Number of spatial dimensions
    inp_channels : int
        Number of input channels
    out_channels : int
        Number of output classes
    kernel_size : int, default=3
        Kernel size **of the initial feature extraction** layer
    activation : str
        **Final** activation function
    backbone : str or Module
        Generic backbone module. Can be already instantiated.

        Examples:
        [`UNet`][cassetta.backbones.UNet] (default),
        [`ATrousNet`][cassetta.backbones.ATrousNet],
        [`MeshNet`][cassetta.backbones.MeshNet].
    opt_backbone : dict
        Parameters of the backbone (if backbone is not pre-instantiated)
    """
    # Backbone logic
    if isinstance(backbone, str):
        backbone_kls = getattr(backbones, backbone)
        backbone = backbone_kls(ndim, **(opt_backbone or {}))

    activation = make_activation(activation)

    feat = ConvBlock(
        ndim,
        inp_channels=inp_channels,
        out_channels=backbone.inp_channels,
        kernel_size=kernel_size,
        activation=None,
    )

    pred = ConvBlock(
        ndim,
        inp_channels=backbone.out_channels,
        out_channels=out_channels,
        kernel_size=1,
        activation=activation,
    )
    super().__init__(feat, backbone, pred)

predict_logits

predict_logits(inp)

Run the forward pass and return the logits (pre-softmax)

PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, inp_channels, *size) tensor

RETURNS DESCRIPTION
out

Logits

TYPE: (B, out_channels, *size) tensor

Source code in cassetta/models/segmentation.py
def predict_logits(self, inp):
    """
    Run the forward pass and return the logits (pre-softmax)

    Parameters
    ----------
    inp : (B, inp_channels, *size) tensor
        Input tensor

    Returns
    -------
    out : (B, out_channels, *size) tensor
        Logits
    """
    if not hasattr(self[-1], 'activation'):
        return self(inp)
    feat, backbone, pred = self
    return pred.conv(backbone(feat(inp)))

cassetta.models.registration

ElasticRegNet

ElasticRegNet(ndim, symmetric=False, nb_steps=0, inp_channels=1, kernel_size=3, activation=None, backbone='UNet', opt_backbone=None)

Bases: LoadableMixin, Sequential

A generic pairwise nonlinear registration network that works with any backbone

Not tested yet -- do not use

Diagram

flowchart LR
    subgraph "Prediction"
        cat(("c")):::d-->
        mf["C*2"]  ---fx("Conv"):::w-->
        fi["F"]    ---b("Backbone"):::w -->
        fo["F"]    ---fk("Conv 1x1x1"):::w
    end
    mov["C"]:::i & fix["C"]:::i ---cat
    fk --> o["D"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef n fill:none,stroke:none;
flowchart LR
    mf["(mov, fix)"]:::i ---p1("Prediction"):::w--> v1["D"]
    fm["(fix, mov)"]:::i ---p2("Prediction"):::w--> v2["D"]
    v1 & v2 ---minus(("-")):::d--> o["D"]:::o
    p1 -.-|"shared weights"| p2
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef n fill:none,stroke:none;
flowchart LR
    mov["<b>mov</b>\nC"]:::i & fix["<b>fix</b>\nC"]:::i ---
    pred("Prediction"):::w-->
    vel["<b>velocity</b>\nD"] ---exp("Exp"):::d -->
    flow["<b>flow</b>\nD"]
    mov & flow ---w("Pull"):::d--> out["<b>moved</b>\nC"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef n fill:none,stroke:none;
PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: (2, 3) DEFAULT: 2

symmetric

Make the network symmetric by averaging model(mov, fix) and model(fix, mov).

TYPE: bool DEFAULT: False

nb_steps

Number of scaling and squaring steps.

TYPE: int DEFAULT: 0

inp_channels

Number of input channels, per image.

TYPE: int DEFAULT: 1

kernel_size

Kernel size of the initial feature extraction layer

TYPE: int DEFAULT: 3

activation

Final activation function

TYPE: str DEFAULT: None

backbone

Generic backbone module. Can be already instantiated.

Examples: [UNet][cassetta.backbones.UNet] (default), [ATrousNet][cassetta.backbones.ATrousNet], [MeshNet][cassetta.backbones.MeshNet].

TYPE: str or Module DEFAULT: 'UNet'

opt_backbone

Parameters of the backbone (if backbone is not pre-instantiated).

Note that, unless user-defined, we set activation="LeakyReLU".

TYPE: dict DEFAULT: None

Source code in cassetta/models/registration.py
def __init__(
    self,
    ndim: int,
    symmetric: bool = False,
    nb_steps: int = 0,
    inp_channels: int = 1,
    kernel_size: OneOrSeveral[int] = 3,
    activation: ActivationType = None,
    backbone: Union[str, nn.Module] = 'UNet',
    opt_backbone: Optional[dict] = None,
):
    """
    Parameters
    ----------
    ndim : {2, 3}
        Number of spatial dimensions
    symmetric : bool
        Make the network symmetric by averaging `model(mov, fix)` and
        `model(fix, mov)`.
    nb_steps : int
        Number of scaling and squaring steps.
    inp_channels : int
        Number of input channels, per image.
    kernel_size : int, default=3
        Kernel size **of the initial feature extraction** layer
    activation : str
        **Final** activation function
    backbone : str or Module
        Generic backbone module. Can be already instantiated.

        Examples:
        [`UNet`][cassetta.backbones.UNet] (default),
        [`ATrousNet`][cassetta.backbones.ATrousNet],
        [`MeshNet`][cassetta.backbones.MeshNet].
    opt_backbone : dict
        Parameters of the backbone (if backbone is not pre-instantiated).

        Note that, unless user-defined, we set `activation="LeakyReLU"`.
    """
    opt_backbone.setdefault('activation', 'LeakyReLU')
    if isinstance(backbone, str):
        backbone_kls = getattr(backbones, backbone)
        backbone = backbone_kls(ndim, **(opt_backbone or {}))
    activation = make_activation(activation)
    feat = ConvBlock(
        ndim,
        inp_channels=inp_channels*2,
        out_channels=backbone.inp_channels,
        kernel_size=kernel_size,
        activation=None,
    )
    pred = ConvBlock(
        ndim,
        inp_channels=backbone.out_channels,
        out_channels=ndim,
        kernel_size=1,
        activation=activation,
    )
    super().__init__()
    self.net = nn.Sequential(feat, backbone, pred)
    self.exp = FlowExp(nb_steps)
    self.warp = FlowPull()
    self.symmetric = symmetric

forward

forward(mov, fix)

Predict the encoding of the displacement field

PARAMETER DESCRIPTION
mov

Moving image

TYPE: (B, inp_channels, *size) tensor

fix

Fixed image

TYPE: (B, inp_channels, *size) tensor

RETURNS DESCRIPTION
vel
  • If nb_steps>0: stationary velocity field
  • Else: voxel displacement field

TYPE: (B, ndim, *size) tensor

Source code in cassetta/models/registration.py
def forward(self, mov, fix):
    """
    Predict the encoding of the displacement field

    Parameters
    ----------
    mov : (B, inp_channels, *size) tensor
        Moving image
    fix : (B, inp_channels, *size) tensor
        Fixed image

    Returns
    -------
    vel : (B, ndim, *size) tensor
        - If `nb_steps>0`: stationary velocity field
        - Else: voxel displacement field
    """
    vel = self.net(Cat()(mov, fix))
    if self.symmetric:
        vel -= self.net(Cat()(fix, mov))
        vel *= 0.5
    return vel

predict_flow

predict_flow(mov, fix)

Predict the forward displacement field, used to warp mov to fix

PARAMETER DESCRIPTION
mov

Moving image

TYPE: (B, inp_channels, *size) tensor

fix

Fixed image

TYPE: (B, inp_channels, *size) tensor

RETURNS DESCRIPTION
flow

Voxel displacement field

TYPE: (B, ndim, *size) tensor

Source code in cassetta/models/registration.py
def predict_flow(self, mov, fix):
    """
    Predict the forward displacement field, used to warp `mov` to `fix`

    Parameters
    ----------
    mov : (B, inp_channels, *size) tensor
        Moving image
    fix : (B, inp_channels, *size) tensor
        Fixed image

    Returns
    -------
    flow : (B, ndim, *size) tensor
        Voxel displacement field
    """
    return self.exp(self(mov, fix))

predict_flows

predict_flows(mov, fix)

Predict the forward displacement field (used to warp mov to fix) and thr backward displacement field (used to warp fix to mov)

PARAMETER DESCRIPTION
mov

Moving image

TYPE: (B, inp_channels, *size) tensor

fix

Fixed image

TYPE: (B, inp_channels, *size) tensor

RETURNS DESCRIPTION
flow_forward

Forward voxel displacement field

TYPE: (B, ndim, *size) tensor

flow_backward

Backward voxel displacement field

TYPE: (B, ndim, *size) tensor

Source code in cassetta/models/registration.py
def predict_flows(self, mov, fix):
    """
    Predict the forward displacement field (used to warp `mov` to `fix`)
    and thr backward displacement field (used to warp `fix` to `mov`)

    Parameters
    ----------
    mov : (B, inp_channels, *size) tensor
        Moving image
    fix : (B, inp_channels, *size) tensor
        Fixed image

    Returns
    -------
    flow_forward : (B, ndim, *size) tensor
        Forward voxel displacement field
    flow_backward : (B, ndim, *size) tensor
        Backward voxel displacement field
    """
    vel = self(mov, fix)
    return self.exp(vel), self.exp(-vel)

predict_moved

predict_moved(mov, fix)

Predict the warped moving image

PARAMETER DESCRIPTION
mov

Moving image

TYPE: (B, inp_channels, *size) tensor

fix

Fixed image

TYPE: (B, inp_channels, *size) tensor

RETURNS DESCRIPTION
moved

Moved image

TYPE: (B, inp_channels, *size) tensor

Source code in cassetta/models/registration.py
def predict_moved(self, mov, fix):
    """
    Predict the warped moving image

    Parameters
    ----------
    mov : (B, inp_channels, *size) tensor
        Moving image
    fix : (B, inp_channels, *size) tensor
        Fixed image

    Returns
    -------
    moved : (B, inp_channels, *size) tensor
        Moved image
    """
    flow = self.predict_flow(mov, fix)
    return self.warp(mov, flow)

predict_both_moved

predict_both_moved(mov, fix)

Predict the warped moving image, and the warped fixed image.

PARAMETER DESCRIPTION
mov

Moving image

TYPE: (B, inp_channels, *size) tensor

fix

Fixed image

TYPE: (B, inp_channels, *size) tensor

RETURNS DESCRIPTION
warped_mov

Moving image warped to fixed space

TYPE: (B, inp_channels, *size) tensor

warped_fix

Fixed image warped to moving space

TYPE: (B, inp_channels, *size) tensor

Source code in cassetta/models/registration.py
def predict_both_moved(self, mov, fix):
    """
    Predict the warped moving image, and the warped fixed image.

    Parameters
    ----------
    mov : (B, inp_channels, *size) tensor
        Moving image
    fix : (B, inp_channels, *size) tensor
        Fixed image

    Returns
    -------
    warped_mov : (B, inp_channels, *size) tensor
        Moving image warped to fixed space
    warped_fix : (B, inp_channels, *size) tensor
        Fixed image warped to moving space
    """
    fwd, bwd = self.predict_flows(mov, fix)
    return self.warp(mov, fwd), self.warp(fix, bwd)