Skip to content

cassetta.backbones

Overview

Backbones are complex/deep architectures, made of a bunch of layers.

In our philosophy, backbones are task-independant, and do not care about the number of input and output channels of the problem at hand. Instead they typically map an input feature space to an output feature space.

Models would wrap a backbone between a feature-extraction layer (a single convolution without activation, possibly with a somewhat large kernel size) and a feature-mapping layer (a single 1x1x1 convolution, possibly followed by a task-specific activation like a SoftMax).

MODULE DESCRIPTION
fcn

Fully convolutional encoders and decoders

unet

U-Nets: autoencoder wirth skip connections

atrous

Networks that use dilated convolutions

cassetta.backbones.fcn

ConvEncoder

ConvEncoder(ndim, nb_features=16, mul_features=2, nb_levels=3, nb_conv_per_level=2, kernel_size=3, residual=False, activation='ReLU', norm=None, dropout=None, attention=None, order='cndax', pool_factor=2, pool_mode='interpolate')

Bases: Sequential

A fully convolutional encoder

Diagram

flowchart LR
    1["`[F0, W]`"]    ---2("ConvGroup"):::w-->
    3["`[F0, W]`"]    ---4("Down"):::w-->
    5["`[F1, W//2]`"] ---6("ConvGroup"):::w-->
    7["`[F1, W//2]`"] --- 8("Down"):::w-->
    9["`[F2, W//4]`"] ---10("ConvGroup"):::w-->
    11["`[F2, W//4]`"]
    classDef w fill:papayawhip,stroke:peachpuff;
PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

nb_features

Number of features at the finest level. If a list, number of features at each level of the encoder.

TYPE: [list of] int DEFAULT: 16

mul_features

Multiply the number of features by this number each time we go down one level.

TYPE: int DEFAULT: 2

nb_levels

Number of levels in the encoder

TYPE: int DEFAULT: 3

nb_conv_per_level

Number of convolutional layers at each level.

TYPE: int DEFAULT: 2

kernel_size

Kernel size

TYPE: [list of] int DEFAULT: 3

residual

Use residual connections between convolutional blocks

TYPE: bool DEFAULT: False

activation

Type of activation

TYPE: ActivationLike DEFAULT: 'ReLU'

norm

Normalization

TYPE: NormType DEFAULT: None

dropout

Channel dropout probability

TYPE: DropoutType DEFAULT: None

attention

Attention

TYPE: AttentionType DEFAULT: None

order

Modules order (permutation of 'ncdax')

TYPE: str DEFAULT: 'cndax'

pool_factor

Downsampling factor (per dimension).

TYPE: [list of] int DEFAULT: 2

pool_mode

Method used to go down one level.

TYPE: (interpolate, conv, pool) DEFAULT: 'interpolate'

Source code in cassetta/backbones/fcn.py
def __init__(
    self,
    ndim: int,
    nb_features: OneOrSeveral[int] = 16,
    mul_features: int = 2,
    nb_levels: int = 3,
    nb_conv_per_level: int = 2,
    kernel_size: OneOrSeveral[int] = 3,
    residual: bool = False,
    activation: ActivationType = 'ReLU',
    norm: NormType = None,
    dropout: DropoutType = None,
    attention: AttentionType = None,
    order: str = 'cndax',
    pool_factor: OneOrSeveral[int] = 2,
    pool_mode: str = 'interpolate',
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    nb_features : [list of] int
        Number of features at the finest level.
        If a list, number of features at each level of the encoder.
    mul_features : int
        Multiply the number of features by this number
        each time we go down one level.
    nb_levels : int
        Number of levels in the encoder
    nb_conv_per_level : int
        Number of convolutional layers at each level.
    kernel_size : [list of] int
        Kernel size
    residual : bool
        Use residual connections between convolutional blocks
    activation : ActivationLike
        Type of activation
    norm : NormType
        Normalization
    dropout : DropoutType
        Channel dropout probability
    attention : AttentionType
        Attention
    order : str
        Modules order (permutation of 'ncdax')
    pool_factor : [list of] int
        Downsampling factor (per dimension).
    pool_mode : {'interpolate', 'conv', 'pool'}
        Method used to go down one level.
    """
    make_inp = partial(
        ConvGroup,
        ndim,
        kernel_size=kernel_size,
        residual=residual,
        activation=activation,
        norm=norm,
        dropout=dropout,
        attention=attention,
        order=order,
        nb_conv=nb_conv_per_level,
    )
    make_down = partial(
        DownConvGroup,
        ndim,
        kernel_size=kernel_size,
        residual=residual,
        activation=activation,
        norm=norm,
        dropout=dropout,
        attention=attention,
        order=order,
        nb_conv=nb_conv_per_level,
        factor=pool_factor,
        mode=pool_mode,
    )

    # number of features per level
    if isinstance(nb_features, int):
        enc_features = [
            int(nb_features * mul_features**level)
            for level in range(nb_levels)
        ]
    else:
        enc_features = list(nb_features)
        # Overwrite any defined nb_levels if user defines feat list.
        nb_levels = len(nb_features)
        enc_features += [
            int(enc_features[-1:] * mul_features**level)
            for level in range(nb_levels - len(enc_features))
        ]
        enc_features = enc_features[:nb_levels]

    # build encoder
    encoder = [make_inp(enc_features[0])]
    for i in range(1, nb_levels):
        encoder += [make_down(enc_features[i-1], enc_features[i])]
    super().__init__(*encoder)

forward

forward(inp, *, return_all=False)
PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, nb_features[0], *inp_size) tensor

return_all

Return all intermediate output tensors (at each level)

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
out

Output tensor(s). If return_all, return all intermediate tensors, from finest to coarsest. Else, return the final tensor only.

TYPE: [tuple of] (B, nb_features[-1], *out_size) tensor

Source code in cassetta/backbones/fcn.py
def forward(self, inp, *, return_all=False):
    """
    Parameters
    ----------
    inp : (B, nb_features[0], *inp_size) tensor
        Input tensor
    return_all : bool
        Return all intermediate output tensors (at each level)

    Returns
    -------
    out : [tuple of] (B, nb_features[-1], *out_size) tensor
        Output tensor(s).
        If `return_all`, return all intermediate tensors, from
        finest to coarsest. Else, return the final tensor only.
    """
    if return_all:
        out = inp
        all = []
        for layer in self:
            out = layer(out)
            all.append(out)
        return tuple(all)
    else:
        return super().forward(inp)

ConvDecoder

ConvDecoder(ndim, nb_features=16, div_features=2, nb_levels=3, nb_conv_per_level=2, skip=False, kernel_size=3, residual=False, activation='ReLU', norm=None, dropout=None, attention=None, order='cndax', unpool_factor=2, unpool_mode='interpolate')

Bases: Sequential

A fully convolutional decoder

Diagram: pure decoder

flowchart LR
    1["`[F0, W]`"]     ---2("Up"):::w-->
    3["`[F1, W*2]`"]   ---4("ConvGroup"):::w-->
    5["`[F1, W*2]`"]   ---6("Up"):::w-->
    7["`[F2, W*4]`"]   ---8("ConvGroup"):::w-->
    9["`[F2, W*4]`"]
    classDef w fill:papayawhip,stroke:peachpuff;
flowchart LR
    S1["`[S1, W*2]`"]
    S2["`[S2, W*2]`"]
    1["`[F0, W]`"]        ---2("Up"):::w-->
    3["`[F1, W*2]`"]      ---4(("c")):::d-->
    5["`[F1+S1, W]`"]     ---6("ConvGroup"):::w-->
    7["`[F1, W*2]`"]      ---8("Up"):::w-->
    9["`[F2, W*4]`"]      ---10(("c")):::d-->
    11["`[F2+S2, W*4]`"]  ---12("ConvGroup"):::w-->
    13["`[F2, W*4]`"]
    S1 --- 4
    S2 --- 10
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
flowchart LR
    S1["`[F1, W*2]`"]
    S2["`[F2, W*2]`"]
    1["`[F0, W]`"]        ---2("Up"):::w-->
    3["`[F1, W*2]`"]      ---4(("+")):::d-->
    5["`[F1, W]`"]        ---6("ConvGroup"):::w-->
    7["`[F1, W*2]`"]      ---8("Up"):::w-->
    9["`[F2, W*4]`"]      ---10(("+")):::d-->
    11["`[F2, W*4]`"]     ---12("ConvGroup"):::w-->
    13["`[F2, W*4]`"]
    S1 --- 4
    S2 --- 10
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

nb_features

Number of features at the finest level. If a list, number of features at each level of the encoder.

TYPE: [list of] int DEFAULT: 16

div_features

Divide the number of features by this number each time we go up one level.

TYPE: int DEFAULT: 2

nb_levels

Number of levels in the encoder

TYPE: int DEFAULT: 3

nb_conv_per_level

Number of convolutional layers at each level.

TYPE: int DEFAULT: 2

skip

Number of channels to concatenate in the skip connection. If 0 (or False) and skip tensors are provided, will try to add them instead of cat. If True, the number of skipped channels and the number of features are identical.

TYPE: int or bool DEFAULT: False

kernel_size

Kernel size

TYPE: [list of] int DEFAULT: 3

residual

Use residual connections between convolutional blocks

TYPE: bool DEFAULT: False

activation

Type of activation

TYPE: ActivationLike DEFAULT: 'ReLU'

norm

Normalization

TYPE: NormType DEFAULT: None

dropout

Channel dropout probability

TYPE: DropoutType DEFAULT: None

attention

Attention

TYPE: AttentionType DEFAULT: None

order

Modules order (permutation of 'ncdax')

TYPE: str DEFAULT: 'cndax'

unpool_factor

Upsampling factor (per dimension).

TYPE: [list of] int DEFAULT: 2

unpool_mode

Method used to go up one level.

TYPE: (interpolate, conv) DEFAULT: 'interpolate'

Source code in cassetta/backbones/fcn.py
def __init__(
    self,
    ndim: int,
    nb_features: OneOrSeveral[int] = 16,
    div_features: int = 2,
    nb_levels: int = 3,
    nb_conv_per_level: int = 2,
    skip: Union[bool, OneOrSeveral[int]] = False,
    kernel_size: OneOrSeveral[int] = 3,
    residual: bool = False,
    activation: ActivationType = 'ReLU',
    norm: NormType = None,
    dropout: DropoutType = None,
    attention: AttentionType = None,
    order: str = 'cndax',
    unpool_factor: OneOrSeveral[int] = 2,
    unpool_mode: str = 'interpolate',
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    nb_features : [list of] int
        Number of features at the finest level.
        If a list, number of features at each level of the encoder.
    div_features : int
        Divide the number of features by this number
        each time we go up one level.
    nb_levels : int
        Number of levels in the encoder
    nb_conv_per_level : int
        Number of convolutional layers at each level.
    skip : int or bool
        Number of channels to concatenate in the skip connection.
        If 0 (or False) and skip tensors are provided, will try to
        add them instead of cat. If True, the number of skipped
        channels and the number of features are identical.
    kernel_size : [list of] int
        Kernel size
    residual : bool
        Use residual connections between convolutional blocks
    activation : ActivationLike
        Type of activation
    norm : NormType
        Normalization
    dropout : DropoutType
        Channel dropout probability
    attention : AttentionType
        Attention
    order : str
        Modules order (permutation of 'ncdax')
    unpool_factor : [list of] int
        Upsampling factor (per dimension).
    unpool_mode : {'interpolate', 'conv'}
        Method used to go up one level.
    """
    make_up = partial(
        UpConvGroup,
        ndim,
        kernel_size=kernel_size,
        residual=residual,
        activation=activation,
        norm=norm,
        dropout=dropout,
        attention=attention,
        order=order,
        nb_conv=nb_conv_per_level,
        factor=unpool_factor,
        mode=unpool_mode,
    )

    # number of features per level
    if isinstance(nb_features, int):
        dec_features = [
            max(1, int(nb_features // div_features**level))
            for level in range(nb_levels)
        ]
    else:
        # Overwrite any defined nb_levels if user defines feat list.
        dec_features = list(nb_features)
        nb_levels = len(dec_features)
        dec_features += [
            max(1, int(dec_features[-1:] // div_features**level))
            for level in range(nb_levels - len(dec_features))
        ]
        dec_features = dec_features[:nb_levels]

    # number of skipped channels per level
    if skip is True:
        skip = dec_features[1:]
    elif not skip:
        skip = [0] * (nb_levels-1)
    elif isinstance(skip, int):
        skip = [skip] * (nb_levels - 1)
    else:
        skip = list(skip) + [0] * max(0, nb_levels - 1 - len(skip))

    # build decoder
    decoder = []
    for i in range(nb_levels-1):
        decoder += [make_up(
            dec_features[i], dec_features[i+1], skip=skip[i]
        )]
    super().__init__(*decoder)

forward

forward(*inp, return_all=False)
PARAMETER DESCRIPTION
*inp

Input tensor(s), eventually including skip connections. Ordered from coarsest to finest.

TYPE: (B, nb_features[n], *inp_size[n]) tensor DEFAULT: ()

return_all

Return all intermediate output tensors (at each level).

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
out

Output tensor(s). If return_all, return all intermediate tensors, from coarsest to finest. Else, return the final tensor only.

TYPE: [tuple of] (B, nb_features[-1], *out_size) tensor

Source code in cassetta/backbones/fcn.py
def forward(self, *inp, return_all=False):
    """
    Parameters
    ----------
    *inp : (B, nb_features[n], *inp_size[n]) tensor
        Input tensor(s), eventually including skip connections.
        Ordered from coarsest to finest.
    return_all : bool
        Return all intermediate output tensors (at each level).

    Returns
    -------
    out : [tuple of] (B, nb_features[-1], *out_size) tensor
        Output tensor(s).
        If `return_all`, return all intermediate tensors, from
        coarsest to finest. Else, return the final tensor only.
    """
    inp, *skips = inp
    skips = list(skips)
    all = []

    out = inp
    for layer in self:
        args = [skips.pop(0)] if skips else []
        out = layer(out, *args)
        if return_all:
            all.append(out)
    return tuple(all) if return_all else out

cassetta.backbones.unet

UNet

UNet(ndim, nb_features=64, mul_features=2, nb_levels=5, nb_levels_decoder=None, nb_conv_per_level=2, kernel_size=3, residual=False, activation='ReLU', norm=None, dropout=None, attention=None, order='cndax', pool_factor=2, pool_mode='pool', unpool_mode='conv', skip=True)

Bases: Module

A UNet

Diagram

flowchart LR
    II0["`[F0, W]`"]            ---CI0("`ConvGroup`"):::w-->
    IO0["`[F0, W]`"]            ---D1("`Down`"):::w-->
    II1["`[F1, W//2]`"]         ---CI1("`ConvGroup`"):::w-->
    IO1["`[F1, W//2]`"]         ---D2("`Down`"):::w-->
    II2["`[F2, W//4]`"]         ---CI2("`ConvGroup`"):::w-->
    OO2["`[F2, W//4]`"]:::o     ---U2("`Up`"):::w-->
    OI1["`[F1, W//2]`"]         ---Z1(("c")):::d-->
    OZ1["`[F1*2, W//2]`"]       ---CO1("`ConvGroup`"):::w-->
    OO1["`[F1, W//2]`"]:::o     ---U1("`Up`"):::w-->
    OI0["`[F0, W]`"]            ---Z0(("c")):::d-->
    OZ0["`[F0*2, W]`"]          ---CO0("`ConvGroup`"):::w-->
    OO0["`[F0, W]`"]:::o
    IO0 --- Z0
    IO1 --- Z1
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef o fill:mistyrose,stroke:lightpink;
flowchart LR
    II0["`[F0, W]`"]            ---CI0("`ConvGroup`"):::w-->
    IO0["`[F0, W]`"]            ---D1("`Down`"):::w-->
    II1["`[F1, W//2]`"]         ---CI1("`ConvGroup`"):::w-->
    IO1["`[F1, W//2]`"]         ---D2("`Down`"):::w-->
    II2["`[F2, W//4]`"]         ---CI2("`ConvGroup`"):::w-->
    OO2["`[F2, W//4]`"]:::o     ---U2("`Up`"):::w-->
    OI1["`[F1, W//2]`"]         ---Z1(("+")):::d-->
    OZ1["`[F1, W//2]`"]         ---CO1("`ConvGroup`"):::w-->
    OO1["`[F1, W//2]`"]:::o     ---U1("`Up`"):::w-->
    OI0["`[F0, W]`"]            ---Z0(("+")):::d-->
    OZ0["`[F0, W]`"]            ---CO0("`ConvGroup`"):::w-->
    OO0["`[F0, W]`"]:::o
    IO0 --- Z0
    IO1 --- Z1
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef o fill:mistyrose,stroke:lightpink;
flowchart LR
    II0["`[F0, W]`"]            ---CI0("`ConvGroup`"):::w-->
    IO0["`[F0, W]`"]            ---D1("`Down`"):::w-->
    II1["`[F1, W//2]`"]         ---CI1("`ConvGroup`"):::w-->
    IO1["`[F1, W//2]`"]         ---D2("`Down`"):::w-->
    II2["`[F2, W//4]`"]         ---CI2("`ConvGroup`"):::w-->
    OO2["`[F2, W//4]`"]:::o     ---U2("`Up`"):::w-->
    OI1["`[F1, W//2]`"]         ---CO1("`ConvGroup`"):::w-->
    OO1["`[F1, W//2]`"]:::o     ---U1("`Up`"):::w-->
    OI0["`[F0, W]`"]            ---CO0("`ConvGroup`"):::w-->
    OO0["`[F0, W]`"]:::o
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef o fill:mistyrose,stroke:lightpink;

Difference with Ronneberger et al.

  • Default parameters are from Ronneberger et al.
  • However, instead of performing a 3x3 channel-expanding convolution + ReLU in the encoder, we first perform a 1x1 channel-expanding convolution without ReLU, followed by a 3x3 channel-preserving convolution + ReLU.
  • Both implementations have the same reprentation power, although ours adds unneeded free parameters.
  • The benefit of our approach is it brings a bit more flexibility. We can easily replace max-pooling with other types of downsampling operators (e.g., linear downsampling or strided convolution) using pool_mode="interpolate" or pool_mode="conv".

Reference

Ronneberger, Fischer & Brox, "U-Net: Convolutional Networks for Biomedical Image Segmentation." MICCAI (2015). arxiv:1505.04597

PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

nb_features

Number of features at the finest level. If a list, number of features at each level of the encoder.

TYPE: [list of] int DEFAULT: 64

mul_features

Multiply the number of features by this number each time we go down one level.

TYPE: int DEFAULT: 2

nb_levels

Number of levels in the encoder

TYPE: int DEFAULT: 5

nb_levels_decoder

Number of levels in the decoder

TYPE: int DEFAULT: `nb_levels`

nb_conv_per_level

Number of convolutional layers at each level.

TYPE: int DEFAULT: 2

kernel_size

Kernel size

TYPE: [list of] int DEFAULT: 3

residual

Use residual connections between convolutional blocks

TYPE: bool DEFAULT: False

activation

Type of activation

TYPE: ActivationLike DEFAULT: 'ReLU'

norm

Normalization

TYPE: NormType DEFAULT: None

dropout

Channel dropout probability

TYPE: DropoutType DEFAULT: None

attention

Attention

TYPE: AttentionType DEFAULT: None

order

Modules order (permutation of 'ncdax')

TYPE: str DEFAULT: 'cndax'

pool_factor

Down/Upsampling factor (per dimension).

TYPE: [list of] int DEFAULT: 2

pool_mode

Method used to go down one level.

  • If "interpolate", use linear interpolation.
  • If "conv", use strided convolutions.
  • If "pool", use max pooling.

TYPE: (interpolate, conv, pool) DEFAULT: 'interpolate'

unpool_mode

Method used to go up one level.

  • If "interpolate", use linear interpolation.
  • If "conv", use transposed convolutions.
  • "pool" (i.e., unpooling) is not supported right now.

TYPE: (interpolate, conv) DEFAULT: 'interpolate'

skip

Type of skip connections:

  • False: no skip connections
  • True: concatenate skip connections
  • '+': add skip connections

TYPE: bool or {'+'} DEFAULT: True

Source code in cassetta/backbones/unet.py
def __init__(
    self,
    ndim: int,
    nb_features: OneOrSeveral[int] = 64,
    mul_features: int = 2,
    nb_levels: int = 5,
    nb_levels_decoder: Optional[int] = None,
    nb_conv_per_level: int = 2,
    kernel_size: OneOrSeveral[int] = 3,
    residual: bool = False,
    activation: ActivationType = 'ReLU',
    norm: NormType = None,
    dropout: DropoutType = None,
    attention: AttentionType = None,
    order: str = 'cndax',
    pool_factor: OneOrSeveral[int] = 2,
    pool_mode: str = 'pool',
    unpool_mode: Optional[str] = 'conv',
    skip: Union[bool, Literal["+"]] = True,
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    nb_features : [list of] int
        Number of features at the finest level.
        If a list, number of features at each level of the encoder.
    mul_features : int
        Multiply the number of features by this number
        each time we go down one level.
    nb_levels : int
        Number of levels in the encoder
    nb_levels_decoder : int, default=`nb_levels`
        Number of levels in the decoder
    nb_conv_per_level : int
        Number of convolutional layers at each level.
    kernel_size : [list of] int
        Kernel size
    residual : bool
        Use residual connections between convolutional blocks
    activation : ActivationLike
        Type of activation
    norm : NormType
        Normalization
    dropout : DropoutType
        Channel dropout probability
    attention : AttentionType
        Attention
    order : str
        Modules order (permutation of 'ncdax')
    pool_factor : [list of] int
        Down/Upsampling factor (per dimension).
    pool_mode : {'interpolate', 'conv', 'pool'}
        Method used to go down one level.

        - If `"interpolate"`, use linear interpolation.
        - If `"conv"`, use strided convolutions.
        - If `"pool"`, use max pooling.
    unpool_mode : {'interpolate', 'conv'}, default=`pool_mode`
        Method used to go up one level.

        - If `"interpolate"`, use linear interpolation.
        - If `"conv"`, use transposed convolutions.
        - `"pool"` (i.e., unpooling) is not supported right now.
    skip : bool or {'+'}
        Type of skip connections:

        - `False`: no skip connections
        - `True`: concatenate skip connections
        - `'+'`: add skip connections
    """
    # number of features per level
    nb_levels_decoder = nb_levels_decoder or nb_levels
    if isinstance(nb_features, int):
        enc_features = [
            int(nb_features * mul_features**level)
            for level in range(nb_levels)
        ]
    else:
        enc_features = list(nb_features)
        enc_features += [
            int(enc_features[-1:] * mul_features**level)
            for level in range(nb_levels - len(enc_features))
        ]
        enc_features = enc_features[:nb_levels]
    dec_features = list(reversed(enc_features))
    dec_features += [
        max(1, int(dec_features[-1:] * mul_features**(-level)))
        for level in range(nb_levels_decoder - len(dec_features))
    ]
    dec_features = dec_features[:nb_levels_decoder]

    unpool_mode = unpool_mode or pool_mode
    if unpool_mode == 'pool':
        unpool_mode = 'interpol'

    # build encoder/decoder
    super().__init__()
    self.encoder = ConvEncoder(
        ndim,
        nb_features=enc_features,
        mul_features=1,
        nb_levels=nb_levels,
        nb_conv_per_level=nb_conv_per_level,
        kernel_size=kernel_size,
        residual=residual,
        activation=activation,
        norm=norm,
        dropout=dropout,
        attention=attention,
        order=order,
        pool_factor=pool_factor,
        pool_mode=pool_mode,
    )
    self.decoder = ConvDecoder(
        ndim,
        nb_features=dec_features,
        nb_levels=nb_levels_decoder,
        nb_conv_per_level=nb_conv_per_level,
        skip=(skip and skip != '+'),
        kernel_size=kernel_size,
        residual=residual,
        activation=activation,
        norm=norm,
        dropout=dropout,
        attention=attention,
        order=order,
        unpool_factor=pool_factor,
        unpool_mode=unpool_mode,
    )
    self.skip = skip

forward

forward(inp, *, return_all=False)
PARAMETER DESCRIPTION
inp

Input tensor

TYPE: (B, nb_features[0], *inp_size)

return_all

Return all intermediate output tensors (at each level).

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
out

Output tensor(s).

  • If return_all, return all intermediate tensors, from coarsest to finest.
  • Else, return the final tensor only.

TYPE: [tuple of] (B, nb_features[n], *out_size[n]) tensor

Source code in cassetta/backbones/unet.py
def forward(self, inp, *, return_all=False):
    """
    Parameters
    ----------
    inp : (B, nb_features[0], *inp_size)
        Input tensor
    return_all : bool
        Return all intermediate output tensors (at each level).

    Returns
    -------
    out : [tuple of] (B, nb_features[n], *out_size[n]) tensor
        Output tensor(s).

        - If `return_all`, return all intermediate tensors, from
          coarsest to finest.
        - Else, return the final tensor only.
    """
    out = self.encoder(inp, return_all=bool(self.skip))
    if self.skip:
        out = list(reversed(out))
    else:
        out = [out]
    return self.decoder(*out, return_all=return_all)

cassetta.backbones.atrous

MeshNet

MeshNet(ndim, nb_features=21, nb_layers=6, nb_conv_per_layer=2, dilation=1, mul_dilation=2, kernel_size=3, residual=False, activation='ReLU', norm='batch', dropout=None, attention=None, order='caxnd')

Bases: ModuleGroup

A stack of dilated convolutions

Diagram

flowchart LR
    II0["`[F, W]`"]         ---CI0("`ConvGroup(dilation=1)`"):::w-->
    IO0["`[F, W]`"]         ---CI1("`ConvGroup(dilation=2)`"):::w-->
    IO1["`[F, W]`"]         ---CI2("`ConvGroup(dilation=4)`"):::w-->
    OO2["`[F, W]`"]         ---CO1("`ConvGroup(dilation=8)`"):::w-->
    OO1["`[F, W]`"]         ---CO0("`ConvGroup(dilation=16)`"):::w-->
    OO0["`[F, W]`"]:::o
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef o fill:mistyrose,stroke:lightpink;

Difference with Fedorov et al.

  • Default parameters are from Fedorov et al.
  • However, Fedorov et al. end with a final convolution block with dilation=1, which our default network discards.
  • To recover their behavior, explictely set the dilation list: dilation=[1, 2, 4, 8, 16, 1].

References

  1. Yu & Koltun, "Multi-Scale Context Aggregation by Dilated Convolutions." ICLR (2016). arxiv:1511.07122

  2. Fedorov, Johnson, Damaraju, Ozerin, Calhoun & Plis, "End-to-end learning of brain tissue segmentation from imperfect labeling." IJCNN (2017). arxiv:1612.00940

PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

nb_features

Number of features at the finest level. If a list, number of features at each level of the encoder.

TYPE: int DEFAULT: 21

nb_layers

Number of levels in the network.

TYPE: int DEFAULT: 6

nb_conv_per_layers

Number of convolutional blocks in each layer.

TYPE: int

dilation

Dilation factor in the first layer. If a list, number of features in each layer.

TYPE: [list of] int DEFAULT: 1

mul_dilation

Multiply the dilation by this number each time we go down one level.

TYPE: int DEFAULT: 2

kernel_size

Kernel size

TYPE: [list of] int DEFAULT: 3

residual

Use residual connections between convolutional blocks and between layers.

TYPE: bool DEFAULT: False

activation

Type of activation

TYPE: ActivationLike DEFAULT: 'ReLU'

norm

Normalization

TYPE: NormType DEFAULT: 'batch'

dropout

Channel dropout probability

TYPE: DropoutType DEFAULT: None

attention

Attention

TYPE: AttentionType DEFAULT: None

order

Modules order (permutation of 'ncdax')

TYPE: str DEFAULT: 'caxnd'

Source code in cassetta/backbones/atrous.py
def __init__(
    self,
    ndim: int,
    nb_features: int = 21,
    nb_layers: int = 6,
    nb_conv_per_layer: int = 2,
    dilation: OneOrSeveral[int] = 1,
    mul_dilation: int = 2,
    kernel_size: OneOrSeveral[int] = 3,
    residual: bool = False,
    activation: ActivationType = 'ReLU',
    norm: NormType = 'batch',
    dropout: DropoutType = None,
    attention: AttentionType = None,
    order: str = 'caxnd',
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    nb_features : int
        Number of features at the finest level.
        If a list, number of features at each level of the encoder.
    nb_layers : int
        Number of levels in the network.
    nb_conv_per_layers : int
        Number of convolutional blocks in each layer.
    dilation : [list of] int
        Dilation factor in the first layer.
        If a list, number of features in each layer.
    mul_dilation : int
        Multiply the dilation by this number
        each time we go down one level.
    kernel_size : [list of] int
        Kernel size
    residual : bool
        Use residual connections between convolutional blocks and
        between layers.
    activation : ActivationLike
        Type of activation
    norm : NormType
        Normalization
    dropout : DropoutType
        Channel dropout probability
    attention : AttentionType
        Attention
    order : str
        Modules order (permutation of 'ncdax')
    """
    if isinstance(dilation, int):
        dilation = [
            int(dilation * mul_dilation**layer)
            for layer in range(nb_layers)
        ]
    else:
        dilation = list(dilation)
        dilation += [
            int(dilation[-1:] * mul_dilation**layer)
            for layer in range(nb_layers - len(dilation))
        ]
        dilation = dilation[:nb_layers]

    make_layer = partial(
        ConvGroup,
        ndim,
        channels=nb_features,
        kernel_size=kernel_size,
        residual=residual,
        activation=activation,
        norm=norm,
        dropout=dropout,
        attention=attention,
        order=order,
        nb_conv=nb_conv_per_layer,
    )
    layers = [make_layer(dilation=d) for d in dilation]
    super().__init__(layers, residual=residual)

ATrousNet

ATrousNet(ndim, nb_features=21, nb_levels=5, nb_conv_per_level=2, dilation=1, mul_dilation=2, kernel_size=3, residual=False, activation='ReLU', norm='batch', dropout=None, attention=None, order='caxnd')

Bases: ModuleGroup

Parallel dilated convolutions

Diagram

flowchart LR
    1["`[F, W]`"] ---C11("`ConvGroup(dilation=1)`"):::w-->
    2["`[F, W]`"] ---C21("`ConvGroup(dilation=1)`"):::w--> 3["`[F, W]`"]
    2             ---C22("`ConvGroup(dilation=2)`"):::w--> 4["`[F, W]`"]
    3 & 4         ---Z2(("+")):::d-->
    5["`[F, W]`"] ---C31("`ConvGroup(dilation=1)`"):::w--> 6["`[F, W]`"]
    5             ---C32("`ConvGroup(dilation=2)`"):::w--> 7["`[F, W]`"]
    5             ---C34("`ConvGroup(dilation=4)`"):::w--> 8["`[F, W]`"]
    6 & 7 & 8     ---Z3(("+")):::d-->
    9["`[F, W]`"] ---C41("`ConvGroup(dilation=1)`"):::w-->10["`[F, W]`"]
    9             ---C42("`ConvGroup(dilation=2)`"):::w-->11["`[F, W]`"]
    10 & 11       ---Z4(("+")):::d-->
    12["`[F, W]`"]---C51("`ConvGroup(dilation=1)`"):::w-->13["`[F, W]`"]:::o
    classDef w fill:papayawhip,stroke:peachpuff;
    classDef d fill:lightcyan,stroke:lightblue;
    classDef o fill:mistyrose,stroke:lightpink;

Reference

Chen, Papandreou, Kokkinos, Murphy & Yuille, "DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs." TPAMI (2017). arxiv:1606.00915

PARAMETER DESCRIPTION
ndim

Number of spatial dimensions

TYPE: int

nb_features

Number of features at the finest level. If a list, number of features at each level of the encoder.

TYPE: int DEFAULT: 21

nb_levels

Number of levels in the network.

TYPE: int DEFAULT: 5

nb_conv_per_level

Number of convolutional blocks in each layer.

TYPE: int DEFAULT: 2

dilation

Dilation factor in the first layer. If a list, number of features in each layer.

TYPE: [list of] int DEFAULT: 1

mul_dilation

Multiply the dilation by this number each time we go down one level.

TYPE: int DEFAULT: 2

kernel_size

Kernel size

TYPE: [list of] int DEFAULT: 3

residual

Use residual connections between convolutional blocks and between layers.

TYPE: bool DEFAULT: False

activation

Type of activation

TYPE: ActivationLike DEFAULT: 'ReLU'

norm

Normalization

TYPE: NormType DEFAULT: 'batch'

dropout

Channel dropout probability

TYPE: DropoutType DEFAULT: None

attention

Attention

TYPE: AttentionType DEFAULT: None

order

Modules order (permutation of 'ncdax')

TYPE: str DEFAULT: 'caxnd'

Source code in cassetta/backbones/atrous.py
def __init__(
    self,
    ndim: int,
    nb_features: int = 21,
    nb_levels: int = 5,
    nb_conv_per_level: int = 2,
    dilation: OneOrSeveral[int] = 1,
    mul_dilation: int = 2,
    kernel_size: OneOrSeveral[int] = 3,
    residual: bool = False,
    activation: ActivationType = 'ReLU',
    norm: NormType = 'batch',
    dropout: DropoutType = None,
    attention: AttentionType = None,
    order: str = 'caxnd',
):
    """
    Parameters
    ----------
    ndim : int
        Number of spatial dimensions
    nb_features : int
        Number of features at the finest level.
        If a list, number of features at each level of the encoder.
    nb_levels : int
        Number of levels in the network.
    nb_conv_per_level : int
        Number of convolutional blocks in each layer.
    dilation : [list of] int
        Dilation factor in the first layer.
        If a list, number of features in each layer.
    mul_dilation : int
        Multiply the dilation by this number
        each time we go down one level.
    kernel_size : [list of] int
        Kernel size
    residual : bool
        Use residual connections between convolutional blocks and
        between layers.
    activation : ActivationLike
        Type of activation
    norm : NormType
        Normalization
    dropout : DropoutType
        Channel dropout probability
    attention : AttentionType
        Attention
    order : str
        Modules order (permutation of 'ncdax')
    """
    if isinstance(dilation, int):
        enc_dilation = [
            int(dilation * mul_dilation**layer)
            for layer in range(nb_levels)
        ]
    else:
        dilation = list(dilation)
        enc_dilation += [
            int(enc_dilation[-1:] * mul_dilation**layer)
            for layer in range(nb_levels - len(dilation))
        ]
        enc_dilation = enc_dilation[:nb_levels]

    make_layer = partial(
        ConvGroup,
        ndim,
        channels=nb_features,
        kernel_size=kernel_size,
        residual=residual,
        activation=activation,
        norm=norm,
        dropout=dropout,
        attention=attention,
        order=order,
        nb_conv=nb_conv_per_level,
    )
    layers = []
    for level in range(nb_levels):
        sublayers = []
        for n in range(level+1):
            sublayers += [make_layer(dilation=enc_dilation[n])]
        layers += [ModuleSum(sublayers)]
    for level in range(nb_levels-1, 0, -1):
        sublayers = []
        for n in range(level):
            sublayers += [make_layer(dilation=enc_dilation[n])]
        layers += [ModuleSum(sublayers)]
    super().__init__(layers, residual=residual)