Skip to content

cassetta.layers

Overview

Layers are relatively simple modules or sequences of modules. They should be used as basic blocks for building backbones.

MODULE DESCRIPTION
activations

Activation functions

attention

Attention layers (squeeze & excite, dot-product, multi-head, ...)

conv

Basic N-dimensional convolution layers

convblocks

Building blocks for convolutional networks

dropout

N-dimensional dropout layers

interpol

N-dimensional interpolation and grid sampling

linear

Linear layer (slightly more practical than PyTorch's)

simple

A bunch of embarassingly simple layers (Cat, Sum, ...)

updown

Different ways to upsample and downsample

cassetta.layers.activations

SymExp

Bases: Module

Symmetric Exponential Activation

SymExp(x) = sign(x) * (exp(abs(x)) - 1)

SymLog

Bases: Module

Symmetric Logarithmic Activation

SymLog(x) = sign(x) * log(1 + abs(x))

make_activation

make_activation(activation, **kwargs)

Instantiate an activation module.

To be accepted in a nn.Sequential module or in a nn.ModuleList, an activation must be a nn.Module. This function takes other forms of "activation parameters" that are typically passed to the constructor of larger models, and generate the corresponding instantiated Module.

An activation-like can be a nn.Module subclass, which is then instantiated, or a callable function that returns an instantiated Module.

It is useful to accept both these cases as they allow to either:

  • have a learnable activation specific to this module
  • have a learnable activation shared with other modules
  • have a non-learnable activation
PARAMETER DESCRIPTION
activation

An already instantiated nn.Module, or a nn.Module subclass, or a callable that retgurns an instantiated nn.Module, or the name of an activation type from nn. For example: "ReLU" "LeakyReLU", "ELU", "GELU", "Tanh", etc.

TYPE: ActivationType

kwargs

Additional parameters to pass to the constructor or function.

TYPE: dict DEFAULT: {}

RETURNS DESCRIPTION
activation

An instantiated nn.Module.

TYPE: Module

Source code in cassetta/layers/activations.py
def make_activation(activation, **kwargs):
    """
    Instantiate an activation module.

    To be accepted in a `nn.Sequential` module or in a `nn.ModuleList`,
    an activation **must** be a `nn.Module`. This function takes other
    forms of "activation parameters" that are typically passed to the
    constructor of larger models, and generate the corresponding
    instantiated Module.

    An activation-like can be a `nn.Module` subclass, which is
    then instantiated, or a callable function that returns an
    instantiated Module.

    It is useful to accept both these cases as they allow to either:

    * have a learnable activation specific to this module
    * have a learnable activation shared with other modules
    * have a non-learnable activation

    Parameters
    ----------
    activation : ActivationType
        An already instantiated `nn.Module`, or a `nn.Module` subclass,
        or a callable that retgurns an instantiated `nn.Module`, or the
        name of an activation type from `nn`. For example:
        `"ReLU"` `"LeakyReLU"`, `"ELU"`, `"GELU"`, `"Tanh"`, etc.
    kwargs : dict
        Additional parameters to pass to the constructor or function.

    Returns
    -------
    activation : Module
        An instantiated `nn.Module`.
    """
    if not activation:
        return None

    if isinstance(activation, nn.Module):
        return activation

    if isinstance(activation, str):
        if hasattr(nn, activation):
            activation = getattr(nn, activation)
        elif activation in locals():
            activation = locals()[activation]
        else:
            inp_act = activation
            activation = _find_act(activation, [nn.__dict__, locals()])
            if not activation:
                raise ValueError(f'Unknown activation "{inp_act}"')

    if isinstance(activation, type):
        if not issubclass(activation, nn.Module):
            raise TypeError('Activation should be a Module subclass')
        if activation is nn.Softmax:
            kwargs.setdefault('dim', 1)
        activation = activation(**kwargs)

    elif callable(activation):
        activation = activation(**kwargs)

    if not isinstance(activation, nn.Module):
        raise ValueError('Activation did not instantiate a Module')
    return activation

cassetta.layers.attention

ChannelSqzEx

ChannelSqzEx(channels, compression=16, activation='ReLU', device=None, dtype=None, **unused_kwargs)

Bases: Sequential

Spatial Squeeze & Channel Excitation layer

Diagram

flowchart LR
    subgraph Spatial Squeeze
    1["`[C, W]`"]    ---2("`MeanPool`"):::d-->  3["`[C, 1]`"]
    end
    subgraph MLP
                        4("`Linear`"):::w   -->
    5["`[C//r, 1]`"] ---6("`ReLU`"):::d     -->
    7["`[C//r, 1]`"] ---8("`Linear`"):::w
    end
    subgraph Channel Excitation
    9["`[C, 1]`"]    ---10("`Sigmoid`"):::d -->
    11["`[C, 1]`"]   ---12(("*")):::d       -->
    13["`[C, W]`"]
    end
    3 --- 4
    8 --> 9
    1 --> 12
    classDef d fill:lightcyan,stroke:lightblue;
    classDef w fill:papayawhip,stroke:peachpuff;
Note that the batch dimension is not represented, but must be present.

Reference

  1. Hu, J, et al. "Squeeze-and-Excitation Networks." CVPR (2018), TPAMI (2019). arxiv:1709.01507

  2. Roy, AG, et al. "Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks" MICCAI (2018). arxiv:1803.02579

PARAMETER DESCRIPTION
channels

Number of input and output channels

TYPE: int

compression

Compression ratio for the number of channels in the squeeze

TYPE: int DEFAULT: 16

activation

Activation function

TYPE: ActivationType DEFAULT: 'ReLU'

Source code in cassetta/layers/attention.py
def __init__(
    self,
    channels: int,
    compression: int = 16,
    activation: ActivationType = 'ReLU',
    device: Optional[DeviceType] = None,
    dtype: Optional[DataType] = None,
    **unused_kwargs,
):
    """
    Parameters
    ----------
    channels : int
        Number of input and output channels
    compression : int
        Compression ratio for the number of channels in the squeeze
    activation : ActivationType
        Activation function
    """
    opt = dict(bias=False, device=device, dtype=to_torch_dtype(dtype))
    super().__init__(
        GlobalPool(reduction='mean', keepdim=True),
        Linear(channels, max(1, channels//compression), **opt),
        make_activation(activation),
        Linear(max(1, channels//compression), channels, **opt),
        nn.Sigmoid(),
    )

SpatialSqzEx

SpatialSqzEx(channels, device=None, dtype=None, **unused_kwargs)

Bases: Sequential

Channel Squeeze & Spatial Excitation layer

Diagram

flowchart LR
    subgraph Channel Squeeze
    1["`[C, W]`"]    ---2("`Linear`"):::w-->  3["`[1, W]`"]
    end
    subgraph Spatial Excitation
    4("`Sigmoid`"):::d--> 5["`[1, W]`"] ---6(("*")):::d-->7["`[C, W]`"]
    end
    3 --- 4
    1 --> 6
    classDef d fill:lightcyan,stroke:lightblue;
    classDef w fill:papayawhip,stroke:peachpuff;
Note that the batch dimension is not represented, but must be present.

Reference

  1. Roy, AG, et al. "Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks" MICCAI (2018). arxiv:1803.02579
PARAMETER DESCRIPTION
channels

Number of input and output channels

TYPE: int

Source code in cassetta/layers/attention.py
def __init__(
    self,
    channels: int,
    device: Optional[DeviceType] = None,
    dtype: Optional[DataType] = None,
    **unused_kwargs,
):
    """
    Parameters
    ----------
    channels : int
        Number of input and output channels
    """
    opt = dict(bias=False, device=device, dtype=to_torch_dtype(dtype))
    super().__init__(
        Linear(channels, 1, **opt),
        nn.Sigmoid(),
    )

forward

forward(inp)
PARAMETER DESCRIPTION
inp

TYPE: (B, C, *spatial) tensor

RETURNS DESCRIPTION
out

TYPE: (B, C, *spatial) tensor

Source code in cassetta/layers/attention.py
def forward(self, inp: Tensor) -> Tensor:
    """
    Parameters
    ----------
    inp : (B, C, *spatial) tensor

    Returns
    -------
    out : (B, C, *spatial) tensor
    """
    return inp * super().forward(inp)

SqzEx

SqzEx(channels, mode='+', compression=16, activation='ReLU', device=None, dtype=None, **unused_kwargs)

Bases: Sequential

Concurrent Spatial and Channel Squeeze & Spatial Excitation layer

Diagram

flowchart LR
    6(("+")):::d  --> 7["`[C, W]`"]
    1["`[C, W]`"] ---2("Channel Squeeze & Excite"):::w-->
    3["`[C, W]`"]
    1             ---4("Spatial Squeeze & Excite"):::w-->
    5["`[C, W]`"]
    3 --- 6
    5 --> 6
    classDef d fill:lightcyan,stroke:lightblue;
    classDef w fill:papayawhip,stroke:peachpuff;
flowchart LR
    1["`[C, W]`"] ---2("Channel Squeeze & Excite"):::w-->
    3["`[C, W]`"] ---4("Spatial Squeeze & Excite"):::w-->
    5["`[C, W]`"]
    classDef d fill:lightcyan,stroke:lightblue;
    classDef w fill:papayawhip,stroke:peachpuff;
flowchart LR
    1["`[C, W]`"] ---2("Spatial Squeeze & Excite"):::w-->
    3["`[C, W]`"] ---4("Channel Squeeze & Excite"):::w-->
    5["`[C, W]`"]
    classDef d fill:lightcyan,stroke:lightblue;
    classDef w fill:papayawhip,stroke:peachpuff;

Note that the batch dimension is not represented, but must be present.

Reference

  1. Roy, AG, et al. "Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks" MICCAI (2018). arxiv:1803.02579
PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

channels

Number of input and output channels

TYPE: int

mode

Squeeze and excitation mode:

  • 'c' : channel only
  • 's' : spatial only
  • 'cs' : channel, then spatial
  • 'sc' : spatial, then channel
  • '+' : concurrent spatial and channel

TYPE: ('+', c, s, sc, cs) DEFAULT: '+'

compression

Compression ratio for the number of channels in the squeeze

TYPE: int DEFAULT: 16

activation

Activation function

TYPE: ActivationType DEFAULT: 'ReLU'

Source code in cassetta/layers/attention.py
def __init__(
    self,
    channels: int,
    mode: Literal['+', 'c', 's', 'cs', 'sc'] = '+',
    compression: int = 16,
    activation: ActivationType = 'ReLU',
    device: Optional[DeviceType] = None,
    dtype: Optional[DataType] = None,
    **unused_kwargs,
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    channels : int
        Number of input and output channels
    mode : {'+', 'c', 's', 'sc', 'cs'}
        Squeeze and excitation mode:

        - `'c'` : channel only
        - `'s'` : spatial only
        - `'cs'` : channel, then spatial
        - `'sc'` : spatial, then channel
        - `'+'` : concurrent spatial and channel
    compression : int
        Compression ratio for the number of channels in the squeeze
    activation : ActivationType
        Activation function
    """
    mode = mode.lower()
    opt = dict(device=device, dtype=to_torch_dtype(dtype))
    if 's' in mode or '+' in mode:
        s = SpatialSqzEx(channels, **opt)
    if 'c' in mode or '+' in mode:
        c = ChannelSqzEx(channels, compression, activation, **opt)
    if mode == 's':
        layers = [s]
    elif mode == 'c':
        layers = [c]
    elif mode == 'cs':
        layers = [c, s]
    elif mode == 'sc':
        layers = [s, c]
    elif mode == '+':
        layers = [s, c]
    else:
        raise ValueError(f'Unknown mode "{mode}"')
    super().__init__(*layers)
    self.mode = mode

forward

forward(inp)
PARAMETER DESCRIPTION
inp

TYPE: (B, C, *spatial) tensor

RETURNS DESCRIPTION
out

TYPE: (B, C, *spatial) tensor

Source code in cassetta/layers/attention.py
def forward(self, inp):
    """
    Parameters
    ----------
    inp : (B, C, *spatial) tensor

    Returns
    -------
    out : (B, C, *spatial) tensor
    """
    if self.mode == '+':
        return sum([layer(inp) for layer in self])
    else:
        return super().forward(inp)

ChannelBlockAttention

ChannelBlockAttention(channels, compression=16, activation='ReLU', device=None, dtype=None, **unused_kwargs)

Bases: Sequential

Channel Attention for Convolutional Block Attention Module

Diagram

flowchart LR
    subgraph Spatial Squeeze
    1["`[C, W]`"]    ---2("`MeanPool`"):::d-->  3["`[C, 1]`"]
    1                ---4("`MaxPool`"):::d-->   5["`[C, 1]`"]
    end
    subgraph MLP - shared weights
                        6("`Linear`"):::w   -->
    7["`[C//r, 1]`"] ---8("`ReLU`"):::d     -->
    9["`[C//r, 1]`"] ---10("`Linear`"):::w
    end
    subgraph MLP - shared weights
                         11("`Linear`"):::w   -->
    12["`[C//r, 1]`"] ---13("`ReLU`"):::d     -->
    14["`[C//r, 1]`"] ---15("`Linear`"):::w
    end
    subgraph Channel Attention
    20("`Sigmoid`"):::d -->
    21["`[C, 1]`"] --- 22(("*")):::d  --> 23["`[C, W]`"]
    end
    16["`[C, 1]`"] & 17["`[C, 1]`"] ---18(("+")):::d--> 19["`[C, 1]`"]
    3 --- 6
    5 --- 11
    10 --> 16
    15 --> 17
    19 --- 20
    1 --> 22
    classDef d fill:lightcyan,stroke:lightblue;
    classDef w fill:papayawhip,stroke:peachpuff;
Note that the batch dimension is not represented, but must be present.

Reference

  1. Woo, S, et al. "CBAM: Convolutional Block Attention Module." ECCV (2018). arxiv:1807.06521
PARAMETER DESCRIPTION
channels

Number of input and output channels

TYPE: int

compression

Compression ratio for the number of channels in the squeeze

TYPE: int DEFAULT: 16

activation

Activation function

TYPE: ActivationType DEFAULT: 'ReLU'

Source code in cassetta/layers/attention.py
def __init__(
    self,
    channels: int,
    compression: int = 16,
    activation: ActivationType = 'ReLU',
    device: Optional[DeviceType] = None,
    dtype: Optional[DataType] = None,
    **unused_kwargs,
):
    """
    Parameters
    ----------
    channels : int
        Number of input and output channels
    compression : int
        Compression ratio for the number of channels in the squeeze
    activation : ActivationType
        Activation function
    """
    opt = dict(device=device, dtype=to_torch_dtype(dtype))
    super().__init__(
        GlobalPool(keepdim=False, reduction='mean'),
        GlobalPool(keepdim=False, reduction='max'),
        Linear(channels, max(1, channels//compression), bias=False, **opt),
        make_activation(activation),
        Linear(max(1, channels//compression), channels, bias=False, **opt),
        nn.Sigmoid(),
    )

forward

forward(inp)
PARAMETER DESCRIPTION
inp

TYPE: (B, C, *spatial) tensor

RETURNS DESCRIPTION
out

TYPE: (B, C, *spatial) tensor

Source code in cassetta/layers/attention.py
def forward(self, inp: Tensor) -> Tensor:
    """
    Parameters
    ----------
    inp : (B, C, *spatial) tensor

    Returns
    -------
    out : (B, C, *spatial) tensor
    """
    ndim = inp.ndim - 2
    meanpool, maxpool, *mlp, sigmoid = self
    mlp = nn.Sequential(*mlp)
    out = sigmoid(mlp(meanpool(inp)) + mlp(maxpool(inp)))
    out = out.reshape(out.shape + (1,) * ndim)
    out = inp * out
    return out

SpatialBlockAttention

SpatialBlockAttention(ndim, kernel_size=7, device=None, dtype=None, **unused_kwargs)

Bases: Sequential

Spatial Attention for Convolutional Block Attention Module

Diagram

flowchart LR
    subgraph Channel Squeeze
    1["`[C, W]`"]    ---2("`ChannelMean`"):::d-->  3["`[1, W]`"]
    1                ---4("`ChannelMax`"):::d-->   5["`[1, W]`"]
    end
    3 & 5 ---6(("c")):::d--> 7["`[2, W]`"]
    7 ---8("`Conv 7`"):::w--> 9["`[1, W]`"]
    subgraph Spatial Excitation
    10("`Sigmoid`"):::d--> 11["`[1, W]`"] ---12(("*")):::d-->
    13["`[C, W]`"]
    end
    9 --> 10
    1 --> 12
    classDef d fill:lightcyan,stroke:lightblue;
    classDef w fill:papayawhip,stroke:peachpuff;
Note that the batch dimension is not represented, but must be present.

Reference

  1. Woo, S, et al. "CBAM: Convolutional Block Attention Module." ECCV (2018). arxiv:1807.06521
PARAMETER DESCRIPTION
ndim

Number of spatial dim

TYPE: int

kernel_size

Kernel size of the convolution layer

TYPE: [list of] int DEFAULT: 7

Source code in cassetta/layers/attention.py
def __init__(
    self,
    ndim: int,
    kernel_size: OneOrSeveral[int] = 7,
    device: Optional[DeviceType] = None,
    dtype: Optional[DataType] = None,
    **unused_kwargs,
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dim
    kernel_size : [list of] int
        Kernel size of the convolution layer
    """
    opt = dict(bias=False, device=device, dtype=to_torch_dtype(dtype))
    super().__init__(
        GlobalPool(keepdim=True, dim=1, reduction='mean'),
        GlobalPool(keepdim=True, dim=1, reduction='max'),
        Conv(ndim, 2, 1, kernel_size=kernel_size, **opt),
        nn.Sigmoid(),
    )

forward

forward(inp)
PARAMETER DESCRIPTION
inp

TYPE: (B, C, *spatial) tensor

RETURNS DESCRIPTION
out

TYPE: (B, C, *spatial) tensor

Source code in cassetta/layers/attention.py
def forward(self, inp: Tensor) -> Tensor:
    """
    Parameters
    ----------
    inp : (B, C, *spatial) tensor

    Returns
    -------
    out : (B, C, *spatial) tensor
    """
    meanpool, maxpool, conv, sigmoid = self
    out = sigmoid(conv(Cat()(meanpool(inp), maxpool(inp))))
    out = inp * out
    return out

BlockAttention

BlockAttention(ndim, channels, mode='cs', compression=16, activation='ReLU', kernel_size=7, **unused_kwargs)

Bases: Sequential

Channel + Spatial Attention layer

Diagram

flowchart LR
    6(("+")):::d  --> 7["`[C, W]`"]
    1["`[C, W]`"] ---2("Channel Attention"):::w--> 3["`[C, W]`"]
    1             ---4("Spatial Attention"):::w--> 5["`[C, W]`"]
    3 --- 6
    5 --> 6
    classDef d fill:lightcyan,stroke:lightblue;
    classDef w fill:papayawhip,stroke:peachpuff;
flowchart LR
    1["`[C, W]`"] ---2("Channel Attention"):::w-->
    3["`[C, W]`"] ---4("Spatial Attention"):::w-->
    5["`[C, W]`"]
    classDef d fill:lightcyan,stroke:lightblue;
    classDef w fill:papayawhip,stroke:peachpuff;
flowchart LR
    1["`[C, W]`"] ---2("Spatial Attention"):::w-->
    3["`[C, W]`"] ---4("Channel Attention"):::w-->
    5["`[C, W]`"]
    classDef d fill:lightcyan,stroke:lightblue;
    classDef w fill:papayawhip,stroke:peachpuff;

Note that the batch dimension is not represented, but must be present.

Reference

  1. Woo, Sanghyun, et al. "CBAM: Convolutional Block Attention Module." ECCV (2018). https://arxiv.org/abs/1807.06521v2
PARAMETER DESCRIPTION
ndim

Number of spatial dim

TYPE: int

channels

Number of input and output channels

TYPE: int

mode

Attention mode:

  • 'c' : channel only
  • 's' : spatial only
  • 'cs' : channel, then spatial
  • 'sc' : spatial, then channel
  • '+' : concurrent spatial and channel

TYPE: (cs, sc, c, s, '+') DEFAULT: 'cs'

compression

Compression ratio for the number of channels in the squeeze

TYPE: int DEFAULT: 16

activation

Activation function

TYPE: ActivationType DEFAULT: 'ReLU'

kernel_size

Kernel size of the convolution layer

TYPE: [list of] int DEFAULT: 7

Source code in cassetta/layers/attention.py
def __init__(
    self,
    ndim: int,
    channels: int,
    mode: Literal['c', 's', 'cs', 'sc', '+'] = 'cs',
    compression: int = 16,
    activation: ActivationType = 'ReLU',
    kernel_size: OneOrSeveral[int] = 7,
    **unused_kwargs,
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dim
    channels : int
        Number of input and output channels
    mode : {'cs', 'sc', 'c', 's', '+'}
        Attention mode:

        - `'c'` : channel only
        - `'s'` : spatial only
        - `'cs'` : channel, then spatial
        - `'sc'` : spatial, then channel
        - `'+'` : concurrent spatial and channel
    compression : int
        Compression ratio for the number of channels in the squeeze
    activation : ActivationType
        Activation function
    kernel_size : [list of] int
        Kernel size of the convolution layer
    """
    mode = mode.lower()
    if 's' in mode or '+' in mode:
        s = SpatialBlockAttention(ndim, kernel_size)
    if 'c' in mode or '+' in mode:
        c = ChannelBlockAttention(channels, compression, activation)
    if mode == 's':
        layers = [s]
    elif mode == 'c':
        layers = [c]
    elif mode == 'cs':
        layers = [c, s]
    elif mode == 'sc':
        layers = [s, c]
    elif mode == '+':
        layers = [s, c]
    else:
        raise ValueError(f'Unknown mode "{mode}"')
    super().__init__(*layers)
    self.mode = mode

forward

forward(inp)
PARAMETER DESCRIPTION
inp

TYPE: (B, C, *spatial) tensor

RETURNS DESCRIPTION
out

TYPE: (B, C, *spatial) tensor

Source code in cassetta/layers/attention.py
def forward(self, inp):
    """
    Parameters
    ----------
    inp : (B, C, *spatial) tensor

    Returns
    -------
    out : (B, C, *spatial) tensor
    """
    if self.mode == '+':
        return sum([layer(inp) for layer in self])
    else:
        return super().forward(inp)

DotProductAttention

DotProductAttention(key_channels, val_channels, scaled=True, **unused_kwargs)

Bases: Module

Under construction -- do not use

References

  1. Vaswani, Ashish, et al. **"Attention Is All You Need." NeurIPS (2017). https://arxiv.org/abs/1706.03762v7
PARAMETER DESCRIPTION
key_channels

Number of keys

TYPE: int

val_channels

Number of values

TYPE: int

scaled

Scale the dot product

TYPE: bool DEFAULT: True

Source code in cassetta/layers/attention.py
def __init__(
    self,
    key_channels: int,
    val_channels: int,
    scaled=True,
    **unused_kwargs,
):
    """
    Parameters
    ----------
    key_channels: int
        Number of keys
    val_channels: int
        Number of values
    scaled : bool
        Scale the dot product
    """
    super().__init__()
    self.key_channels = key_channels
    self.val_channels = val_channels
    self.scaled = scaled

forward

forward(inp)
PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, K+K*V+V, *spatial) tensor

RETURNS DESCRIPTION
out

TYPE: (B, V, *spatial) tensor

Source code in cassetta/layers/attention.py
def forward(self, inp):
    """
    Parameters
    ----------
    inp : (B, K+K*V+V, *spatial) tensor
        Input tensor

    Returns
    -------
    out : (B, V, *spatial) tensor
    """
    nk, nv = self.key_channels, self.val_channels
    q, k, v = inp.split([nk, nk*nv, nv], dim=1)
    q, k, v = q.movedim(1, -1), k.movedim(1, -1), v.movedim(1, -1)
    q, k = q.unsqueeze(-2), k.reshape(k.shape[:-1] + (nk, nv))
    qk = q.matmul(k).squeeze(-2)
    if self.scaled:
        qk /= self.key_channels ** 0.5
    qk = nn.Softmax(dim=-1)
    return (qk * v).movedim(-1, 1)

MultiHeadAttention

MultiHeadAttention(inp_channels, key_channels, val_channels, nb_heads, scaled=True, **unused_kwargs)

Bases: Module

Under construction -- do not use

References

  1. Vaswani, Ashish, et al. **"Attention Is All You Need." NeurIPS (2017). https://arxiv.org/abs/1706.03762v7
PARAMETER DESCRIPTION
inp_channels

Number of input channels

TYPE: int

key_channels

Number of keys

TYPE: int

val_channels

Number of values

TYPE: int

nb_heads

Number of heads

TYPE: int

scaled

Scale the dot product

TYPE: bool DEFAULT: True

Source code in cassetta/layers/attention.py
def __init__(
    self,
    inp_channels: int,
    key_channels: int,
    val_channels: int,
    nb_heads: int,
    scaled=True,
    **unused_kwargs,
):
    """
    Parameters
    ----------
    inp_channels: int
        Number of input channels
    key_channels: int
        Number of keys
    val_channels: int
        Number of values
    nb_heads : int
        Number of heads
    scaled : bool
        Scale the dot product
    """
    super().__init__()
    qkv_channels = (
        key_channels * val_channels + key_channels + val_channels
    )
    self.heads = nn.ModuleList([
        nn.Sequential(
            MoveDim(1, -1),
            nn.Linear(inp_channels, qkv_channels, bias=False),
            MoveDim(-1, 1),
            DotProductAttention(key_channels, val_channels, scaled=scaled)
        )
        for _ in range(nb_heads)
    ])
    self.combine = nn.Sequential(
        MoveDim(1, -1),
        nn.Linear(val_channels*nb_heads, inp_channels),
        MoveDim(-1, 1),
    )

forward

forward(inp)
PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, C, *spatial) tensor

RETURNS DESCRIPTION
out

TYPE: (B, C, *spatial) tensor

Source code in cassetta/layers/attention.py
def forward(self, inp):
    """
    Parameters
    ----------
    inp : (B, C, *spatial) tensor
        Input tensor

    Returns
    -------
    out : (B, C, *spatial) tensor
    """
    out = Cat()([head(inp) for head in self.heads])
    out = self.combine(out)
    return out

make_attention

make_attention(attention, channels, ndim=None, **kwargs)

Instantiate an attention layer

PARAMETER DESCRIPTION
attention

An already instantiated nn.Module, or a nn.Module subclass, or a callable that retgurns an instantiated nn.Module, or the name of an attention type:

  • "sqzex" : Squeeze & Excite
  • "cbam" : Convolutional Block Attention Module
  • "dp" : Dot-Product Attention
  • "sdp" : Scaled Dot-Product Attention
  • "mha" : Multi-Head Attention

TYPE: AttentionType

channels

Number of channels

TYPE: int

ndim

Number of spatial dimensions

TYPE: int DEFAULT: None

RETURNS DESCRIPTION
attention

An attention layer

TYPE: Module

Source code in cassetta/layers/attention.py
def make_attention(
    attention: AttentionType,
    channels: int,
    ndim: int = None,
    **kwargs
):
    """
    Instantiate an attention layer

    Parameters
    ----------
    attention : AttentionType
        An already instantiated `nn.Module`, or a `nn.Module` subclass,
        or a callable that retgurns an instantiated `nn.Module`, or the
        name of an attention type:

        - `"sqzex"` : Squeeze & Excite
        - `"cbam"` : Convolutional Block Attention Module
        - `"dp"` : Dot-Product Attention
        - `"sdp"` : Scaled Dot-Product Attention
        - `"mha"` : Multi-Head Attention
    channels : int
        Number of channels
    ndim : int
        Number of spatial dimensions

    Returns
    -------
    attention : Module
        An attention layer
    """
    if not attention:
        return None
    if isinstance(attention, nn.Module):
        return attention
    if isinstance(attention, str):
        attention = attention.lower()
        if attention == 'sqzex':
            attention = SqzEx
        elif attention == 'cbam':
            attention = BlockAttention
        elif attention == 'dp':
            attention = DotProductAttention
            kwargs.setdefault('scaled', False)
        elif attention == 'sdp':
            attention = DotProductAttention
            kwargs.setdefault('scaled', True)
        elif attention == 'mha':
            attention = MultiHeadAttention
        else:
            raise ValueError(f'Unknown attention "{attention}"')
    kwargs['ndim'] = ndim
    kwargs['channels'] = channels
    attention = attention(**kwargs)
    if not isinstance(attention, nn.Module):
        raise ValueError('Attention did not instantiate a Module')
    return attention

cassetta.layers.conv

Conv

Conv(ndim, inp_channels, out_channels=None, kernel_size=3, stride=1, padding=0, dilation=1, output_padding=0, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)

Bases: _Conv

PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

inp_channels

Number of input channels

TYPE: int

out_channels

Number of output channels

TYPE: int DEFAULT: `inp_channels`

kernel_size

Kernel size

TYPE: [list of] int DEFAULT: 3

stride

Space between output elements

TYPE: [list of] int DEFAULT: 1

padding

Amount of padding to apply

TYPE: (same, valid) DEFAULT: 'same'

dilation

Space between kernel elements

TYPE: [list of] int DEFAULT: 1

output_padding

Amount of padding to apply to the ouptput of a transposed convolution.

TYPE: [list of] int DEFAULT: 0

groups

Number of groups

TYPE: int DEFAULT: 1

bias

Add a learnable bias term

TYPE: int DEFAULT: True

padding_mode

How to pad the tensor

TYPE: BoundType DEFAULT: 'zeros'

Source code in cassetta/layers/conv.py
def __init__(
    self,
    ndim: int,
    inp_channels: int,
    out_channels: Optional[int] = None,
    kernel_size: OneOrSeveral[int] = 3,
    stride: OneOrSeveral[int] = 1,
    padding: Union[Literal['same', 'valid'], OneOrSeveral[int]] = 0,
    dilation: OneOrSeveral[int] = 1,
    output_padding: OneOrSeveral[int] = 0,
    groups: int = 1,
    bias: bool = True,
    padding_mode: BoundType = 'zeros',
    device: Optional[DeviceType] = None,
    dtype: Optional[torch.dtype] = None,
) -> None:
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    inp_channels : int
        Number of input channels
    out_channels : int, default=`inp_channels`
        Number of output channels
    kernel_size : [list of] int
        Kernel size
    stride : [list of] int
        Space between output elements
    padding : {'same', 'valid'} or [list of] int
        Amount of padding to apply
    dilation : [list of] int
        Space between kernel elements
    output_padding : [list of] int
        Amount of padding to apply to the ouptput of a transposed
        convolution.
    groups : int
        Number of groups
    bias : int
        Add a learnable bias term
    padding_mode : BoundType
        How to pad the tensor
    """
    factory_kwargs = {'device': device, 'dtype': dtype}
    super().__init__()

    self.ndim = ndim
    self.inp_channels = inp_channels
    self.out_channels = out_channels
    self.kernel_size = kernel_size
    self.stride = stride
    self.padding = padding
    self.dilation = dilation
    self.output_padding = output_padding
    self.groups = groups
    self.padding_mode = padding_mode

    self.check_parameters()

    if self.transposed:
        self.weight = nn.Parameter(torch.empty(
            (
                self.inp_channels,
                self.out_channels // self.groups,
                *self.kernel_size
            ),
            **factory_kwargs
        ))
    else:
        self.weight = nn.Parameter(torch.empty(
            (
                self.out_channels,
                self.inp_channels // self.groups,
                *self.kernel_size
            ),
            **factory_kwargs
        ))
    if bias:
        self.bias = nn.Parameter(torch.empty(
            self.out_channels, **factory_kwargs
        ))
    else:
        self.register_parameter('bias', None)

    self.reset_parameters()

forward

forward(inp)
PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, inp_channels, *inp_size) tensor

RETURNS DESCRIPTION
out

Convolved tensor

TYPE: (B, out_channels, *out_size) tensor

Source code in cassetta/layers/conv.py
def forward(self, inp: Tensor) -> Tensor:
    """
    Parameters
    ----------
    inp : (B, inp_channels, *inp_size) tensor
        Input tensor

    Returns
    -------
    out : (B, out_channels, *out_size) tensor
        Convolved tensor
    """
    ndim = len(self.kernel_size)
    padding = self.padding
    conv = getattr(F, f'conv{ndim}d')
    if to_enum(self.padding_mode) != BoundEnum.zeros:
        inp = pad(inp, self._padding_lr, mode=self.padding_mode)
        padding = 0
    return conv(inp, self.weight, self.bias, self.stride,
                padding, self.dilation, self.groups)

ConvTransposed

ConvTransposed(ndim, inp_channels, out_channels=None, kernel_size=3, stride=1, padding=0, dilation=1, output_padding=0, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)

Bases: _Conv

PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

inp_channels

Number of input channels

TYPE: int

out_channels

Number of output channels

TYPE: int DEFAULT: `inp_channels`

kernel_size

Kernel size

TYPE: [list of] int DEFAULT: 3

stride

Space between output elements

TYPE: [list of] int DEFAULT: 1

padding

Amount of padding to apply

TYPE: (same, valid) DEFAULT: 'same'

dilation

Space between kernel elements

TYPE: [list of] int DEFAULT: 1

output_padding

Amount of padding to apply to the ouptput of a transposed convolution.

TYPE: [list of] int DEFAULT: 0

groups

Number of groups

TYPE: int DEFAULT: 1

bias

Add a learnable bias term

TYPE: int DEFAULT: True

padding_mode

How to pad the tensor

TYPE: BoundType DEFAULT: 'zeros'

Source code in cassetta/layers/conv.py
def __init__(
    self,
    ndim: int,
    inp_channels: int,
    out_channels: Optional[int] = None,
    kernel_size: OneOrSeveral[int] = 3,
    stride: OneOrSeveral[int] = 1,
    padding: Union[Literal['same', 'valid'], OneOrSeveral[int]] = 0,
    dilation: OneOrSeveral[int] = 1,
    output_padding: OneOrSeveral[int] = 0,
    groups: int = 1,
    bias: bool = True,
    padding_mode: BoundType = 'zeros',
    device: Optional[DeviceType] = None,
    dtype: Optional[torch.dtype] = None,
) -> None:
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    inp_channels : int
        Number of input channels
    out_channels : int, default=`inp_channels`
        Number of output channels
    kernel_size : [list of] int
        Kernel size
    stride : [list of] int
        Space between output elements
    padding : {'same', 'valid'} or [list of] int
        Amount of padding to apply
    dilation : [list of] int
        Space between kernel elements
    output_padding : [list of] int
        Amount of padding to apply to the ouptput of a transposed
        convolution.
    groups : int
        Number of groups
    bias : int
        Add a learnable bias term
    padding_mode : BoundType
        How to pad the tensor
    """
    factory_kwargs = {'device': device, 'dtype': dtype}
    super().__init__()

    self.ndim = ndim
    self.inp_channels = inp_channels
    self.out_channels = out_channels
    self.kernel_size = kernel_size
    self.stride = stride
    self.padding = padding
    self.dilation = dilation
    self.output_padding = output_padding
    self.groups = groups
    self.padding_mode = padding_mode

    self.check_parameters()

    if self.transposed:
        self.weight = nn.Parameter(torch.empty(
            (
                self.inp_channels,
                self.out_channels // self.groups,
                *self.kernel_size
            ),
            **factory_kwargs
        ))
    else:
        self.weight = nn.Parameter(torch.empty(
            (
                self.out_channels,
                self.inp_channels // self.groups,
                *self.kernel_size
            ),
            **factory_kwargs
        ))
    if bias:
        self.bias = nn.Parameter(torch.empty(
            self.out_channels, **factory_kwargs
        ))
    else:
        self.register_parameter('bias', None)

    self.reset_parameters()

forward

forward(inp, out_size=None)
PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, inp_channels, *inp_size) tensor

RETURNS DESCRIPTION
out

Output tensor

TYPE: (B, out_channels, *out_size) tensor

Source code in cassetta/layers/conv.py
def forward(self, inp: Tensor, out_size: Optional[List[int]] = None
            ) -> Tensor:
    """
    Parameters
    ----------
    inp : (B, inp_channels, *inp_size) tensor
        Input tensor

    Returns
    -------
    out : (B, out_channels, *out_size) tensor
        Output tensor

    """
    ndim = len(self.kernel_size)
    output_padding = self._output_padding(inp, out_size)
    conv_transpose = getattr(F, f'conv_transpose{ndim}d')
    return conv_transpose(
        inp, self.weight, self.bias, self.stride, self.padding,
        output_padding, self.groups, self.dilation)

LazyConv

LazyConv(out_channels=None, kernel_size=3, stride=1, padding=0, dilation=1, output_padding=0, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)

Bases: LazyModuleMixin, Conv

A convolution layer whose ndim and inp_channels are guessed lazily at run time.

PARAMETER DESCRIPTION
out_channels

Number of output channels. If a function, takes the materialized number of input channels and return the materialized number of output channels.

TYPE: int or callable DEFAULT: `inp_channels`

kernel_size

Kernel size

TYPE: [list of] int DEFAULT: 3

stride

Space between output elements

TYPE: [list of] int DEFAULT: 1

padding

Amount of padding to apply

TYPE: (same, valid) DEFAULT: 'same'

dilation

Space between kernel elements

TYPE: [list of] int DEFAULT: 1

groups

Number of groups

TYPE: int DEFAULT: 1

bias

Add a learnable bias term

TYPE: int DEFAULT: True

padding_mode

How to pad the tensor

TYPE: BoundType DEFAULT: 'zeros'

Source code in cassetta/layers/conv.py
def __init__(
    self,
    out_channels: Optional[Union[int, Callable]] = None,
    kernel_size: OneOrSeveral[int] = 3,
    stride: OneOrSeveral[int] = 1,
    padding: Union[Literal['same', 'valid'], OneOrSeveral[int]] = 0,
    dilation: OneOrSeveral[int] = 1,
    output_padding: OneOrSeveral[int] = 0,
    groups: int = 1,
    bias: bool = True,
    padding_mode: BoundType = 'zeros',
    device: Optional[DeviceType] = None,
    dtype: Optional[torch.dtype] = None,
) -> None:
    """
    Parameters
    ----------
    out_channels : int or callable, default=`inp_channels`
        Number of output channels.
        If a function, takes the materialized number of input
        channels and return the materialized number of output
        channels.
    kernel_size : [list of] int
        Kernel size
    stride : [list of] int
        Space between output elements
    padding : {'same', 'valid'} or [list of] int
        Amount of padding to apply
    dilation : [list of] int
        Space between kernel elements
    groups : int
        Number of groups
    bias : int
        Add a learnable bias term
    padding_mode : BoundType
        How to pad the tensor
    """
    conv_kwargs = dict(
        kernel_size=kernel_size, stride=stride, padding=padding,
        dilation=dilation, groups=groups, padding_mode=padding_mode
    )
    factory_kwargs = dict(dtype=dtype, device=device)
    # ndim=3 to avoid losing user parameters
    super().__init__(3, 0, 0, bias=False, **conv_kwargs)
    self.out_channels = out_channels
    self.weight = nn.UninitializedParameter(**factory_kwargs)
    if bias:
        self.bias = nn.UninitializedParameter(**factory_kwargs)

LazyConvTransposed

LazyConvTransposed(out_channels=None, kernel_size=3, stride=1, padding=0, dilation=1, output_padding=0, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)

Bases: LazyModuleMixin, Conv

A transposed convolution layer whose ndim and inp_channels are guessed lazily at run time.

PARAMETER DESCRIPTION
out_channels

Number of output channels. If a function, takes the materialized number of input channels and return the materialized number of output channels.

TYPE: int or callable DEFAULT: `inp_channels`

kernel_size

Kernel size

TYPE: [list of] int DEFAULT: 3

stride

Space between output elements

TYPE: [list of] int DEFAULT: 1

padding

Amount of padding to apply

TYPE: (same, valid) DEFAULT: 'same'

dilation

Space between kernel elements

TYPE: [list of] int DEFAULT: 1

output_padding

Amount of padding to apply to the ouptput of a transposed convolution.

TYPE: [list of] int DEFAULT: 0

groups

Number of groups

TYPE: int DEFAULT: 1

bias

Add a learnable bias term

TYPE: int DEFAULT: True

padding_mode

How to pad the tensor

TYPE: BoundType DEFAULT: 'zeros'

Source code in cassetta/layers/conv.py
def __init__(
    self,
    out_channels: Optional[Union[int, Callable]] = None,
    kernel_size: OneOrSeveral[int] = 3,
    stride: OneOrSeveral[int] = 1,
    padding: Union[Literal['same', 'valid'], OneOrSeveral[int]] = 0,
    dilation: OneOrSeveral[int] = 1,
    output_padding: OneOrSeveral[int] = 0,
    groups: int = 1,
    bias: bool = True,
    padding_mode: BoundType = 'zeros',
    device: Optional[DeviceType] = None,
    dtype: Optional[torch.dtype] = None,
) -> None:
    """
    Parameters
    ----------
    out_channels : int or callable, default=`inp_channels`
        Number of output channels.
        If a function, takes the materialized number of input
        channels and return the materialized number of output
        channels.
    kernel_size : [list of] int
        Kernel size
    stride : [list of] int
        Space between output elements
    padding : {'same', 'valid'} or [list of] int
        Amount of padding to apply
    dilation : [list of] int
        Space between kernel elements
    output_padding : [list of] int
        Amount of padding to apply to the ouptput of a transposed
        convolution.
    groups : int
        Number of groups
    bias : int
        Add a learnable bias term
    padding_mode : BoundType
        How to pad the tensor
    """
    conv_kwargs = dict(
        kernel_size=kernel_size, stride=stride, padding=padding,
        dilation=dilation, output_padding=output_padding, groups=groups,
        padding_mode=padding_mode
    )
    factory_kwargs = dict(dtype=dtype, device=device)
    # ndim=3 to avoid losing user parameters
    super().__init__(3, 0, 0, bias=False, **conv_kwargs)
    self.out_channels = out_channels
    self.weight = nn.UninitializedParameter(**factory_kwargs)
    if bias:
        self.bias = nn.UninitializedParameter(**factory_kwargs)

SeparableConv

SeparableConv(ndim, inp_channels, out_channels=None, kernel_size=3, dilation=1, bias=True, padding='same', padding_mode='zeros')

Bases: Sequential

Separable Convolution.

Implements a ND convolution (e.g., WxHxD) as a series of 1D convolutions (e.g., Wx1x1, 1xHx1, 1x1xD).

The number of input and output channels will be the same

Padding mode is 'same' by default

PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

inp_channels

Number of input channels

TYPE: int

out_channels

Number of output channels

TYPE: int DEFAULT: `inp_channels`

kernel_size

Kernel size

TYPE: [list of] int DEFAULT: 3

dilation

Space between kernel elements

TYPE: [list of] int DEFAULT: 1

bias

Add a learnable bias term

TYPE: int DEFAULT: True

padding

Amount of padding to apply

TYPE: (same, valid) DEFAULT: 'same'

padding_mode

How to pad the tensor

TYPE: BoundType DEFAULT: 'zeros'

Source code in cassetta/layers/conv.py
def __init__(
    self,
    ndim: int,
    inp_channels: int,
    out_channels: Optional[int] = None,
    kernel_size: OneOrSeveral[int] = 3,
    dilation: OneOrSeveral[int] = 1,
    bias: bool = True,
    padding: Union[int, Literal['same']] = 'same',
    padding_mode: str = 'zeros',
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    inp_channels : int
        Number of input channels
    out_channels : int, default=`inp_channels`
        Number of output channels
    kernel_size : [list of] int
        Kernel size
    dilation : [list of] int
        Space between kernel elements
    bias : int
        Add a learnable bias term
    padding : {'same', 'valid'} or [list of] int
        Amount of padding to apply
    padding_mode : BoundType
        How to pad the tensor
    """
    out_channels = out_channels or inp_channels
    mid_channels = max(inp_channels, out_channels)

    layers = []
    for dim, (K, D) in enumerate(zip(kernel_size, dilation)):
        K1 = [1] * ndim
        K1[dim] = K
        inpch = inp_channels if dim == 0 else mid_channels
        outch = out_channels if (dim == ndim-1) else mid_channels
        kwargs = dict(kernel_size=K1, dilation=D,
                      padding=padding, bias=bias and (dim == ndim-1),
                      padding_mode=padding_mode)
        layers.append(Conv(ndim, inpch, outch, **kwargs))

    super().__init__(*layers)

CrossHairConv

CrossHairConv(ndim, inp_channels, out_channels=None, kernel_size=3, dilation=1, bias=True, padding='same', padding_mode='zeros')

Bases: SeparableConv

Separable Cross-Hair Convolution.

Separable convolution, where the input tensor is convolved with a set of 1D convolutions, and all outputs are summed together (e.g., Wx1x1 + 1xHx1 + 1x1xD).

Padding must be 'same'

Reference

Tetteh, Giles, et al. "Deepvesselnet: Vessel segmentation, centerline prediction, and bifurcation detection in 3-d angiographic volumes." Frontiers in Neuroscience 14 (2020). 10.3389/fnins.2020.592352

bibtex
@article{tetteh2020deepvesselnet,
    title={Deepvesselnet: Vessel segmentation, centerline prediction, and bifurcation detection in 3-d angiographic volumes},
    author={Tetteh, Giles and Efremov, Velizar and Forkert, Nils D and Schneider, Matthias and Kirschke, Jan and Weber, Bruno and Zimmer, Claus and Piraud, Marie and Menze, Bj{"o}rn H},
    journal={Frontiers in Neuroscience},
    volume={14},
    pages={592352},
    year={2020},
    publisher={Frontiers}
}
PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

inp_channels

Number of input channels

TYPE: int

out_channels

Number of output channels

TYPE: int DEFAULT: `inp_channels`

kernel_size

Kernel size

TYPE: [list of] int DEFAULT: 3

dilation

Space between kernel elements

TYPE: [list of] int DEFAULT: 1

bias

Add a learnable bias term

TYPE: int DEFAULT: True

padding

Amount of padding to apply

TYPE: (same, valid) DEFAULT: 'same'

padding_mode

How to pad the tensor

TYPE: BoundType DEFAULT: 'zeros'

Source code in cassetta/layers/conv.py
def __init__(
    self,
    ndim: int,
    inp_channels: int,
    out_channels: Optional[int] = None,
    kernel_size: OneOrSeveral[int] = 3,
    dilation: OneOrSeveral[int] = 1,
    bias: bool = True,
    padding: Union[int, Literal['same']] = 'same',
    padding_mode: str = 'zeros',
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    inp_channels : int
        Number of input channels
    out_channels : int, default=`inp_channels`
        Number of output channels
    kernel_size : [list of] int
        Kernel size
    dilation : [list of] int
        Space between kernel elements
    bias : int
        Add a learnable bias term
    padding : {'same', 'valid'} or [list of] int
        Amount of padding to apply
    padding_mode : BoundType
        How to pad the tensor
    """
    out_channels = out_channels or inp_channels
    mid_channels = max(inp_channels, out_channels)

    layers = []
    for dim, (K, D) in enumerate(zip(kernel_size, dilation)):
        K1 = [1] * ndim
        K1[dim] = K
        inpch = inp_channels if dim == 0 else mid_channels
        outch = out_channels if (dim == ndim-1) else mid_channels
        kwargs = dict(kernel_size=K1, dilation=D,
                      padding=padding, bias=bias and (dim == ndim-1),
                      padding_mode=padding_mode)
        layers.append(Conv(ndim, inpch, outch, **kwargs))

    super().__init__(*layers)

forward

forward(inp)
PARAMETER DESCRIPTION
inp

TYPE: (B, inp_channels, *spatial)

RETURNS DESCRIPTION
out

TYPE: (B, out_channels, *spatial)

Source code in cassetta/layers/conv.py
def forward(self, inp: Tensor) -> Tensor:
    """
    Parameters
    ----------
    inp : (B, inp_channels, *spatial)

    Returns
    -------
    out : (B, out_channels, *spatial)
    """
    out = 0
    for layer in self:
        out += layer(inp)
    return out

make_conv

make_conv(ndim, inp_channels, out_channels=None, kernel_size=3, stride=1, padding=0, dilation=1, output_padding=0, groups=1, bias=True, separable=False, transpose=False, padding_mode='zeros', device=None, dtype=None)

Instantiate a convolution layer.

PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

inp_channels

Number of input channels

TYPE: int

out_channels

Number of output channels

TYPE: int DEFAULT: `inp_channels`

kernel_size

Kernel size

TYPE: [list of] int DEFAULT: 3

stride

Space between output elements

TYPE: [list of] int DEFAULT: 1

padding

Amount of padding to apply

TYPE: (same, valid) DEFAULT: 'same'

dilation

Space between kernel elements

TYPE: [list of] int DEFAULT: 1

output_padding

Amount of padding to apply to the ouptput of a transposed convolution.

TYPE: [list of] int DEFAULT: 0

groups

Number of groups

TYPE: int DEFAULT: 1

bias

Add a learnable bias term

TYPE: int DEFAULT: True

padding_mode

How to pad the tensor

TYPE: BoundType DEFAULT: 'zeros'

device

Weights' device

TYPE: device DEFAULT: None

dtype

Weights' data type

TYPE: dtype DEFAULT: None

RETURNS DESCRIPTION
layer

TYPE: Conv or ConvTransposed or SeparableConv or CrossHairConv

Source code in cassetta/layers/conv.py
def make_conv(
    ndim: int,
    inp_channels: int,
    out_channels: Optional[int] = None,
    kernel_size: OneOrSeveral[int] = 3,
    stride: OneOrSeveral[int] = 1,
    padding: Union[Literal['same', 'valid'], OneOrSeveral[int]] = 0,
    dilation: OneOrSeveral[int] = 1,
    output_padding: OneOrSeveral[int] = 0,
    groups: int = 1,
    bias: bool = True,
    separable: Union[Literal['crosshair'], bool] = False,
    transpose=False,
    padding_mode: BoundType = 'zeros',
    device: Optional[DeviceType] = None,
    dtype: Optional[torch.dtype] = None,
):
    """
    Instantiate a convolution layer.

    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    inp_channels : int
        Number of input channels
    out_channels : int, default=`inp_channels`
        Number of output channels
    kernel_size : [list of] int
        Kernel size
    stride : [list of] int
        Space between output elements
    padding : {'same', 'valid'} or [list of] int
        Amount of padding to apply
    dilation : [list of] int
        Space between kernel elements
    output_padding : [list of] int
        Amount of padding to apply to the ouptput of a transposed
        convolution.
    groups : int
        Number of groups
    bias : int
        Add a learnable bias term
    padding_mode : BoundType
        How to pad the tensor
    device : torch.device
        Weights' device
    dtype : torch.dtype
        Weights' data type

    Returns
    -------
    layer : Conv or ConvTransposed or SeparableConv or CrossHairConv
    """
    opt = dict(
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation,
        output_padding=output_padding,
        groups=groups,
        bias=bias,
        padding_mode=padding_mode,
        device=device,
        dtype=dtype,
    )
    if transpose:
        klass = ConvTransposed
        if separable:
            raise ValueError('Separable convolutions cannot be transposed')
    else:
        opt.pop('output_padding')
        if separable:
            if isinstance(separable, str):
                if separable == 'crosshair':
                    klass = CrossHairConv
                else:
                    raise ValueError(f'Unknown separable value "{separable}"')
            else:
                klass = SeparableConv
            if ensure_tuple(stride, 3) != (1, 1, 1):
                # TODO: implement strided separable convs
                raise ValueError('Separable convolutions cannot be strided')
        else:
            klass = Conv
    return klass(ndim, inp_channels, out_channels, **opt)

cassetta.layers.convblocks

ConvBlockBase

ConvBlockBase(ndim, inp_channels, out_channels=None, activation='ReLU', norm=None, dropout=None, attention=None, order='ncdax', optc=None, opta=None, optn=None, optd=None, optx=None)

Bases: Sequential

Base class for unstrided convolution blocks that contain any of these layers:

  • Norm (n)
  • Conv (c)
  • Dropout (d)
  • Activation (a)
  • Attention (x)
Source code in cassetta/layers/convblocks.py
def __init__(
    self,
    ndim: int,
    inp_channels: int,
    out_channels: Optional[int] = None,
    activation: ActivationType = 'ReLU',
    norm: NormType = None,
    dropout: DropoutType = None,
    attention: AttentionType = None,
    order: str = 'ncdax',
    optc: Optional[dict] = None,
    opta: Optional[dict] = None,
    optn: Optional[dict] = None,
    optd: Optional[dict] = None,
    optx: Optional[dict] = None,
):
    super().__init__()
    self.order = self.fix_order(order)
    out_channels = out_channels or inp_channels
    norm_channels = (
        inp_channels if order.index('n') < order.index('c') else
        out_channels
    )
    attention_channels = (
        inp_channels if order.index('x') < order.index('c') else
        out_channels
    )

    conv = make_conv(ndim, inp_channels, out_channels, **(optc or {}))
    norm = make_norm(norm, norm_channels, **(optn or {}))
    dropout = make_dropout(dropout, **(optd or {}))
    activation = make_activation(activation, **(opta or {}))
    attention = make_attention(attention, channels=attention_channels,
                               ndim=ndim, **(optx or {}))

    # Assign submodules in order
    for o in self.order:
        if o == 'n':
            self.norm = norm
        elif o == 'c':
            self.conv = conv
        elif o == 'd':
            self.dropout = dropout
        elif o == 'a':
            self.activation = activation
        elif o == 'x':
            self.attention = attention

ConvBlock

ConvBlock(ndim, inp_channels, out_channels, kernel_size=3, dilation=1, bias=True, activation='ReLU', norm=None, dropout=None, attention=None, compression=16, order='ncdax', separable=False, optc=None, optn=None, optd=None, opta=None, optx=None)

Bases: ConvBlockBase

A single convolution, in a Norm + Conv + Dropout + Activation + Attention group.

Diagram

flowchart LR
    subgraph ConvBlock
        n("Norm"):::w-->
        on["C<sub>inp</sub>"]    ---c("Conv"):::w-->
        oc["C<sub>out</sub>"]    ---d("Dropout"):::d-->
        od["C<sub>out</sub>"]    ---a("Activation"):::d-->
        oa["C<sub>out</sub>"]    --- x("Attention"):::w
    end
    i["C<sub>inp</sub>"]:::i --- n
    x --> ox["C<sub>out</sub>"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef n fill:none,stroke:none;
flowchart LR
    subgraph ConvBlock
        c("Conv"):::w-->
        oc["C<sub>out</sub>"]    ---n("Norm"):::w-->
        on["C<sub>out</sub>"]    ---d("Dropout"):::d-->
        od["C<sub>out</sub>"]    ---a("Activation"):::d-->
        oa["C<sub>out</sub>"]    --- x("Attention"):::w
    end
    i["C<sub>inp</sub>"]:::i --- c
    x --> ox["C<sub>out</sub>"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef n fill:none,stroke:none;

Padding is always 'same'

Ordering

The order of the Norm/Conv/Dropout/Activation layers can be chosen with the argument order. For example:

  • order='ncdax': Norm -> Conv -> Dropout -> Activation -> Attention
  • order='andc': Activation -> Norm -> Dropout -> Conv
PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

inp_channels

Number of input features

TYPE: int

out_channels

Number of output features

TYPE: int

kernel_size

Kernel size

TYPE: [list of] int DEFAULT: 3

dilation

Dilation size

TYPE: [list of] int DEFAULT: 1

bias

Include a bias term

TYPE: bool DEFAULT: True

activation

Activation function

TYPE: ActivationType DEFAULT: 'ReLU'

norm

Normalization function ('batch', 'instance', 'layer')

TYPE: NormType DEFAULT: None

dropout

Dropout probability

TYPE: DropoutType DEFAULT: None

attention

Attention layer

TYPE: bool or {sqzex, bcam, mha} DEFAULT: None

compression

Compression ratio of the attention layer

TYPE: int DEFAULT: 16

order

Modules order (permutation of 'ncdax')

TYPE: str DEFAULT: 'ncdax'

separable

Use a separable (or cross-hair) convolution

TYPE: bool or {cross} DEFAULT: False

PARAMETER DESCRIPTION
optc

Other convolution parameters

TYPE: dict

optn

Other nomralization parameters

TYPE: dict

optd

Other dropout parameters

TYPE: dict

opta

Other activation parameters

TYPE: dict

optx

Other attention parameters

TYPE: dict

Source code in cassetta/layers/convblocks.py
def __init__(
    self,
    ndim: int,
    inp_channels: int,
    out_channels: int,
    kernel_size: OneOrSeveral[int] = 3,
    dilation: OneOrSeveral[int] = 1,
    bias: bool = True,
    activation: ActivationType = 'ReLU',
    norm: NormType = None,
    dropout: DropoutType = None,
    attention: AttentionType = None,
    compression: int = 16,
    order: str = 'ncdax',
    separable: Union[bool, Literal['crosshair']] = False,
    optc: Optional[dict] = None,
    optn: Optional[dict] = None,
    optd: Optional[dict] = None,
    opta: Optional[dict] = None,
    optx: Optional[dict] = None,
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    inp_channels : int
        Number of input features
    out_channels : int
        Number of output features
    kernel_size : [list of] int
        Kernel size
    dilation : [list of] int
        Dilation size
    bias : bool
        Include a bias term
    activation : ActivationType
        Activation function
    norm : NormType
        Normalization function ('batch', 'instance', 'layer')
    dropout : DropoutType
        Dropout probability
    attention : bool or {'sqzex', 'bcam', 'mha'}
        Attention layer
    compression : int
        Compression ratio of the attention layer
    order : str
        Modules order (permutation of 'ncdax')
    separable : bool or {'cross'}
        Use a separable (or cross-hair) convolution

    Other Parameters
    ----------------
    optc : dict
        Other convolution parameters
    optn : dict
        Other nomralization parameters
    optd : dict
        Other dropout parameters
    opta : dict
        Other activation parameters
    optx : dict
        Other attention parameters
    """
    optc = optc or {}
    optn = optn or {}
    optd = optd or {}
    opta = opta or {}
    optx = optx or {}
    optc.update(dict(
        kernel_size=kernel_size,
        dilation=dilation,
        bias=bias,
        separable=separable,
        padding='same',
    ))
    optx.update(dict(
        compression=compression
    ))
    super().__init__(
        ndim=ndim,
        inp_channels=inp_channels,
        out_channels=out_channels,
        activation=activation,
        norm=norm,
        dropout=dropout,
        attention=attention,
        order=order,
        optc=optc,
        optn=optn,
        optd=optd,
        opta=opta,
        optx=optx,
    )

forward

forward(inp)
PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, inp_channels, *inp_size) tensor

RETURNS DESCRIPTION
out

Output downsampled tensor

TYPE: (B, out_channels, *out_size) tensor

Source code in cassetta/layers/convblocks.py
def forward(self, inp: Tensor) -> Tensor:
    """
    Parameters
    ----------
    inp : (B, inp_channels, *inp_size) tensor
        Input tensor

    Returns
    -------
    out : (B, out_channels, *out_size) tensor
        Output downsampled tensor
    """
    return super().forward(inp)

ConvGroup

ConvGroup(ndim, channels, nb_conv=1, recurrent=False, residual=False, kernel_size=3, dilation=1, bias=True, activation='ReLU', norm=None, dropout=None, attention=None, compression=16, order='ncdax', separable=False, skip=0)

Bases: ModuleGroup

Multiple convolution blocks stacked together

Diagram

flowchart LR
    subgraph nb_blocks
        2("ConvBlock 1"):::w  --> 3["C"] ---
        4("ConvBlock 2"):::w  --> 5["C"] ---
        6("..."):::n      --> 7["C"] ---
        8("ConvBlock N"):::w
    end
    1["C [+S]"]:::i --- 2
    8 ---> 9["C"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef n fill:none,stroke:none;
flowchart LR
    subgraph nb_blocks
        2("ConvBlock 1"):::w  --> 3["C"] ---
        4(("+")):::d      --> 5["C"] ---
        6("ConvBlock 2"):::w  --> 7["C"] ---
        8(("+")):::d      --> 9["C"] ---
        10("..."):::n     --> 11["C"] ---
        12("ConvBlock N"):::w --> 13["C"] ---
        14(("+")):::d
    end
    1["C [+S]"]:::i --- 2
    1 --- 4
    5 --- 8
    11 --- 14
    14 ---> 15["C"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef n fill:none,stroke:none;

The recurrent variant shares weights across blocks

The number of channels is preserved throughout

Padding is always 'same'

PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

channels

Number of input and output features

TYPE: int

nb_conv

Number of convolution blocks

TYPE: int DEFAULT: 1

recurrent

Recurrent network: share weights across blocks

TYPE: bool DEFAULT: False

residual

Use residual connections between blocks

TYPE: bool DEFAULT: False

kernel_size

Kernel size

TYPE: [list of] int DEFAULT: 3

dilation

Dilation size

TYPE: [list of] int DEFAULT: 1

bias

Include a bias term

TYPE: bool DEFAULT: True

activation

Activation function

TYPE: ActivationType DEFAULT: 'ReLU'

norm

Normalization function ('batch', 'instance', 'layer')

TYPE: NormType DEFAULT: None

dropout

Dropout probability

TYPE: DropoutType DEFAULT: None

attention

Attention layer

TYPE: bool or {sqzex, bcam, mha} DEFAULT: None

compression

Compression ratio of the attention layer

TYPE: int DEFAULT: 16

order

Modules order (permutation of 'ncdax')

TYPE: str DEFAULT: 'ncdax'

separable

Use a separable (or cross-hair) convolution

TYPE: bool or {cross} DEFAULT: False

skip

Number of additional skipped channels in the input tensor.

TYPE: int DEFAULT: 0

Source code in cassetta/layers/convblocks.py
def __init__(
    self,
    ndim: int,
    channels: int,
    nb_conv: int = 1,
    recurrent: bool = False,
    residual: bool = False,
    kernel_size: OneOrSeveral[int] = 3,
    dilation: OneOrSeveral[int] = 1,
    bias: bool = True,
    activation: ActivationType = 'ReLU',
    norm: NormType = None,
    dropout: DropoutType = None,
    attention: AttentionType = None,
    compression: int = 16,
    order: str = 'ncdax',
    separable: Union[bool, Literal['crosshair']] = False,
    skip: int = 0,
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    channels : int
        Number of input and output features
    nb_conv : int
        Number of convolution blocks
    recurrent : bool
        Recurrent network: share weights across blocks
    residual : bool
        Use residual connections between blocks
    kernel_size : [list of] int
        Kernel size
    dilation : [list of] int
        Dilation size
    bias : bool
        Include a bias term
    activation : ActivationType
        Activation function
    norm : NormType
        Normalization function ('batch', 'instance', 'layer')
    dropout : DropoutType
        Dropout probability
    attention : bool or {'sqzex', 'bcam', 'mha'}
        Attention layer
    compression : int
        Compression ratio of the attention layer
    order : str
        Modules order (permutation of 'ncdax')
    separable : bool or {'cross'}
        Use a separable (or cross-hair) convolution
    skip : int
        Number of additional skipped channels in the input tensor.
    """

    OneConv = partial(
        ConvBlock,
        ndim,
        out_channels=channels,
        kernel_size=kernel_size,
        dilation=dilation,
        bias=bias,
        activation=activation,
        norm=norm,
        dropout=dropout,
        attention=attention,
        compression=compression,
        order=order,
        separable=separable,
    )

    layers = []
    if skip:
        nb_conv -= 1
        layers = [OneConv(channels + skip)]

    if recurrent:
        layers += [OneConv(channels)] * nb_conv
    else:
        layers += [OneConv(channels) for _ in range(nb_conv)]
    super().__init__(layers, residual=residual, skip=skip)

DownGroup

DownGroup(module_down, module_block)

Bases: Sequential

A downsampling step followed by a bunch of layers.

Diagram

flowchart LR
    i["[C<sub>inp</sub>, W]"]:::i ---d("Down"):::w-->
    1["[C<sub>mid</sub>, W/2]"] ---b("Block"):::w-->
    o["[C<sub>out</sub>, W/2]"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef n fill:none,stroke:none;
PARAMETER DESCRIPTION
module_down

A downsamlping layer, such as [DownConv][cassetta.layers.DownConv], [DownPool][cassetta.layers.DownPool], or [DownInterpol][cassetta.layers.DownInterpol].

TYPE: Module

module_block

A block of layers, that typically preserve the input shape, such as [ConvBlock][cassetta.layers.ConvBlock] or [ConvGroup][cassetta.layers.ConvGroup].

TYPE: Module

Source code in cassetta/layers/convblocks.py
def __init__(self, module_down: nn.Module, module_block: nn.Module):
    """
    Parameters
    ----------
    module_down : Module
        A downsamlping layer, such as
        [`DownConv`][cassetta.layers.DownConv],
        [`DownPool`][cassetta.layers.DownPool], or
        [`DownInterpol`][cassetta.layers.DownInterpol].
    module_block : Module
        A block of layers, that typically preserve the input shape,
        such as [`ConvBlock`][cassetta.layers.ConvBlock] or
        [`ConvGroup`][cassetta.layers.ConvGroup].
    """
    super().__init__(module_down, module_block)

forward

forward(inp)
PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, inp_channels, *inp_size) tensor

RETURNS DESCRIPTION
out

Output downsampled tensor

TYPE: (B, out_channels, *out_size) tensor

indices

Indices, if return_indices=True.

TYPE: (B, out_channels, *out_size) tensor[long]

Source code in cassetta/layers/convblocks.py
def forward(self, inp: Tensor) -> OneOrSeveral[Tensor]:
    """
    Parameters
    ----------
    inp : (B, inp_channels, *inp_size) tensor
        Input tensor

    Returns
    -------
    out : (B, out_channels, *out_size) tensor
        Output downsampled tensor
    indices : (B, out_channels, *out_size) tensor[long]
        Indices, if `return_indices=True`.
    """
    down, block = self
    if getattr(down, 'return_indices', False):
        out, indices = down
        out = block(out)
        return out, indices
    else:
        return super().forward(inp)

DownConvGroup

DownConvGroup(ndim, inp_channels, out_channels=None, factor=2, mode='interpol', nb_conv=1, kernel_size=3, dilation=1, recurrent=False, residual=False, bias=True, activation='ReLU', norm=None, dropout=None, attention=None, compression=16, order='ncdax', separable=False, **down_options)

Bases: DownGroup

A downsampling step followed by a series of convolution blocks

Diagram

flowchart LR
    subgraph "<code>nb_conv</code>"
        2("ConvBlock 1"):::w  --> 3["[C<sub>out</sub>, W/2]"] ---
        4("ConvBlock 2"):::w  --> 5["[C<sub>out</sub>, W/2]"] ---
        6("..."):::n      --> 7["[C<sub>out</sub>, W/2]"] ---
        8("ConvBlock N"):::w
    end
    i["[C<sub>inp</sub>, W]"]:::i ---d("Down"):::w-->
    od["[C<sub>out</sub>, W/2]"]  ---2
    8 ---> 9["[C<sub>out</sub>, W/2]"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef n fill:none,stroke:none;
flowchart LR
    subgraph "<code>nb_conv</code>"
        2("ConvBlock 1"):::w  --> 3["[C<sub>out</sub>, W/2]"] ---
        4(("+")):::d          --> 5["[C<sub>out</sub>, W/2]"] ---
        6("ConvBlock 2"):::w  --> 7["[C<sub>out</sub>, W/2]"] ---
        8(("+")):::d          --> 9["[C<sub>out</sub>, W/2]"] ---
        10("..."):::n         --> 11["[C<sub>out</sub>, W/2]"] ---
        12("ConvBlock N"):::w --> 13["[C<sub>out</sub>, W/2]"] ---
        14(("+")):::d
    end
    i["[C<sub>inp</sub>, W]"]:::i ---d("Down"):::w-->
    od["[C<sub>out</sub>, W/2]"]  ---2
    od --- 4
    5 --- 8
    11 --- 14
    14 ---> 15["[C<sub>out</sub>, W/2]"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef n fill:none,stroke:none;
PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

inp_channels

Number of input channels

TYPE: int

out_channels

Number of output channels

TYPE: int DEFAULT: `inp_channels`

factor

Downsampling factor

TYPE: [list of] int DEFAULT: 2

mode

Downsampling mode

TYPE: (conv, interpol, pool) DEFAULT: 'conv'

nb_conv

Number of convolution blocks

TYPE: int DEFAULT: 1

kernel_size

Kernel size

TYPE: [list of] int DEFAULT: 3

dilation

Dilation size

TYPE: [list of] int DEFAULT: 1

recurrent

Recurrent network: share weights across blocks

TYPE: bool DEFAULT: False

residual

Use residual connections between blocks

TYPE: bool DEFAULT: False

bias

Include a bias term

TYPE: bool DEFAULT: True

activation

Activation function

TYPE: ActivationType DEFAULT: 'ReLU'

norm

Normalization function ('batch', 'instance', 'layer')

TYPE: NormType DEFAULT: None

dropout

Dropout probability

TYPE: DropoutType DEFAULT: None

sqzex

Squeeze & Excitation layer

TYPE: bool or {s, c, sc}

compression

Compression ratio of the Squeeze & Excitation layer

TYPE: int DEFAULT: 16

order

Modules order (permutation of 'ncdax')

TYPE: str DEFAULT: 'ncdax'

separable

Use a separable (or cross-hair) convolution

TYPE: bool or {cross} DEFAULT: False

PARAMETER DESCRIPTION
interpolation

Spline order

TYPE: InterpolationType, if `mode="interpol"`

bound

Boundary conditions

TYPE: BoundType, if `mode="interpol"`

prefilter

Perform proper interpolation by applying spline preflitering

return_indices

Return argmax indices

Source code in cassetta/layers/convblocks.py
def __init__(
    self,
    ndim: int,
    inp_channels: int,
    out_channels: Optional[int] = None,
    factor: OneOrSeveral[int] = 2,
    mode: Literal['conv', 'interpol', 'pool'] = 'interpol',
    nb_conv: int = 1,
    kernel_size: OneOrSeveral[int] = 3,
    dilation: OneOrSeveral[int] = 1,
    recurrent: bool = False,
    residual: bool = False,
    bias: bool = True,
    activation: ActivationType = 'ReLU',
    norm: NormType = None,
    dropout: DropoutType = None,
    attention: AttentionType = None,
    compression: int = 16,
    order: str = 'ncdax',
    separable: Union[bool, Literal['crosshair']] = False,
    **down_options,
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    inp_channels : int
        Number of input channels
    out_channels : int, default=`inp_channels`
        Number of output channels
    factor : [list of] int
        Downsampling factor
    mode : {'conv', 'interpol', 'pool'}
        Downsampling mode
    nb_conv : int
        Number of convolution blocks
    kernel_size : [list of] int
        Kernel size
    dilation : [list of] int
        Dilation size
    recurrent : bool
        Recurrent network: share weights across blocks
    residual : bool
        Use residual connections between blocks
    bias : bool
        Include a bias term
    activation : ActivationType
        Activation function
    norm : NormType
        Normalization function ('batch', 'instance', 'layer')
    dropout : DropoutType
        Dropout probability
    sqzex : bool or {'s', 'c', 'sc'}
        Squeeze & Excitation layer
    compression : int
        Compression ratio of the Squeeze & Excitation layer
    order : str
        Modules order (permutation of 'ncdax')
    separable : bool or {'cross'}
        Use a separable (or cross-hair) convolution

    Other Parameters
    ----------------
    interpolation : InterpolationType, if `mode="interpol"`
        Spline order
    bound : BoundType, if `mode="interpol"`
        Boundary conditions
    prefilter: bool, if `mode="interpol"`
        Perform proper interpolation by applying spline preflitering
    return_indices: bool, if `mode="pool"`
        Return argmax indices
    """
    mode = mode[0].lower()
    Down = (
        DownConv if mode == 'c' else
        DownPool if mode == 'p' else
        DownInterpol if mode == 'i' else
        None
    )
    down = Down(
        ndim=ndim,
        inp_channels=inp_channels,
        out_channels=out_channels,
        factor=factor,
        **down_options,
    )
    conv = ConvGroup(
        ndim=ndim,
        channels=out_channels,
        nb_conv=nb_conv,
        kernel_size=kernel_size,
        dilation=dilation,
        bias=bias,
        activation=activation,
        norm=norm,
        dropout=dropout,
        attention=attention,
        compression=compression,
        order=order,
        separable=separable,
        residual=residual,
        recurrent=recurrent,
    )
    super().__init__(down, conv)

UpGroup

UpGroup(module_up, module_block, skip=False)

Bases: Sequential

An upsampling step followed by a bunch of layers.

Diagram

flowchart LR
    i["[C<sub>inp</sub>, W]"]:::i ---d("Up"):::w-->
    1["[C<sub>mid</sub>, W*2]"] ---b("Block"):::w-->
    o["[C<sub>out</sub>, W*2]"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef n fill:none,stroke:none;
flowchart LR
    i["[C<sub>inp</sub>, W]"]:::i ---d("Up"):::w-->
    1["[C<sub>mid</sub>, W*2]"]
    s["[C<sub>mid</sub>, W*2]"]:::i
    1 & s --- c(("c")):::d --->
    2["[C<sub>mid</sub>*2, W*2]"] ---b("Block"):::w-->
    o["[C<sub>out</sub>, W*2]"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef n fill:none,stroke:none;
flowchart LR
    i["[C<sub>inp</sub>, W]"]:::i ---d("Up"):::w-->
    1["[C<sub>mid</sub>, W*2]"]
    s["[C<sub>mid</sub>, W*2]"]:::i
    1 & s --- c(("+")):::d --->
    2["[C<sub>mid</sub>, W*2]"] ---b("Block"):::w-->
    o["[C<sub>out</sub>, W*2]"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef n fill:none,stroke:none;
PARAMETER DESCRIPTION
module_up

An upsampling layer, such as [UpConv][cassetta.layers.UpConv], [UpPool][cassetta.layers.UpPool], or [UpInterpol][cassetta.layers.UpInterpol].

TYPE: Module

module_block

A block of layers, that typically preserve the input shape, such as [ConvBlock][cassetta.layers.ConvBlock] or [ConvGroup][cassetta.layers.ConvGroup].

TYPE: Module

skip

Whether to concatenate (skip=True) or add (skip=False) eventual skip connections.

TYPE: bool DEFAULT: False

Source code in cassetta/layers/convblocks.py
def __init__(self, module_up: nn.Module, module_block: nn.Module,
             skip: bool = False):
    """
    Parameters
    ----------
    module_up : Module
        An upsampling layer, such as
        [`UpConv`][cassetta.layers.UpConv],
        [`UpPool`][cassetta.layers.UpPool], or
        [`UpInterpol`][cassetta.layers.UpInterpol].
    module_block : Module
        A block of layers, that typically preserve the input shape,
        such as [`ConvBlock`][cassetta.layers.ConvBlock] or
        [`ConvGroup`][cassetta.layers.ConvGroup].
    skip : bool
        Whether to concatenate (`skip=True`) or add (`skip=False`)
        eventual skip connections.
    """
    super().__init__(module_up, module_block)
    self.skip = skip

forward

forward(inp, *skips, indices=None)
PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, inp_channels, *inp_size) tensor

*skips

Skipped tensors

TYPE: (B, skip_channels, *inp_size) tensor DEFAULT: ()

PARAMETER DESCRIPTION
indices

Unpool indices. Only if module_up is an [UpPool][cassetta.layers.UpPool].

TYPE: (B, out_tensor, *inp_size) tensor[long]

RETURNS DESCRIPTION
out

Output downsampled tensor

TYPE: (B, out_channels, *out_size) tensor

Source code in cassetta/layers/convblocks.py
def forward(
    self,
    inp: Tensor,
    *skips,
    indices: Optional[Tensor] = None
) -> Tensor:
    """
    Parameters
    ----------
    inp : (B, inp_channels, *inp_size) tensor
        Input tensor
    *skips : (B, skip_channels, *inp_size) tensor
        Skipped tensors

    Other Parameters
    ----------------
    indices : (B, out_tensor, *inp_size) tensor[long]
        Unpool indices. Only if `module_up` is an
        [`UpPool`][cassetta.layers.UpPool].

    Returns
    -------
    out : (B, out_channels, *out_size) tensor
        Output downsampled tensor
    """
    up, conv = self
    kwargs = dict(indices=indices) if indices is not None else {}
    out = up(inp, **kwargs)
    if skips:
        if not self.skip:
            for skip in skips:
                out += skip
        else:
            out = torch.cat([out, *skips], dim=1)
    return conv(out)

UpConvGroup

UpConvGroup(ndim, inp_channels, out_channels=None, factor=2, skip=0, mode='interpol', nb_conv=1, kernel_size=3, dilation=1, recurrent=False, residual=False, bias=True, activation='ReLU', norm=None, dropout=False, attention=None, compression=16, order='ncdax', separable=False, **up_options)

Bases: UpGroup

A upsampling step followed by a series of convolution blocks, potentially with a skip connection

Diagram

flowchart LR
    subgraph "<code>nb_conv</code>"
        2("ConvBlock 1"):::w  --> 3["[C<sub>out</sub>, W*2]"] ---
        4("ConvBlock 2"):::w  --> 5["[C<sub>out</sub>, W*2]"] ---
        6("..."):::n      --> 7["[C<sub>out</sub>, W*2]"] ---
        8("ConvBlock N"):::w
    end
    i["[C<sub>inp</sub>, W]"]:::i ---d("Up"):::w-->
    1["[C<sub>out</sub>, W*2]"] --- 2
    8 --->  o["[C<sub>out</sub>, W*2]"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef n fill:none,stroke:none;
flowchart LR
    subgraph "<code>nb_conv</code>"
        2("ConvBlock 1"):::w  --> 3["[C<sub>out</sub>, W*2]"] ---
        4("ConvBlock 2"):::w  --> 5["[C<sub>out</sub>, W*2]"] ---
        6("..."):::n      --> 7["[C<sub>out</sub>, W*2]"] ---
        8("ConvBlock N"):::w
    end
    i["[C<sub>inp</sub>, W]"]:::i ---d("Up"):::w-->
    1["[C<sub>out</sub>, W*2]"]
    s["[C<sub>out</sub>, W*2]"]:::i
    1 & s --- c(("c")):::d ---> x["[C<sub>out</sub>*2, W*2]"] --- 2
    8 ---> o["[C<sub>out</sub>, W*2]"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef n fill:none,stroke:none;
flowchart LR
    subgraph "<code>nb_conv</code>"
        2("ConvBlock 1"):::w  --> 3["[C<sub>out</sub>, W*2]"] ---
        4("ConvBlock 2"):::w  --> 5["[C<sub>out</sub>, W*2]"] ---
        6("..."):::n      --> 7["[C<sub>out</sub>, W*2]"] ---
        8("ConvBlock N"):::w
    end
    i["[C<sub>inp</sub>, W]"]:::i ---d("Up"):::w-->
    1["[C<sub>out</sub>, W*2]"]
    s["[C<sub>out</sub>, W*2]"]:::i
    1 & s --- c(("+")):::d ---> x["[C<sub>out</sub>, W*2]"] --- 2
    8 ---> o["[C<sub>out</sub>, W*2]"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef n fill:none,stroke:none;
PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

inp_channels

Number of input channels

TYPE: int

out_channels

Number of output channels

TYPE: int DEFAULT: `inp_channels`

factor

Downsampling factor

TYPE: [list of] int DEFAULT: 2

skip

Number of channels to concatenate in the skip connection. If 0 and skip tensors are provided, will try to add them instead of cat.

TYPE: int DEFAULT: 0

mode

Downsampling mode

TYPE: (conv, interpol, pool) DEFAULT: 'conv'

nb_conv

Number of convolution blocks

TYPE: int DEFAULT: 1

kernel_size

Kernel size

TYPE: [list of] int DEFAULT: 3

dilation

Dilation size

TYPE: [list of] int DEFAULT: 1

recurrent

Recurrent network: share weights across blocks

TYPE: bool DEFAULT: False

residual

Use residual connections between blocks

TYPE: bool DEFAULT: False

bias

Include a bias term

TYPE: bool DEFAULT: True

activation

Activation function

TYPE: ActivationType DEFAULT: 'ReLU'

norm

Normalization function ('batch', 'instance', 'layer')

TYPE: NormType DEFAULT: None

dropout

Dropout probability

TYPE: DropoutType DEFAULT: False

order

Modules order (permutation of 'ncdax')

TYPE: str DEFAULT: 'ncdax'

separable

Use a separable (or cross-hair) convolution

TYPE: bool or {cross} DEFAULT: False

PARAMETER DESCRIPTION
interpolation

Spline order

TYPE: InterpolationType, if `mode="interpol"`

bound

Boundary conditions

TYPE: BoundType, if `mode="interpol"`

prefilter

Perform proper interpolation by applying spline preflitering

Source code in cassetta/layers/convblocks.py
def __init__(
    self,
    ndim: int,
    inp_channels: int,
    out_channels: Optional[int] = None,
    factor: OneOrSeveral[int] = 2,
    skip: int = 0,
    mode: Literal['conv', 'interpol', 'pool'] = 'interpol',
    nb_conv: int = 1,
    kernel_size: OneOrSeveral[int] = 3,
    dilation: OneOrSeveral[int] = 1,
    recurrent: bool = False,
    residual: bool = False,
    bias: bool = True,
    activation: ActivationType = 'ReLU',
    norm: NormType = None,
    dropout: DropoutType = False,
    attention: AttentionType = None,
    compression: int = 16,
    order: str = 'ncdax',
    separable: Union[bool, Literal['crosshair']] = False,
    **up_options,
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    inp_channels : int
        Number of input channels
    out_channels : int, default=`inp_channels`
        Number of output channels
    factor : [list of] int
        Downsampling factor
    skip : int
        Number of channels to concatenate in the skip connection.
        If 0 and skip tensors are provided, will try to add them
        instead of cat.
    mode : {'conv', 'interpol', 'pool'}
        Downsampling mode
    nb_conv : int
        Number of convolution blocks
    kernel_size : [list of] int
        Kernel size
    dilation : [list of] int
        Dilation size
    recurrent : bool
        Recurrent network: share weights across blocks
    residual : bool
        Use residual connections between blocks
    bias : bool
        Include a bias term
    activation : ActivationType
        Activation function
    norm : NormType
        Normalization function ('batch', 'instance', 'layer')
    dropout : DropoutType
        Dropout probability
    order : str
        Modules order (permutation of 'ncdax')
    separable : bool or {'cross'}
        Use a separable (or cross-hair) convolution

    Other Parameters
    ----------------
    interpolation : InterpolationType, if `mode="interpol"`
        Spline order
    bound : BoundType, if `mode="interpol"`
        Boundary conditions
    prefilter: bool, if `mode="interpol"`
        Perform proper interpolation by applying spline preflitering
    """
    mode = mode[0].lower()
    Up = (
        UpConv if mode == 'c' else
        UpPool if mode in ('p', 'u') else
        UpInterpol if mode == 'i' else
        None
    )
    up = Up(
        ndim=ndim,
        inp_channels=inp_channels,
        out_channels=out_channels,
        factor=factor,
        **up_options,
    )
    conv = ConvGroup(
        ndim=ndim,
        channels=out_channels,
        nb_conv=nb_conv,
        kernel_size=kernel_size,
        dilation=dilation,
        bias=bias,
        activation=activation,
        norm=norm,
        dropout=dropout,
        attention=attention,
        compression=compression,
        order=order,
        separable=separable,
        residual=residual,
        recurrent=recurrent,
        skip=skip,
    )
    super().__init__(up, conv, skip=skip)

cassetta.layers.dropout

ChannelDropout

Bases: _DropoutNd

forward

forward(inp)
PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, C, *size) tensor

out

Output tensor

TYPE: (B, C, *size) tensor

Source code in cassetta/layers/dropout.py
def forward(self, inp: Tensor) -> Tensor:
    """
    Parameters
    ----------
    inp : (B, C, *size) tensor
        Input tensor

    out : (B, C, *size) tensor
        Output tensor
    """
    ndim = inp.ndim - 2
    dropout = getattr(F, f'dropout{ndim}d')
    return dropout(inp, self.p, self.training, self.inplace)

make_dropout

make_dropout(dropout, **kwargs)

Instantiate a (channel) dropout module.

PARAMETER DESCRIPTION
dropout

An already instantiated nn.Module, or a nn.Module subclass, or a callable that returns an instantiated nn.Module, or the dropout probability.

TYPE: DropoutType

kwargs

Additional parameters to pass to the constructor or function.

TYPE: dict DEFAULT: {}

RETURNS DESCRIPTION
dropout

An instantiated nn.Module.

TYPE: Module

Source code in cassetta/layers/dropout.py
def make_dropout(dropout: DropoutType, **kwargs):
    """
    Instantiate a (channel) dropout module.

    Parameters
    ----------
    dropout : DropoutType
        An already instantiated `nn.Module`, or a `nn.Module` subclass,
        or a callable that returns an instantiated `nn.Module`, or the
        dropout probability.
    kwargs : dict
        Additional parameters to pass to the constructor or function.

    Returns
    -------
    dropout : Module
        An instantiated `nn.Module`.
    """
    if not dropout:
        return None

    if isinstance(dropout, nn.Module):
        return dropout

    if isinstance(dropout, float):
        return ChannelDropout(dropout, **kwargs)

    dropout = dropout(**kwargs)

    if not isinstance(dropout, nn.Module):
        raise ValueError('Dropout did not instantiate a Module')
    return dropout

cassetta.layers.interpol

GridPull

GridPull(interpolation='linear', bound='zero', extrapolate=False, prefilter=False)

Bases: Module

Deform an image using a coordinates field.

PARAMETER DESCRIPTION
interpolation

Interpolation order.

TYPE: [list of] InterpolationType DEFAULT: 'linear'

bound

Boundary conditions.

TYPE: [list of] BoundType DEFAULT: 'zero'

extrapolate

Extrapolate out-of-bound data.

TYPE: bool or int DEFAULT: False

prefilter

Apply spline pre-filter (= interpolates the input)

TYPE: bool DEFAULT: False

Source code in cassetta/layers/interpol.py
def __init__(
    self,
    interpolation: InterpolationType = 'linear',
    bound: BoundType = 'zero',
    extrapolate: bool = False,
    prefilter: bool = False
):
    """
    Parameters
    ----------
    interpolation : [list of] InterpolationType
        Interpolation order.
    bound : [list of] BoundType
        Boundary conditions.
    extrapolate : bool or int
        Extrapolate out-of-bound data.
    prefilter : bool
        Apply spline pre-filter (= interpolates the input)

    """
    super().__init__()
    self.interpolation = interpolation
    self.bound = bound
    self.extrapolate = extrapolate
    self.prefilter = prefilter

forward

forward(input, grid)

Sample an image.

If the input dtype is not a floating point type, the input image is assumed to contain labels. Then, unique labels are extracted and resampled individually, making them soft labels. Finally, the label map is reconstructed from the individual soft labels by assigning the label with maximum soft value.

PARAMETER DESCRIPTION
input

Input image.

TYPE: (batch, channel, *inshape) tensor

grid

Coordinate field, in voxels.

TYPE: (batch, ndim, *outshape) tensor

RETURNS DESCRIPTION
output

Deformed image.

TYPE: (batch, channel, *outshape) tensor

Source code in cassetta/layers/interpol.py
def forward(self, input, grid):
    """
    Sample an image.

    If the input dtype is not a floating point type, the input image is
    assumed to contain labels. Then, unique labels are extracted
    and resampled individually, making them soft labels. Finally,
    the label map is reconstructed from the individual soft labels by
    assigning the label with maximum soft value.

    Parameters
    ----------
    input : (batch, channel, *inshape) tensor
        Input image.
    grid : (batch, ndim, *outshape) tensor
        Coordinate field, in voxels.

    Returns
    -------
    output : (batch, channel, *outshape) tensor
        Deformed image.
    """
    grid = torch.movedim(grid, 1, -1)
    return torch_interpol.grid_pull(input, grid, **self._options)

FlowPull

FlowPull(interpolation='linear', bound='zero', extrapolate=False, prefilter=False)

Bases: GridPull

Deform an image using a displacement field

PARAMETER DESCRIPTION
interpolation

Interpolation order.

TYPE: [list of] InterpolationType DEFAULT: 'linear'

bound

Boundary conditions.

TYPE: [list of] BoundType DEFAULT: 'zero'

extrapolate

Extrapolate out-of-bound data.

TYPE: bool or int DEFAULT: False

prefilter

Apply spline pre-filter (= interpolates the input)

TYPE: bool DEFAULT: False

Source code in cassetta/layers/interpol.py
def __init__(
    self,
    interpolation: InterpolationType = 'linear',
    bound: BoundType = 'zero',
    extrapolate: bool = False,
    prefilter: bool = False
):
    """
    Parameters
    ----------
    interpolation : [list of] InterpolationType
        Interpolation order.
    bound : [list of] BoundType
        Boundary conditions.
    extrapolate : bool or int
        Extrapolate out-of-bound data.
    prefilter : bool
        Apply spline pre-filter (= interpolates the input)

    """
    super().__init__()
    self.interpolation = interpolation
    self.bound = bound
    self.extrapolate = extrapolate
    self.prefilter = prefilter

forward

forward(input, flow)

Sample an image.

If the input dtype is not a floating point type, the input image is assumed to contain labels. Then, unique labels are extracted and resampled individually, making them soft labels. Finally, the label map is reconstructed from the individual soft labels by assigning the label with maximum soft value.

PARAMETER DESCRIPTION
input

Input image.

TYPE: (batch, channel, *inshape) tensor

flow

Displacement field, in voxels.

TYPE: (batch, ndim, *outshape) tensor

RETURNS DESCRIPTION
output

Deformed image.

TYPE: (batch, channel, *outshape) tensor

Source code in cassetta/layers/interpol.py
def forward(self, input, flow):
    """
    Sample an image.

    If the input dtype is not a floating point type, the input image is
    assumed to contain labels. Then, unique labels are extracted
    and resampled individually, making them soft labels. Finally,
    the label map is reconstructed from the individual soft labels by
    assigning the label with maximum soft value.

    Parameters
    ----------
    input : (batch, channel, *inshape) tensor
        Input image.
    flow : (batch, ndim, *outshape) tensor
        Displacement field, in voxels.

    Returns
    -------
    output : (batch, channel, *outshape) tensor
        Deformed image.
    """
    flow = torch.movedim(flow, 1, -1)
    flow = torch_interpol.add_identity_grid(flow)
    flow = torch.movedim(flow, -1, 1)
    return super().forward(input, flow)

GridPush

GridPush(interpolation='linear', bound='zero', extrapolate=False, prefilter=False)

Bases: Module

Splat an image using a coordinates field.

PARAMETER DESCRIPTION
interpolation

Interpolation order.

TYPE: int or sequence[int] DEFAULT: 'linear'

bound

Boundary conditions.

TYPE: BoundType or sequence[BoundType] DEFAULT: 'zero'

extrapolate

Extrapolate out-of-bound data.

TYPE: bool or int DEFAULT: False

prefilter

Apply spline pre-filter (= interpolates the input)

TYPE: bool DEFAULT: False

Source code in cassetta/layers/interpol.py
def __init__(
        self,
        interpolation='linear',
        bound='zero',
        extrapolate=False,
        prefilter=False
):
    """

    Parameters
    ----------
    interpolation : int or sequence[int]
        Interpolation order.
    bound : BoundType or sequence[BoundType]
        Boundary conditions.
    extrapolate : bool or int
        Extrapolate out-of-bound data.
    prefilter : bool
        Apply spline pre-filter (= interpolates the input)

    """
    super().__init__()
    self.interpolation = interpolation
    self.bound = bound
    self.extrapolate = extrapolate
    self.prefilter = prefilter

forward

forward(input, grid, shape=None)

Splat an image.

PARAMETER DESCRIPTION
input

Input image.

TYPE: (batch, channel, *inshape) tensor

grid

Coordinate field, in voxels.

TYPE: (batch, ndim, *inshape) tensor

shape

Output spatial shape.

TYPE: sequence[int] DEFAULT: inshape

RETURNS DESCRIPTION
output

Splatted image.

TYPE: (batch, channel, *shape) tensor

Source code in cassetta/layers/interpol.py
def forward(self, input, grid, shape=None):
    """
    Splat an image.

    Parameters
    ----------
    input : (batch, channel, *inshape) tensor
        Input image.
    grid : (batch, ndim, *inshape) tensor
        Coordinate field, in voxels.
    shape : sequence[int], default=inshape
        Output spatial shape.

    Returns
    -------
    output : (batch, channel, *shape) tensor
        Splatted image.
    """
    grid = torch.movedim(grid, 1, -1)
    return torch_interpol.grid_push(input, grid, shape, **self._options)

FlowPush

FlowPush(interpolation='linear', bound='zero', extrapolate=False, prefilter=False)

Bases: GridPush

Splat an image using a displacement field

PARAMETER DESCRIPTION
interpolation

Interpolation order.

TYPE: int or sequence[int] DEFAULT: 'linear'

bound

Boundary conditions.

TYPE: BoundType or sequence[BoundType] DEFAULT: 'zero'

extrapolate

Extrapolate out-of-bound data.

TYPE: bool or int DEFAULT: False

prefilter

Apply spline pre-filter (= interpolates the input)

TYPE: bool DEFAULT: False

Source code in cassetta/layers/interpol.py
def __init__(
        self,
        interpolation='linear',
        bound='zero',
        extrapolate=False,
        prefilter=False
):
    """

    Parameters
    ----------
    interpolation : int or sequence[int]
        Interpolation order.
    bound : BoundType or sequence[BoundType]
        Boundary conditions.
    extrapolate : bool or int
        Extrapolate out-of-bound data.
    prefilter : bool
        Apply spline pre-filter (= interpolates the input)

    """
    super().__init__()
    self.interpolation = interpolation
    self.bound = bound
    self.extrapolate = extrapolate
    self.prefilter = prefilter

forward

forward(input, flow, shape=None)

Splat an image.

PARAMETER DESCRIPTION
input

Input image.

TYPE: (batch, channel, *inshape) tensor

flow

Displacement field, in voxels.

TYPE: (batch, ndim, *inshape) tensor

shape

Output spatial shape

TYPE: sequence[int] DEFAULT: inshape

RETURNS DESCRIPTION
output

Deformed image.

TYPE: (batch, channel, *shape) tensor

Source code in cassetta/layers/interpol.py
def forward(self, input, flow, shape=None):
    """
    Splat an image.

    Parameters
    ----------
    input : (batch, channel, *inshape) tensor
        Input image.
    flow : (batch, ndim, *inshape) tensor
        Displacement field, in voxels.
    shape : sequence[int], default=inshape
        Output spatial shape

    Returns
    -------
    output : (batch, channel, *shape) tensor
        Deformed image.
    """
    flow = torch.movedim(flow, 1, -1)
    flow = torch_interpol.add_identity_grid(flow)
    flow = torch.movedim(flow, -1, 1)
    return super().forward(input, flow, shape)

Resize

Resize(factor=None, shape=None, anchor='edge', interpolation='linear', bound='zero', extrapolate=True, prefilter=True)

Bases: Module

Resize (interpolate) an image

Notes
  • A least one of factor and shape must be specified
  • If anchor in ('center', 'edge'), exactly one of factor or shape must be specified.
  • If anchor in ('first', 'last'), factor must be provided even if shape is specified.
  • Because of rounding, it is in general not assured that resize(resize(x, f), 1/f) returns a tensor with the same shape as x.
     edge           center          first           last
e - + - + - e   + - + - + - +   + - + - + - +   + - + - + - +
| . | . | . |   | c | . | c |   | f | . | . |   | . | . | . |
+ _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +
| . | . | . |   | . | . | . |   | . | . | . |   | . | . | . |
+ _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +
| . | . | . |   | c | . | c |   | . | . | . |   | . | . | l |
e _ + _ + _ e   + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +
PARAMETER DESCRIPTION
factor

Resizing factor

  • > 1 : larger image <-> smaller voxels
  • < 1 : smaller image <-> larger voxels

TYPE: float or list[float] DEFAULT: None

shape

Output shape

TYPE: (ndim,) list[int] DEFAULT: None

anchor
  • In cases 'c' and 'e', the volume shape is multiplied by the zoom factor (and eventually truncated), and two anchor points are used to determine the voxel size.
  • In cases 'f' and 'l', a single anchor point is used so that the voxel size is exactly divided by the zoom factor. This case with an integer factor corresponds to subslicing the volume (e.g., vol[::f, ::f, ::f]).
  • A list of anchors (one per dimension) can also be provided.

TYPE: (center, edge, first, last) DEFAULT: 'center'

interpolation

Interpolation order.

TYPE: int or sequence[int] DEFAULT: 'linear'

bound

Boundary conditions.

TYPE: BoundType or sequence[BoundType] DEFAULT: 'zero'

extrapolate

Extrapolate out-of-bound data.

TYPE: bool or int DEFAULT: True

prefilter

Apply spline pre-filter (= interpolates the input)

TYPE: bool DEFAULT: True

Source code in cassetta/layers/interpol.py
def __init__(
        self,
        factor=None,
        shape=None,
        anchor='edge',
        interpolation='linear',
        bound='zero',
        extrapolate=True,
        prefilter=True
):
    """
    Notes
    -----

    * A least one of `factor` and `shape` must be specified
    * If `anchor in ('center', 'edge')`, exactly one of `factor`
      or `shape` must be specified.
    * If `anchor in ('first', 'last')`, `factor` must be provided
      even if `shape` is specified.
    *  Because of rounding, it is in general not assured that
      `resize(resize(x, f), 1/f)` returns a tensor with the same
      shape as x.

    ```
         edge           center          first           last
    e - + - + - e   + - + - + - +   + - + - + - +   + - + - + - +
    | . | . | . |   | c | . | c |   | f | . | . |   | . | . | . |
    + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +
    | . | . | . |   | . | . | . |   | . | . | . |   | . | . | . |
    + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +
    | . | . | . |   | c | . | c |   | . | . | . |   | . | . | l |
    e _ + _ + _ e   + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +
    ```

    Parameters
    ----------
    factor : float or list[float], optional
        Resizing factor

        * `> 1` : larger image <-> smaller voxels
        * `< 1` : smaller image <-> larger voxels
    shape : (ndim,) list[int], optional
        Output shape
    anchor : {'center', 'edge', 'first', 'last'} or list
        * In cases 'c' and 'e', the volume shape is multiplied by the
          zoom factor (and eventually truncated), and two anchor points
          are used to determine the voxel size.
        * In cases 'f' and 'l', a single anchor point is used so that
          the voxel size is exactly divided by the zoom factor.
          This case with an integer factor corresponds to subslicing
          the volume (e.g., `vol[::f, ::f, ::f]`).
        * A list of anchors (one per dimension) can also be provided.
    interpolation : int or sequence[int]
        Interpolation order.
    bound : BoundType or sequence[BoundType]
        Boundary conditions.
    extrapolate : bool or int
        Extrapolate out-of-bound data.
    prefilter : bool
        Apply spline pre-filter (= interpolates the input)

    """
    super().__init__()
    self.factor = factor
    self.shape = shape
    self.anchor = anchor
    self.interpolation = interpolation
    self.bound = bound
    self.extrapolate = extrapolate
    self.prefilter = prefilter

forward

forward(input, **kwargs)

Resize an image

PARAMETER DESCRIPTION
input

Input image

TYPE: (batch, channel, *inshape)

PARAMETER DESCRIPTION
shape

Output shape. If not provided at call time, use self.shape

TYPE: (ndim,) list[int]

PARAMETER DESCRIPTION
output

Resized image

TYPE: (batch, channel, *shape)

Source code in cassetta/layers/interpol.py
def forward(self, input, **kwargs):
    """
    Resize an image

    Parameters
    ----------
    input : (batch, channel, *inshape)
        Input image

    Other Parameters
    ----------------
    shape : (ndim,) list[int], optional
        Output shape. If not provided at call time, use self.shape

    Parameters
    ----------
    output : (batch, channel, *shape)
        Resized image
    """
    options = self._options
    options.update(kwargs)
    if isinstance(options['factor'], (int, float)):
        options['factor'] = [options['factor']] * (input.ndim-2)
    if isinstance(options['shape'], (int, float)):
        options['shape'] = [options['shape']] * (input.ndim-2)
    return torch_interpol.resize(input, **options)

Restrict

Restrict(factor=None, shape=None, anchor='edge', interpolation='linear', bound='zero', reduce_sum=False)

Bases: Module

Restrict an image (adjoint of resize)

Notes
  • A least one of factor and shape must be specified
  • If anchor in ('center', 'edge'), exactly one of factor or shape must be specified.
  • If anchor in ('first', 'last'), factor must be provided even if shape is specified.
  • Because of rounding, it is in general not assured that resize(resize(x, f), 1/f) returns a tensor with the same shape as x.
    edge           center           first           last
e - + - + - e   + - + - + - +   + - + - + - +   + - + - + - +
| . | . | . |   | c | . | c |   | f | . | . |   | . | . | . |
+ _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +
| . | . | . |   | . | . | . |   | . | . | . |   | . | . | . |
+ _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +
| . | . | . |   | c | . | c |   | . | . | . |   | . | . | l |
e _ + _ + _ e   + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +
PARAMETER DESCRIPTION
factor

Resizing factor

  • > 1 : larger image <-> smaller voxels
  • < 1 : smaller image <-> larger voxels

TYPE: float or list[float] DEFAULT: None

shape

Output shape

TYPE: (ndim,) list[int] DEFAULT: None

anchor
  • In cases 'c' and 'e', the volume shape is multiplied by the zoom factor (and eventually truncated), and two anchor points are used to determine the voxel size.
  • In cases 'f' and 'l', a single anchor point is used so that the voxel size is exactly divided by the zoom factor. This case with an integer factor corresponds to subslicing the volume (e.g., vol[::f, ::f, ::f]).
  • A list of anchors (one per dimension) can also be provided.

TYPE: (center, edge, first, last) DEFAULT: 'center'

interpolation

Interpolation order.

TYPE: int or sequence[int] DEFAULT: 'linear'

bound

Boundary conditions.

TYPE: BoundType or sequence[BoundType] DEFAULT: 'zero'

reduce_sum

Do not normalize by the number of accumulated values per voxel

TYPE: bool DEFAULT: False

Source code in cassetta/layers/interpol.py
def __init__(
        self,
        factor=None,
        shape=None,
        anchor='edge',
        interpolation='linear',
        bound='zero',
        reduce_sum=False,
):
    """
    Notes
    -----

    * A least one of `factor` and `shape` must be specified
    * If `anchor in ('center', 'edge')`, exactly one of `factor`
      or `shape` must be specified.
    * If `anchor in ('first', 'last')`, `factor` must be provided
      even if `shape` is specified.
    *  Because of rounding, it is in general not assured that
      `resize(resize(x, f), 1/f)` returns a tensor with the same
      shape as x.

    ```
        edge           center           first           last
    e - + - + - e   + - + - + - +   + - + - + - +   + - + - + - +
    | . | . | . |   | c | . | c |   | f | . | . |   | . | . | . |
    + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +
    | . | . | . |   | . | . | . |   | . | . | . |   | . | . | . |
    + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +
    | . | . | . |   | c | . | c |   | . | . | . |   | . | . | l |
    e _ + _ + _ e   + _ + _ + _ +   + _ + _ + _ +   + _ + _ + _ +
    ```

    Parameters
    ----------
    factor : float or list[float], optional
        Resizing factor

        * `> 1` : larger image <-> smaller voxels
        * `< 1` : smaller image <-> larger voxels
    shape : (ndim,) list[int], optional
        Output shape
    anchor : {'center', 'edge', 'first', 'last'} or list
        * In cases 'c' and 'e', the volume shape is multiplied by the
          zoom factor (and eventually truncated), and two anchor points
          are used to determine the voxel size.
        * In cases 'f' and 'l', a single anchor point is used so that
          the voxel size is exactly divided by the zoom factor.
          This case with an integer factor corresponds to subslicing
          the volume (e.g., `vol[::f, ::f, ::f]`).
        * A list of anchors (one per dimension) can also be provided.
    interpolation : int or sequence[int]
        Interpolation order.
    bound : BoundType or sequence[BoundType]
        Boundary conditions.
    reduce_sum : bool
        Do not normalize by the number of accumulated values per voxel

    """
    super().__init__()
    self.factor = factor
    self.shape = shape
    self.anchor = anchor
    self.interpolation = interpolation
    self.bound = bound
    self.reduce_sum = reduce_sum

forward

forward(input, **kwargs)

Restrict an image

PARAMETER DESCRIPTION
input

Input image

TYPE: (batch, channel, *inshape)

PARAMETER DESCRIPTION
shape

Output shape. If not provided at call time, use self.shape

TYPE: (ndim,) list[int]

PARAMETER DESCRIPTION
output

Restricted image

TYPE: (batch, channel, *shape)

Source code in cassetta/layers/interpol.py
def forward(self, input, **kwargs):
    """
    Restrict an image

    Parameters
    ----------
    input : (batch, channel, *inshape)
        Input image

    Other Parameters
    ----------------
    shape : (ndim,) list[int], optional
        Output shape. If not provided at call time, use self.shape

    Parameters
    ----------
    output : (batch, channel, *shape)
        Restricted image
    """
    options = self._options
    options.update(kwargs)
    return torch_interpol.restrict(input, **options)

ResizeFlow

ResizeFlow(factor=None, shape=None, anchor='edge', interpolation='linear', bound='dft', extrapolate=True, prefilter=True)

Bases: Resize

Resize (interpolate) a displacement field

Source code in cassetta/layers/interpol.py
def __init__(
        self,
        factor=None,
        shape=None,
        anchor='edge',
        interpolation='linear',
        bound='dft',
        extrapolate=True,
        prefilter=True,
):
    super().__init__(
        factor=factor,
        shape=shape,
        anchor=anchor,
        interpolation=interpolation,
        bound=bound,
        extrapolate=extrapolate,
        prefilter=prefilter,
    )

forward

forward(flow, **kwargs)

Resize a displacement field. The magnitude of the displacements gets rescaled as well.

PARAMETER DESCRIPTION
flow

Input displacement field

TYPE: (batch, ndim, *inshape)

PARAMETER DESCRIPTION
shape

Output shape. If not provided at call time, use self.shape

TYPE: (ndim,) list[int]

PARAMETER DESCRIPTION
output

Resized displacement field

TYPE: (batch, ndim, *shape)

Source code in cassetta/layers/interpol.py
def forward(self, flow, **kwargs):
    """
    Resize a displacement field. The magnitude of the displacements
    gets rescaled as well.

    Parameters
    ----------
    flow : (batch, ndim, *inshape)
        Input displacement field

    Other Parameters
    ----------------
    shape : (ndim,) list[int], optional
        Output shape. If not provided at call time, use self.shape

    Parameters
    ----------
    output : (batch, ndim, *shape)
        Resized displacement field
    """
    ndim = flow.shape[1]
    ishape = flow.shape[2:]
    iflow = super().forward(flow)
    oflow = torch.empty_like(iflow)
    oshape = oflow.shape[2:]
    anchor = self.anchor[0].lower()
    if anchor == 'c':
        for d in range(ndim):
            oflow[:, d] = iflow[:, d] * ((oshape[d] - 1) / (ishape[d] - 1))
    elif anchor == 'e':
        for d in range(ndim):
            oflow[:, d] = iflow[:, d] * (oshape[d] / ishape[d])
    else:
        factor = ensure_list(self.factor, ndim)
        for d in range(ndim):
            oflow[:, d] = iflow[:, d] * factor[d]
    return oflow

ValueToCoeff

ValueToCoeff(interpolation='linear', bound='zero')

Bases: Module

Compute spline coefficients from values

PARAMETER DESCRIPTION
interpolation

Interpolation order.

TYPE: int or sequence[int] DEFAULT: 'linear'

bound

Boundary conditions.

TYPE: BoundType or sequence[BoundType] DEFAULT: 'zero'

Source code in cassetta/layers/interpol.py
def __init__(
        self,
        interpolation='linear',
        bound='zero',
):
    """
    Parameters
    ----------
    interpolation : int or sequence[int]
        Interpolation order.
    bound : BoundType or sequence[BoundType]
        Boundary conditions.
    """
    super().__init__()
    self.interpolation = interpolation if interpolation != 'fd' else 1
    self.bound = bound

forward

forward(input)
PARAMETER DESCRIPTION
input

Input image of values

TYPE: (batch, channel, *shape) tensor

RETURNS DESCRIPTION
output

Input image of spline coefficients

TYPE: (batch, channel, *shape) tensor

Source code in cassetta/layers/interpol.py
def forward(self, input):
    """
    Parameters
    ----------
    input : (batch, channel, *shape) tensor
        Input image of values

    Returns
    -------
    output : (batch, channel, *shape) tensor
        Input image of spline coefficients
    """
    ndim = self.input.ndim - 2
    return torch_interpol.spline_coeff_nd(input, **self._options, dim=ndim)

CoeffToValue

CoeffToValue(interpolation='linear', bound='zero')

Bases: Module

Compute values from spline coefficients

PARAMETER DESCRIPTION
interpolation

Interpolation order.

TYPE: int or sequence[int] DEFAULT: 'linear'

bound

Boundary conditions.

TYPE: BoundType or sequence[BoundType] DEFAULT: 'zero'

Source code in cassetta/layers/interpol.py
def __init__(
        self,
        interpolation='linear',
        bound='zero',
):
    """
    Parameters
    ----------
    interpolation : int or sequence[int]
        Interpolation order.
    bound : BoundType or sequence[BoundType]
        Boundary conditions.
    """
    super().__init__()
    self.interpolation = interpolation if interpolation != 'fd' else 1
    self.bound = bound

forward

forward(input)
PARAMETER DESCRIPTION
input

Input image of spline coefficients

TYPE: (batch, channel, *shape) tensor

RETURNS DESCRIPTION
output

Input image of values

TYPE: (batch, channel, *shape) tensor

Source code in cassetta/layers/interpol.py
def forward(self, input):
    """
    Parameters
    ----------
    input : (batch, channel, *shape) tensor
        Input image of spline coefficients

    Returns
    -------
    output : (batch, channel, *shape) tensor
        Input image of values
    """
    grid = torch_interpol.identity_grid(
        input.shape[2:], dtype=input.dtype, device=input.device)
    return torch_interpol.grid_pull(input, grid, **self._options)

FlowExp

FlowExp(nsteps=8, interpolation='linear', bound='dft', extrapolate=False, coeff=False)

Bases: Module

Exponentiate a stationary velocity field

PARAMETER DESCRIPTION
nsteps

Number of scaling and squaring steps

TYPE: int DEFAULT: 8

interpolation

Interpolation order.

TYPE: int or sequence[int] DEFAULT: 'linear'

bound

Boundary conditions.

TYPE: BoundType or sequence[BoundType] DEFAULT: 'dft'

extrapolate

Extrapolate out-of-bound data.

TYPE: bool or int DEFAULT: False

coeff

If True, the input velocity image contains spline coefficients, and spline coefficients will also be returned. If False, the input velocity image contains actual values, and values will also be returned.

TYPE: bool DEFAULT: False

Source code in cassetta/layers/interpol.py
def __init__(
        self,
        nsteps=8,
        interpolation='linear',
        bound='dft',
        extrapolate=False,
        coeff=False,
):
    """
    Parameters
    ----------
    nsteps : int
        Number of scaling and squaring steps
    interpolation : int or sequence[int]
        Interpolation order.
    bound : BoundType or sequence[BoundType]
        Boundary conditions.
    extrapolate : bool or int
        Extrapolate out-of-bound data.
    coeff : bool
        If True, the input velocity image contains spline coefficients,
        and spline coefficients will also be returned.
        If False, the input velocity image contains actual values,
        and values will also be returned.

    """
    super().__init__()
    self.nsteps = nsteps
    self.interpolation = interpolation
    self.bound = bound
    self.extrapolate = extrapolate
    self.coeff = coeff

forward

forward(flow)

Exponentiate the SVF

PARAMETER DESCRIPTION
flow

Stationary velocity field

TYPE: (batch, ndim, *shape) tensor

Source code in cassetta/layers/interpol.py
def forward(self, flow):
    """
    Exponentiate the SVF

    Parameters
    ----------
    flow : (batch, ndim, *shape) tensor
        Stationary velocity field
    """
    # helpers
    to_coeff = ValueToCoeff(
        interpolation=self.interpolation, bound=self.bound)
    from_coeff = CoeffToValue(
        interpolation=self.interpolation, bound=self.bound)
    compose = FlowPull(
        interpolation=self.interpolation, bound=self.bound,
        extrapolate=self.extrapolate, prefilter=False)
    flow = flow / 2**self.nsteps
    # init
    if not self.coeff:
        coeff = to_coeff(flow)
    else:
        coeff = flow
        flow = from_coeff(coeff)
    # scale and square
    for _ in range(self.nsteps):
        flow = compose(coeff, flow)
        coeff = to_coeff(flow)
    # final
    return coeff if self.coeff else flow

FlowMomentum

FlowMomentum(absolute=0, membrane=0, bending=0, div=0, shears=0, norm=True, interpolation='linear', bound='dft')

Bases: Module

Compute the momentum of a displacement field

PARAMETER DESCRIPTION
absolute

Penalty on absolute displacement

TYPE: float DEFAULT: 0

membrane

Penalty on first derivatives

TYPE: float DEFAULT: 0

bending

Penalty on second derivatives

TYPE: float DEFAULT: 0

div

Penalty on volume changes

TYPE: float DEFAULT: 0

shears

Penalty on shears

TYPE: float DEFAULT: 0

norm

If True, compute the average energy across the field of view. Otherwise, compute the sum (integral) of the energy across the FOV.

TYPE: bool DEFAULT: True

interpolation

Spline order

TYPE: int DEFAULT: 'linear'

bound

Boundary conditions

TYPE: bound_like DEFAULT: 'dft'

Source code in cassetta/layers/interpol.py
def __init__(
        self,
        absolute=0,
        membrane=0,
        bending=0,
        div=0,
        shears=0,
        norm=True,
        interpolation='linear',
        bound='dft',
):
    """
    Parameters
    ----------
    absolute : float
        Penalty on absolute displacement
    membrane : float
        Penalty on first derivatives
    bending : float
        Penalty on second derivatives
    div : float
        Penalty on volume changes
    shears : float
        Penalty on shears
    norm : bool
        If True, compute the average energy across the field of view.
        Otherwise, compute the sum (integral) of the energy across the FOV.
    interpolation : int
        Spline order
    bound : bound_like
        Boundary conditions
    """
    if not hasattr(torch_interpol, 'flowmom'):
        raise NotImplementedError(
            'FlowMomentum requires torch-interpol >= 1')
    super().__init__()
    self.absolute = absolute
    self.membrane = membrane
    self.bending = bending
    self.div = div
    self.shears = shears
    self.norm = norm
    self.interpolation = interpolation
    self.bound = bound

forward

forward(flow)
PARAMETER DESCRIPTION
flow

Spline coefficients of a displacement field

TYPE: (batch, ndim, *shape) tensor

RETURNS DESCRIPTION
mom

Momentum field

TYPE: (batch, ndim, *shape) tensor

Source code in cassetta/layers/interpol.py
def forward(self, flow):
    """
    Parameters
    ----------
    flow : (batch, ndim, *shape) tensor
        Spline coefficients of a displacement field

    Returns
    -------
    mom : (batch, ndim, *shape) tensor
        Momentum field
    """
    flow = torch.movedim(flow, 1, -1)
    flow = torch_interpol.flowmom(flow, **self._options)
    flow = torch.movedim(flow, -1, 1)
    return flow

FlowLoss

FlowLoss(absolute=0, membrane=0, bending=0, div=0, shears=0, norm=True, interpolation='linear', bound='dft')

Bases: FlowMomentum

Compute the regularization loss of a displacement field

PARAMETER DESCRIPTION
absolute

Penalty on absolute displacement

TYPE: float DEFAULT: 0

membrane

Penalty on first derivatives

TYPE: float DEFAULT: 0

bending

Penalty on second derivatives

TYPE: float DEFAULT: 0

div

Penalty on volume changes

TYPE: float DEFAULT: 0

shears

Penalty on shears

TYPE: float DEFAULT: 0

norm

If True, compute the average energy across the field of view. Otherwise, compute the sum (integral) of the energy across the FOV.

TYPE: bool DEFAULT: True

interpolation

Spline order

TYPE: int DEFAULT: 'linear'

bound

Boundary conditions

TYPE: bound_like DEFAULT: 'dft'

Source code in cassetta/layers/interpol.py
def __init__(
        self,
        absolute=0,
        membrane=0,
        bending=0,
        div=0,
        shears=0,
        norm=True,
        interpolation='linear',
        bound='dft',
):
    """
    Parameters
    ----------
    absolute : float
        Penalty on absolute displacement
    membrane : float
        Penalty on first derivatives
    bending : float
        Penalty on second derivatives
    div : float
        Penalty on volume changes
    shears : float
        Penalty on shears
    norm : bool
        If True, compute the average energy across the field of view.
        Otherwise, compute the sum (integral) of the energy across the FOV.
    interpolation : int
        Spline order
    bound : bound_like
        Boundary conditions
    """
    if not hasattr(torch_interpol, 'flowmom'):
        raise NotImplementedError(
            'FlowMomentum requires torch-interpol >= 1')
    super().__init__()
    self.absolute = absolute
    self.membrane = membrane
    self.bending = bending
    self.div = div
    self.shears = shears
    self.norm = norm
    self.interpolation = interpolation
    self.bound = bound

forward

forward(flow)
PARAMETER DESCRIPTION
flow

Spline coefficients of a displacement field

TYPE: (batch, ndim, *shape) tensor

RETURNS DESCRIPTION
loss

loss -- averaged across batch elements

TYPE: scalar tensor

Source code in cassetta/layers/interpol.py
def forward(self, flow):
    """
    Parameters
    ----------
    flow : (batch, ndim, *shape) tensor
        Spline coefficients of a displacement field

    Returns
    -------
    loss : scalar tensor
        loss -- averaged across batch elements
    """
    mom = super().forward(flow)
    nbatch = len(flow)
    return mom.flatten().dot(flow.flatten()) / nbatch

SplineUp2

SplineUp2(interpolation='linear', bound='dft')

Bases: Module

MSE-minmizing upsampling of a displacement field -- by a factor 2

PARAMETER DESCRIPTION
interpolation

Spline order

TYPE: int DEFAULT: 'linear'

bound

Boundary conditions

TYPE: bound_like DEFAULT: 'dft'

Source code in cassetta/layers/interpol.py
def __init__(
        self,
        interpolation='linear',
        bound='dft',
):
    """
    Parameters
    ----------
    interpolation : int
        Spline order
    bound : bound_like
        Boundary conditions
    """
    if not hasattr(torch_interpol, 'flow_upsample2'):
        raise NotImplementedError(
            'SplineUp2 requires torch-interpol >= 1')
    super().__init__()
    self.interpolation = interpolation
    self.bound = bound

forward

forward(flow)
PARAMETER DESCRIPTION
flow

Spline coefficients of a displacement field

TYPE: (batch, ndim, *shape) tensor

RETURNS DESCRIPTION
flow2

Spline coefficients of a larger displacement field

TYPE: (batch, ndim, *shape_twice) tensor

Source code in cassetta/layers/interpol.py
def forward(self, flow):
    """
    Parameters
    ----------
    flow : (batch, ndim, *shape) tensor
        Spline coefficients of a displacement field

    Returns
    -------
    flow2 : (batch, ndim, *shape_twice) tensor
        Spline coefficients of a larger displacement field
    """
    flow = torch.movedim(flow, 1, -1)
    flow = torch_interpol.flow_upsample2(flow, **self._options)
    flow = torch.movedim(flow, -1, 1)
    return flow

cassetta.layers.linear

Linear

Linear(inp_channels, out_channels, bias=True, dim=1, *, device=None, dtype=None)

Bases: Module

Linear layer.

We reimplement nn.Linear so that the dimension that it operates upon is the second (by default) instead of the last. This makes it more compatible with vision applications.

PARAMETER DESCRIPTION
inp_channels

Number of input channels

TYPE: int

out_channels

Number of output channels

TYPE: int DEFAULT: `inp_channels`

bias

Include a bias term

TYPE: bool DEFAULT: True

dim

Dimension along which to operate

TYPE: int DEFAULT: 1

PARAMETER DESCRIPTION
device

Weight's device

TYPE: device

dtype

Weight's data type

TYPE: device

Source code in cassetta/layers/linear.py
def __init__(
        self,
        inp_channels: int,
        out_channels: Optional[int],
        bias: bool = True,
        dim: int = 1,
        *,
        device: Optional[DeviceType] = None,
        dtype: Optional[torch.dtype] = None,
) -> None:
    """
    Parameters
    ----------
    inp_channels : int
        Number of input channels
    out_channels : int, default=`inp_channels`
        Number of output channels
    bias : bool
        Include a bias term
    dim : int
        Dimension along which to operate

    Other Parameters
    ----------------
    device : torch.device
        Weight's device
    dtype : torch.device
        Weight's data type

    """
    factory_kwargs = {'device': device, 'dtype': dtype}
    super().__init__()
    self.inp_channels = inp_channels
    self.out_channels = out_channels or inp_channels
    self.dim = dim
    self.weight = nn.Parameter(
        torch.empty((out_channels, inp_channels), **factory_kwargs)
    )
    if bias:
        self.bias = nn.Parameter(
            torch.empty(out_channels, **factory_kwargs)
        )
    else:
        self.register_parameter('bias', None)
    self.reset_parameters()

forward

forward(inp)
PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, Cinp, *spatial) tensor

RETURNS DESCRIPTION
out

Output tensor

TYPE: (B, Cout, *spatial) tensor

Source code in cassetta/layers/linear.py
def forward(self, inp: Tensor) -> Tensor:
    """
    Parameters
    ----------
    inp : (B, Cinp, *spatial) tensor
        Input tensor

    Returns
    -------
    out : (B, Cout, *spatial) tensor
        Output tensor
    """
    inp = inp.movedim(self.dim, -1)
    out = F.linear(input, self.weight, self.bias)
    out = out.movedim(-1, self.dim)
    return out

LazyLinear

LazyLinear(out_channels, bias=True, dim=1, *, device=None, dtype=None)

Bases: LazyModuleMixin, Linear

A linear layer whose weights get allocated lazily, on first call.

This allows the number of input channels to be automatically determined at run time.

We reimplement nn.LazyLinear so that the dimension that it operates upon is the second (by default) instead of the last. This makes it more compatible with vision applications.

We also allow the number of output channels to be set lazily.

PARAMETER DESCRIPTION
out_channels

Number of output channels. If a function, takes the number of input channels and returns the number of output channels.

TYPE: int or callable DEFAULT: `inp_channels`

bias

Include a bias term

TYPE: bool DEFAULT: True

dim

Dimension along which to operate

TYPE: int DEFAULT: 1

PARAMETER DESCRIPTION
device

Weight's device

TYPE: device

dtype

Weight's data type

TYPE: device

Source code in cassetta/layers/linear.py
def __init__(
        self,
        out_channels: Optional[Union[int, Callable[[int], int]]],
        bias: bool = True,
        dim: int = 1,
        *,
        device: Optional[DeviceType] = None,
        dtype: Optional[torch.dtype] = None,
) -> None:
    """
    Parameters
    ----------
    out_channels : int or callable, default=`inp_channels`
        Number of output channels.
        If a function, takes the number of input channels and
        returns the number of output channels.
    bias : bool
        Include a bias term
    dim : int
        Dimension along which to operate

    Other Parameters
    ----------------
    device : torch.device
        Weight's device
    dtype : torch.device
        Weight's data type

    """
    factory_kwargs = {'device': device, 'dtype': dtype}
    # bias is hardcoded to False to avoid creating tensor
    # that will soon be overwritten.
    super().__init__(0, 0, False, dim=dim)
    self.weight = nn.UninitializedParameter(**factory_kwargs)
    self.out_channels = out_channels
    if bias:
        self.bias = nn.UninitializedParameter(**factory_kwargs)

cassetta.layers.norm

BatchNorm

BatchNorm(channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)

Bases: _BatchNorm

PARAMETER DESCRIPTION
nb_chanels

Number of input channels.

TYPE: int

eps

Value added to the denominator for numerical stability.

TYPE: float DEFAULT: 1e-05

momentum

The value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average).

TYPE: float DEFAULT: 0.1

affine

Use learnable affine parameters.

TYPE: bool DEFAULT: True

track_running_stats

Track the running mean and variance. If False, this module does not track such statistics, and initializes statistics buffers as None. When these buffers are None, this module always uses batch statistics in both training and eval modes.

TYPE: bool DEFAULT: True

Source code in cassetta/layers/norm.py
def __init__(
    self,
    channels: int,
    eps: float = 1e-5,
    momentum: float = 0.1,
    affine: bool = True,
    track_running_stats: bool = True,
    device: Optional[DeviceType] = None,
    dtype: Optional[DataType] = None,
) -> None:
    """
    Parameters
    ----------
    nb_chanels : int
        Number of input channels.
    eps : float
        Value added to the denominator for numerical stability.
    momentum : float
        The value used for the `running_mean` and `running_var`
        computation. Can be set to `None` for cumulative moving average
        (i.e. simple average).
    affine : bool
        Use learnable affine parameters.
    track_running_stats : bool
        Track the running mean and variance. If `False`,
        this module does not track such statistics, and initializes
        statistics buffers as `None`. When these buffers are `None`,
        this module always uses batch statistics in both training
        and eval modes.
    """
    dtype = to_torch_dtype(dtype)
    super().__init__(
        channels,
        eps=eps,
        momentum=momentum,
        affine=affine,
        track_running_stats=track_running_stats,
        dtype=dtype,
        device=device,
    )

InstanceNorm

InstanceNorm(channels, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False, device=None, dtype=None)

Bases: _InstanceNorm

PARAMETER DESCRIPTION
nb_chanels

Number of input channels.

TYPE: int

eps

Value added to the denominator for numerical stability.

TYPE: float DEFAULT: 1e-05

momentum

The value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average).

TYPE: float DEFAULT: 0.1

affine

Use learnable affine parameters.

TYPE: bool DEFAULT: False

track_running_stats

Track the running mean and variance. If False, this module does not track such statistics, and initializes statistics buffers as None. When these buffers are None, this module always uses batch statistics in both training and eval modes.

TYPE: bool DEFAULT: False

Source code in cassetta/layers/norm.py
def __init__(
    self,
    channels: int,
    eps: float = 1e-5,
    momentum: float = 0.1,
    affine: bool = False,
    track_running_stats: bool = False,
    device: Optional[DeviceType] = None,
    dtype: Optional[DataType] = None,
) -> None:
    """
    Parameters
    ----------
    nb_chanels : int
        Number of input channels.
    eps : float
        Value added to the denominator for numerical stability.
    momentum : float
        The value used for the `running_mean` and `running_var`
        computation. Can be set to `None` for cumulative moving average
        (i.e. simple average).
    affine : bool
        Use learnable affine parameters.
    track_running_stats : bool
        Track the running mean and variance. If `False`,
        this module does not track such statistics, and initializes
        statistics buffers as `None`. When these buffers are `None`,
        this module always uses batch statistics in both training
        and eval modes.
    """
    dtype = to_torch_dtype(dtype)
    super().__init__(
        channels,
        eps=eps,
        momentum=momentum,
        affine=affine,
        track_running_stats=track_running_stats,
        dtype=dtype,
        device=device,
    )

forward

forward(inp)
PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, channels, *size) tensor

RETURNS DESCRIPTION
out

Output tensor

TYPE: (B, channels, *size) tensor

Source code in cassetta/layers/norm.py
def forward(self, inp: Tensor) -> Tensor:
    """
    Parameters
    ----------
    inp : (B, channels, *size) tensor
        Input tensor

    Returns
    -------
    out : (B, channels, *size) tensor
        Output tensor
    """
    if self.affine and input.size(1) != self.channels:
        raise ValueError("Wrong number of channels")
    return self._apply_instance_norm(inp)

LayerNorm

LayerNorm(channels, eps=1e-05, affine=True, device=None, dtype=None)

Bases: GroupNorm

PARAMETER DESCRIPTION
nb_chanels

Number of input channels.

TYPE: int

eps

Value added to the denominator for numerical stability.

TYPE: float DEFAULT: 1e-05

affine

Use learnable affine parameters.

TYPE: bool DEFAULT: True

Source code in cassetta/layers/norm.py
def __init__(
    self,
    channels: int,
    eps: float = 1e-5,
    affine: bool = True,
    device: Optional[DeviceType] = None,
    dtype: Optional[DataType] = None
) -> None:
    """
    Parameters
    ----------
    nb_chanels : int
        Number of input channels.
    eps : float
        Value added to the denominator for numerical stability.
    affine : bool
        Use learnable affine parameters.
    """
    dtype = to_torch_dtype(dtype)
    super().__init__(
        channels, channels, eps, affine, device, dtype
    )

make_norm

make_norm(norm, channels, affine=True, **kwargs)

Instantiate a normalization module.

To be accepted in a nn.Sequential module or in a nn.ModuleList, a norm must be a nn.Module. This function takes other forms of "norm parameters" that are typically passed to the constructor of larger models, and generate the corresponding instantiated Module.

A norm-like value can be a nn.Module subclass, which is then instantiated, or a callable function that returns an instantiated Module. It can also be the name of a none normalization: "batch", "layer", or "instance".

It is useful to accept all these cases as they allow to either:

  • have a learnable norm specific to this module
  • have a learnable norm shared with other modules
  • have a non-learnable norm
PARAMETER DESCRIPTION
norm

An already instantiated nn.Module, or a nn.Module subclass, or a callable that retgurns an instantiated nn.Module, or the name of a known normalization: "batch" "layer", or "instance".

TYPE: NormType

channels

Number of channels

TYPE: int

affine

Include a learnable affine transform.

TYPE: bool DEFAULT: True

kwargs

Additional parameters to pass to the constructor or function.

TYPE: dict DEFAULT: {}

RETURNS DESCRIPTION
norm

An instantiated nn.Module.

TYPE: Module

Source code in cassetta/layers/norm.py
def make_norm(
    norm: NormType,
    channels: int,
    affine: bool = True,
    **kwargs
):
    """
    Instantiate a normalization module.

    To be accepted in a `nn.Sequential` module or in a `nn.ModuleList`,
    a norm **must** be a `nn.Module`. This function takes other
    forms of "norm parameters" that are typically passed to the
    constructor of larger models, and generate the corresponding
    instantiated Module.

    A norm-like value can be a `nn.Module` subclass, which is
    then instantiated, or a callable function that returns an
    instantiated Module. It can also be the name of a none normalization:
    `"batch"`, `"layer"`, or `"instance"`.

    It is useful to accept all these cases as they allow to either:

    * have a learnable norm specific to this module
    * have a learnable norm shared with other modules
    * have a non-learnable norm

    Parameters
    ----------
    norm : NormType
        An already instantiated `nn.Module`, or a `nn.Module` subclass,
        or a callable that retgurns an instantiated `nn.Module`, or the
        name of a known normalization: `"batch"` `"layer"`, or `"instance"`.
    channels : int
        Number of channels
    affine : bool
        Include a learnable affine transform.
    kwargs : dict
        Additional parameters to pass to the constructor or function.

    Returns
    -------
    norm : Module
        An instantiated `nn.Module`.
    """
    kwargs['affine'] = affine

    if not norm:
        return None

    if isinstance(norm, nn.Module):
        return norm

    if norm is True:
        norm = 'batch'

    if isinstance(norm, int):
        return nn.GroupNorm(norm, channels, **kwargs)

    if isinstance(norm, str):
        norm = norm.lower()
        if 'instance' in norm:
            norm = InstanceNorm
        elif 'layer' in norm:
            norm = LayerNorm
        elif 'batch' in norm:
            norm = BatchNorm
        else:
            raise ValueError(f'Unknown normalization "{norm}"')

    norm = norm(channels, **kwargs)

    if not isinstance(norm, nn.Module):
        raise ValueError('Normalization did not instantiate a Module')
    return norm

cassetta.layers.simple

Cat

Cat(dim=1)

Bases: Module

Concatenate tensors

Diagram

flowchart LR
    subgraph Inputs
        i1["C<sub>1</sub>"]:::i
        i2["C<sub>2</sub>"]:::i
        i3["..."]:::n
        i4["C<sub>N</sub>"]:::i
    end
    i1 & i2 & i3 & i4 ---z(("c")):::d--->
    o["C<sub>1</sub> + C<sub>2</sub> + ... + C<sub>N</sub>"]:::o

    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef n fill:none,stroke:none;
PARAMETER DESCRIPTION
dim

Dimension to concatenate. Default is 1, the channel dimension.

TYPE: int DEFAULT: 1

Source code in cassetta/layers/simple.py
def __init__(self, dim=1):
    """
    Parameters
    ----------
    dim : int
        Dimension to concatenate. Default is 1, the channel dimension.
    """
    super().__init__()
    self.dim = dim

forward

forward(*inputs)
PARAMETER DESCRIPTION
*inputs

A series of tensors

TYPE: tensor DEFAULT: ()

RETURNS DESCRIPTION
output

A single concatenated tensor

TYPE: tensor

Source code in cassetta/layers/simple.py
def forward(self, *inputs):
    """
    Parameters
    ----------
    *inputs : tensor
        A series of tensors

    Returns
    -------
    output : tensor
        A single concatenated tensor
    """
    return torch.cat(inputs, self.dim)

Add

Bases: Module

Add tensors

Diagram

flowchart LR
    subgraph Inputs
        i1["C"]:::i
        i2["C"]:::i
        i3["..."]:::n
        i4["C"]:::i
    end
    i1 & i2 & i3 & i4 ---z(("+")):::d--->
    o["C"]:::o

    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef n fill:none,stroke:none;

forward

forward(*inputs)
PARAMETER DESCRIPTION
*inputs

A series of tensors

TYPE: tensor DEFAULT: ()

RETURNS DESCRIPTION
output

A single summed tensor

TYPE: tensor

Source code in cassetta/layers/simple.py
def forward(self, *inputs):
    """
    Parameters
    ----------
    *inputs : tensor
        A series of tensors

    Returns
    -------
    output : tensor
        A single summed tensor
    """
    return sum(inputs)

Split

Split(nb_chunks=2, dim=1)

Bases: Module

Split tensor

Diagram

flowchart LR
    subgraph Outputs
        o1["C"]:::o
        o2["C"]:::o
        o3["..."]:::n
        o4["C"]:::o
    end
    i["NxC"]:::i ---z(("s")):::d---> o1 & o2 & o3 & o4

    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef n fill:none,stroke:none;
PARAMETER DESCRIPTION
nb_chunks

Number of output tensors

TYPE: int DEFAULT: 2

dim

Dimension to chunk. Default is 1, the channel dimension.

TYPE: int DEFAULT: 1

Source code in cassetta/layers/simple.py
def __init__(self, nb_chunks=2, dim=1):
    """
    Parameters
    ----------
    nb_chunks : int
        Number of output tensors
    dim : int
        Dimension to chunk. Default is 1, the channel dimension.
    """
    super().__init__()
    self.dim = dim
    self.nb_chunks = nb_chunks

forward

forward(input)
PARAMETER DESCRIPTION
input

The tensor to chunk

TYPE: tensor

RETURNS DESCRIPTION
output

Tencor chunks

TYPE: list[tensor]

Source code in cassetta/layers/simple.py
def forward(self, input):
    """
    Parameters
    ----------
    input : tensor
        The tensor to chunk

    Returns
    -------
    output : list[tensor]
        Tencor chunks
    """
    return torch.tensor_split(input, self.nb_chunks, dim=self.dim)

DoNothing

Bases: Module

A layer that does nothing

Diagram

flowchart LR
    i["C"]:::i ---> o["C"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;

MoveDim

MoveDim(src, dst)

Bases: Module

Move dimension in a tensor

Source code in cassetta/layers/simple.py
def __init__(self, src, dst):
    super().__init__()
    self.src = src
    self.dst = dst

Hadamard

Bases: Module

Reparameterize tensors using the Hadamard transform: (x, y) -> (x + y, x - y)

Diagram

flowchart LR
    x["C"]:::i
    y["C"]:::i
    x & y ---plus(("+")):::d---> oplus["C"]:::o
    x & y ---minus(("-")):::d---> ominus["C"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef n fill:none,stroke:none;
flowchart LR
    inp["2xC"]:::i ---split(("s")):::d---> x["C"] & y["C"]
    x & y ---plus(("+")):::d---> oplus["C"]
    x & y ---minus(("-")):::d---> ominus["C"]
    oplus & ominus ---cat(("c")):::d---> o["2xC"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef n fill:none,stroke:none;

forward

forward(x, y=None)

Note

This layer can be applied to a single tensor, or to two tensors.

  • If two tensors are provided, their Hadamard transform is computed, and two tensors are returned.
  • If a single tensor is provided, it is split into two chunks, their Hadamard transform is computed, and the resulting chunks are concatenated and returned.
PARAMETER DESCRIPTION
x

One or two tensors

TYPE: (B, C, *shape) tensor

y

One or two tensors

TYPE: (B, C, *shape) tensor

RETURNS DESCRIPTION
hx, hy : (B, C, *shape) tensor

One or two transformedtensors

Source code in cassetta/layers/simple.py
def forward(self, x, y=None):
    """

    !!! note
        This layer can be applied to a single tensor, or to two tensors.

        * If two tensors are provided, their Hadamard transform is
          computed, and two tensors are returned.
        * If a single tensor is provided, it is split into two chunks,
          their Hadamard transform is computed, and the resulting chunks
          are concatenated and returned.

    Parameters
    ----------
    x, y : (B, C, *shape) tensor
        One or two tensors

    Returns
    -------
    hx, hy : (B, C, *shape) tensor
        One or two transformedtensors
    """
    if y is None:
        x, y = Split()(x)
        return Cat()(x + y, x - y)
    else:
        return x + y, x - y

ModuleSum

Bases: ModuleList

Apply modules in parallel and sum their outputs.

Diagram

flowchart LR
    subgraph nb_blocks
        2("Block 1"):::w  --> 3["C"]
        4("Block 2"):::w  --> 5["C"]
        6("..."):::n
        8("Block N"):::w  --> 9["C"]
    end
    1["C"]:::i --- 2 & 4 & 6 & 8
    3 & 5 & 6 & 9  --- 10(("+")):::d
    10 --> 11["C"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef n fill:none,stroke:none;

The output of all modules must have the same shape

forward

forward(inp)
PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, channels, *size) tensor

RETURNS DESCRIPTION
out

Output tensor

TYPE: (B, channels, *size) tensor

Source code in cassetta/layers/simple.py
def forward(self, inp):
    """
    Parameters
    ----------
    inp : (B, channels, *size) tensor
        Input tensor

    Returns
    -------
    out : (B, channels, *size) tensor
        Output tensor
    """
    out = 0
    for layer in self:
        out += layer(inp)
    return out

ModuleGroup

ModuleGroup(blocks, residual=False, skip=0)

Bases: Sequential

Multiple layers stacked together, eventually with residual connections.

Diagram

flowchart LR
    subgraph nb_blocks
        2("Block 1"):::w  --> 3["C"] ---
        4("Block 2"):::w  --> 5["C"] ---
        6("..."):::n      --> 7["C"] ---
        8("Block N"):::w
    end
    1["C"]:::i --- 2
    8 ---> 9["C"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef n fill:none,stroke:none;
flowchart LR
    subgraph nb_blocks
        2("Block 1"):::w  --> 3["C"] ---
        4(("+")):::d      --> 5["C"] ---
        6("Block 2"):::w  --> 7["C"] ---
        8(("+")):::d      --> 9["C"] ---
        10("..."):::n     --> 11["C"] ---
        12("Block N"):::w --> 13["C"] ---
        14(("+")):::d
    end
    1["C"]:::i --- 2
    1 --- 4
    5 --- 8
    11 --- 14
    14 ---> 15["C"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef n fill:none,stroke:none;
flowchart LR
    subgraph nb_blocks
        2("Block 1"):::w  --> 3["C"] ---
        4(("+")):::d      --> 5["C"] ---
        6("Block 2"):::w  --> 7["C"] ---
        8(("+")):::d      --> 9["C"] ---
        10("..."):::n     --> 11["C"] ---
        12("Block N"):::w --> 13["C"] ---
        14(("+")):::d
    end
    1["C+S"]:::i --- 2
    1 --- split(("s")):::d --> c["C"] & s["S"]
    c --- 4
    s --- void[" "]:::n
    5 --- 8
    11 --- 14
    14 ---> 15["C"]:::o
    classDef i fill:honeydew,stroke:lightgreen;
    classDef o fill:mistyrose,stroke:lightpink;
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef n fill:none,stroke:none;
    linkStyle 17 stroke:none;

The recurrent variant shares weights across blocks

The number of channels should be preserved throughout

The spatial size should be preserved throughout

PARAMETER DESCRIPTION
blocks

Number of blocks

TYPE: list[Module]

residual

Use residual connections between blocks

TYPE: bool DEFAULT: False

skip

Number of additional skipped channels in the input tensor.

TYPE: int DEFAULT: 0

Source code in cassetta/layers/simple.py
def __init__(
    self,
    blocks: List[nn.Module],
    residual: bool = False,
    skip: int = 0,
):
    """
    Parameters
    ----------
    blocks : list[Module]
        Number of blocks
    residual : bool
        Use residual connections between blocks
    skip : int
        Number of additional skipped channels in the input tensor.
    """
    super().__init__(*blocks)
    self.residual = residual
    self.skip = skip

forward

forward(inp)
PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, channels [+skip], *size) tensor

RETURNS DESCRIPTION
out

Output tensor

TYPE: (B, channels, *size) tensor

Source code in cassetta/layers/simple.py
def forward(self, inp: Tensor) -> Tensor:
    """
    Parameters
    ----------
    inp : (B, channels [+skip], *size) tensor
        Input tensor

    Returns
    -------
    out : (B, channels, *size) tensor
        Output tensor
    """
    x = inp

    layers = list(self)
    if self.skip:
        first, *layers = layers
        if self.residual:
            identity = x
            x = first(x)
            x += identity[:, :x.shape[1]]
        else:
            x = first(x)

    if self.residual:
        for layer in layers:
            identity = x
            x = layer(x)
            x += identity
    else:
        for layer in layers:
            x = layer(x)
    return x

GlobalPool

GlobalPool(reduction='mean', keepdim=True, dim='spatial')

Bases: Module

Global pooling across spatial dimensions

Diagram

flowchart LR
    1["`[B, C, W, H]`"] ---2("`GlobalPool`"):::d-->
    3["`[B, C, 1, 1]`"]
    classDef d fill:lightcyan,stroke:lightblue;
flowchart LR
    1["`[B, C, W, H]`"] ---2("`GlobalPool`"):::d-->
    3["`[B, C]`"]
    classDef d fill:lightcyan,stroke:lightblue;
flowchart LR
    1["`[B, C, W, H]`"] ---2("`GlobalPool`"):::d-->
    3["`[B, 1, W, H]`"]
    classDef d fill:lightcyan,stroke:lightblue;
PARAMETER DESCRIPTION
reduction

Reduction type

TYPE: (mean, max) DEFAULT: 'mean'

keepdim

Keep spatial dimensions

TYPE: bool DEFAULT: True

dim

Dimension(s) to pool

TYPE: [list of] int or {'spatial'} DEFAULT: 'spatial'

Source code in cassetta/layers/simple.py
def __init__(
    self,
    reduction: Literal['mean', 'max'] = 'mean',
    keepdim: bool = True,
    dim: Union[OneOrSeveral[int], Literal['spatial']] = 'spatial',
):
    """
    Parameters
    ----------
    reduction : {'mean', 'max'}
        Reduction type
    keepdim : bool
        Keep spatial dimensions
    dim : [list of] int or {'spatial'}
        Dimension(s) to pool
    """
    super().__init__()
    self.reduction = reduction.lower()
    self.keepdim = keepdim
    self.dim = dim

forward

forward(inp)
PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, C, *spatial) tensor

RETURNS DESCRIPTION
out

Output tensor

TYPE: (B, C, [*ones]) tensor

Source code in cassetta/layers/simple.py
def forward(self, inp):
    """
    Parameters
    ----------
    inp : (B, C, *spatial) tensor
        Input tensor

    Returns
    -------
    out : (B, C, [*ones]) tensor
        Output tensor
    """
    if isinstance(self.dim, str):
        if self.dim[0].lower() != 's':
            raise ValueError('Unknown dimension:', self.dim)
        dims = list(range(2, inp.ndim))
    else:
        dims = self.dim
    if self.reduction == 'max':
        return inp.max(dim=dims, keepdim=self.keepdim).values
    elif self.reduction == 'mean':
        return inp.mean(dim=dims, keepdim=self.keepdim)
    else:
        raise ValueError(f'Unknown reduction "{self.reduction}"')

cassetta.layers.updown

DownConv

DownConv(ndim, inp_channels, out_channels=None, factor=2)

Bases: Module

Downsample using a strided convolution.

This layer includes no activation/norm/dropout

PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

inp_channels

Number of input channels

TYPE: int

out_channels

Number of output channels

TYPE: int, ddefault=`inp_channels` DEFAULT: None

factor

Downsampling factor

TYPE: [list of] int DEFAULT: 2

Source code in cassetta/layers/updown.py
def __init__(
    self,
    ndim: int,
    inp_channels: int,
    out_channels: Optional[int] = None,
    factor: OneOrSeveral[int] = 2,
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    inp_channels : int
        Number of input channels
    out_channels : int, ddefault=`inp_channels`
        Number of output channels
    factor : [list of] int
        Downsampling factor
    """
    out_channels = out_channels or inp_channels
    super().__init__()
    Conv = getattr(nn, f'Conv{ndim}d')
    self.conv = Conv(
        in_channels=inp_channels,
        out_channels=out_channels,
        stride=factor,
        kernel_size=factor,
        padding=0,
    )

forward

forward(inp)
PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, inp_channels, *inp_spatial) tensor

RETURNS DESCRIPTION
out

Output downsampled tensor

TYPE: (B, out_channels, *out_spatial) tensor

Source code in cassetta/layers/updown.py
def forward(self, inp):
    """
    Parameters
    ----------
    inp : (B, inp_channels, *inp_spatial) tensor
        Input tensor

    Returns
    -------
    out : (B, out_channels, *out_spatial) tensor
        Output downsampled tensor
    """
    return self.conv(inp)

UpConv

UpConv(ndim, inp_channels, out_channels=None, factor=2)

Bases: Module

Upsample using a strided convolution.

This layer includes no activation/norm/dropout

PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

inp_channels

Number of input channels

TYPE: int

out_channels

Number of output channels

TYPE: int, ddefault=`inp_channels` DEFAULT: None

factor

Downsampling factor

TYPE: [list of] int DEFAULT: 2

Source code in cassetta/layers/updown.py
def __init__(
    self,
    ndim: int,
    inp_channels: int,
    out_channels: Optional[int] = None,
    factor: OneOrSeveral[int] = 2,
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    inp_channels : int
        Number of input channels
    out_channels : int, ddefault=`inp_channels`
        Number of output channels
    factor : [list of] int
        Downsampling factor
    """
    out_channels = out_channels or inp_channels
    super().__init__()
    Conv = getattr(nn, f'ConvTranspose{ndim}d')
    self.conv = Conv(
        in_channels=inp_channels,
        out_channels=out_channels,
        stride=factor,
        kernel_size=factor,
        padding=0,
    )

forward

forward(inp)
PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, inp_channels, *inp_spatial) tensor

RETURNS DESCRIPTION
out

Output downsampled tensor

TYPE: (B, out_channels, *out_spatial) tensor

Source code in cassetta/layers/updown.py
def forward(self, inp):
    """
    Parameters
    ----------
    inp : (B, inp_channels, *inp_spatial) tensor
        Input tensor

    Returns
    -------
    out : (B, out_channels, *out_spatial) tensor
        Output downsampled tensor
    """
    return self.conv(inp)

DownPool

DownPool(ndim, inp_channels, out_channels=None, factor=2, return_indices=False)

Bases: Sequential

Downsampling using max-pooling + channel expansion

This layer includes no activation/norm/dropout

Cinp -[maxpool รท2]-> Cinp -[conv 1x1x1]-> -> Cout
PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

inp_channels

Number of input channels

TYPE: int

out_channels

Number of output channels

TYPE: int DEFAULT: None

factor

Downsampling factor

TYPE: [list of] int DEFAULT: 2

return_indices

Return indices on top of pooled features

TYPE: bool DEFAULT: False

Source code in cassetta/layers/updown.py
def __init__(
    self,
    ndim: int,
    inp_channels: int,
    out_channels: Optional[int] = None,
    factor: OneOrSeveral[int] = 2,
    return_indices=False,
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    inp_channels : int
        Number of input channels
    out_channels : int
        Number of output channels
    factor : [list of] int
        Downsampling factor
    return_indices : bool
        Return indices on top of pooled features
    """
    super().__init__()
    MaxPool = getattr(nn, f'MaxPool{ndim}d')
    Conv = getattr(nn, f'Conv{ndim}d')
    out_channels = out_channels or inp_channels
    layers = [MaxPool(
        kernel_size=factor,
        stride=factor,
        return_indices=return_indices,
    )]
    if out_channels != inp_channels:
        layers += [Conv(
            inp_channels,
            out_channels,
            kernel_size=1,
        )]
    else:
        layers += [DoNothing()]
    super().__init__(*layers)
    self.inp_channels = inp_channels
    self.out_channels = out_channels

forward

forward(x)
PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, inp_channels, *inp_spatial) tensor

RETURNS DESCRIPTION
out

Output downsampled tensor

TYPE: (B, out_channels, *out_spatial) tensor

indices

Argmax of the maxpooling operation. Only returned if return_indices=True

TYPE: (B, out_channels, *out_spatial) tensor[long]

Source code in cassetta/layers/updown.py
def forward(self, x):
    """
    Parameters
    ----------
    inp : (B, inp_channels, *inp_spatial) tensor
        Input tensor

    Returns
    -------
    out : (B, out_channels, *out_spatial) tensor
        Output downsampled tensor
    indices : (B, out_channels, *out_spatial) tensor[long]
        Argmax of the maxpooling operation.
        Only returned if `return_indices=True`
    """
    if self[0].return_indices:
        pool, conv = self
        x, ind = pool(x)
        x = conv(x)
        return x, ind
    else:
        return super().forward(x)

UpPool

UpPool(ndim, inp_channels, out_channels=None, factor=2)

Bases: Sequential

Downsampling using max-pooling + channel expansion

This layer includes no activation/norm/dropout

Indices --------------------------- .
                                    |
                                    v
Cinp -[conv 1x1x1]-> Cout -> -[maxunpool x2]-> Cout
PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

inp_channels

Number of input channels

TYPE: int

out_channels

Number of output channels

TYPE: int DEFAULT: None

factor

Downsampling factor

TYPE: [list of] int DEFAULT: 2

Source code in cassetta/layers/updown.py
def __init__(
    self,
    ndim: int,
    inp_channels: int,
    out_channels: Optional[int] = None,
    factor: OneOrSeveral[int] = 2,
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    inp_channels : int
        Number of input channels
    out_channels : int
        Number of output channels
    factor : [list of] int
        Downsampling factor
    """
    super().__init__()
    MaxUnpool = getattr(nn, f'MaxUnpool{ndim}d')
    Conv = getattr(nn, f'Conv{ndim}d')
    out_channels = out_channels or inp_channels
    layers = []
    if out_channels != inp_channels:
        layers += [Conv(
            inp_channels,
            out_channels,
            kernel_size=1,
        )]
    else:
        layers += [DoNothing()]
    layers += [MaxUnpool(
        kernel_size=factor,
        stride=factor,
    )]
    super().__init__(*layers)
    self.inp_channels = inp_channels
    self.out_channels = out_channels

forward

forward(inp, *, indices)
PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, inp_channels, *inp_spatial) tensor

indices

Indices returned by DownPool or MaxPool{ndim}d

TYPE: (B, out_channels, *inp_spatial) tensor[long]

RETURNS DESCRIPTION
out

Output upsampled tensor

TYPE: (B, out_channels, *out_spatial) tensor

Source code in cassetta/layers/updown.py
def forward(self, inp: Tensor, *, indices: Tensor) -> Tensor:
    """
    Parameters
    ----------
    inp : (B, inp_channels, *inp_spatial) tensor
        Input tensor
    indices : (B, out_channels, *inp_spatial) tensor[long]
        Indices returned by `DownPool` or `MaxPool{ndim}d`

    Returns
    -------
    out : (B, out_channels, *out_spatial) tensor
        Output upsampled tensor
    """
    conv, unpool = self
    out = conv(inp)
    out = unpool(out, indices)
    return out

DownInterpol

DownInterpol(ndim, inp_channels, out_channels=None, factor=2, interpolation='linear', bound='zero', prefilter=True)

Bases: Sequential

Downsampling using spline interpolation + channel expansion

This layer includes no activation/norm/dropout

Cinp -[interpol รท2]-> Cinp -[conv 1x1x1]-> -> Cout
PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

inp_channels

Number of input channels

TYPE: int

out_channels

Number of output channels

TYPE: int DEFAULT: None

factor

Downsampling factor

TYPE: [list of] int DEFAULT: 2

interpolation

Interpolation order.

TYPE: [list of] InterpolationType DEFAULT: 'linear'

bound

Boundary conditions.

TYPE: [list of] BoundType DEFAULT: 'zero'

prefilter

Apply spline pre-filter (= interpolates the input)

TYPE: bool DEFAULT: True

Source code in cassetta/layers/updown.py
def __init__(
    self,
    ndim: int,
    inp_channels: int,
    out_channels: Optional[int] = None,
    factor: OneOrSeveral[int] = 2,
    interpolation: InterpolationType = 'linear',
    bound: BoundType = 'zero',
    prefilter: bool = True,
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    inp_channels : int
        Number of input channels
    out_channels : int
        Number of output channels
    factor : [list of] int
        Downsampling factor
    interpolation : [list of] InterpolationType
        Interpolation order.
    bound : [list of] BoundType
        Boundary conditions.
    prefilter : bool
        Apply spline pre-filter (= interpolates the input)
    """
    super().__init__()
    Conv = getattr(nn, f'Conv{ndim}d')
    out_channels = out_channels or inp_channels
    factor = list(map(lambda x: 1/x, ensure_list(factor, ndim)))
    layers = [Resize(
        factor=factor,
        interpolation=interpolation,
        bound=bound,
        prefilter=prefilter,
    )]
    if out_channels != inp_channels:
        layers += [Conv(
            inp_channels,
            out_channels,
            kernel_size=1,
        )]
    super().__init__(*layers)
    self.inp_channels = inp_channels
    self.out_channels = out_channels

UpInterpol

UpInterpol(ndim, inp_channels, out_channels=None, factor=2, interpolation='linear', bound='zero', prefilter=True)

Bases: Sequential

Upsampling using spline interpolation + channel expansion

This layer includes no activation/norm/dropout

Cinp -[conv 1x1x1]-> Cout -> -[interpol x2]-> Cout
PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

inp_channels

Number of input channels

TYPE: int

out_channels

Number of output channels

TYPE: int DEFAULT: None

factor

Downsampling factor

TYPE: [list of] int DEFAULT: 2

interpolation

Interpolation order.

TYPE: [list of] InterpolationType DEFAULT: 'linear'

bound

Boundary conditions.

TYPE: [list of] BoundType DEFAULT: 'zero'

prefilter

Apply spline pre-filter (= interpolates the input)

TYPE: bool DEFAULT: True

Source code in cassetta/layers/updown.py
def __init__(
    self,
    ndim: int,
    inp_channels: int,
    out_channels: Optional[int] = None,
    factor: OneOrSeveral[int] = 2,
    interpolation: InterpolationType = 'linear',
    bound: BoundType = 'zero',
    prefilter: bool = True,
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    inp_channels : int
        Number of input channels
    out_channels : int
        Number of output channels
    factor : [list of] int
        Downsampling factor
    interpolation : [list of] InterpolationType
        Interpolation order.
    bound : [list of] BoundType
        Boundary conditions.
    prefilter : bool
        Apply spline pre-filter (= interpolates the input)
    """
    super().__init__()
    Conv = getattr(nn, f'Conv{ndim}d')
    out_channels = out_channels or inp_channels
    layers = []
    if out_channels != inp_channels:
        layers += [Conv(
            inp_channels,
            out_channels,
            kernel_size=1,
        )]
    layers += [Resize(
        factor=ensure_list(factor, ndim),
        interpolation=interpolation,
        bound=bound,
        prefilter=prefilter,
    )]
    super().__init__(*layers)
    self.inp_channels = inp_channels
    self.out_channels = out_channels