Skip to content

cassetta.losses

Overview

Losses are modules that return a scalar tensor and must be differentiable.

MODULE DESCRIPTION
base

Base class for all losses

segmentation

Losses for semantic segmentation

make_loss

make_loss(loss, *args, **kwargs)

Instantiate a loss

A loss can be:

  • the name of a cassetta loss, such as "DiceLoss";
  • the fully qualified path to a model, such as "cassetta.losses.DiceLoss", or "monai.losses.GeneralizedDiceLoss";
  • a nn.Module subclass, such as [DiceLoss][cassetta.losses.DiceLoss];
  • an already instantiated nn.Module, such as [DiceLoss()][cassetta.losses.DiceLoss].
PARAMETER DESCRIPTION
loss

Instantiated or non-instantiated loss

TYPE: LossType

*args

Positional arguments pass to the loss constructor

TYPE: tuple DEFAULT: ()

**kwargs

Keyword arguments pass to the loss constructor

TYPE: dict DEFAULT: {}

RETURNS DESCRIPTION
loss

Instantiated loss

TYPE: Module

Source code in cassetta/losses/__init__.py
def make_loss(loss: LossType, *args, **kwargs):
    """
    Instantiate a loss

    A loss can be:

    - the name of a cassetta loss, such as `"DiceLoss"`;
    - the fully qualified path to a model, such as
    `"cassetta.losses.DiceLoss"`, or `"monai.losses.GeneralizedDiceLoss"`;
    - a [`nn.Module`][torch.nn.Module] subclass, such as
    [`DiceLoss`][cassetta.losses.DiceLoss];
    - an already instantiated [`nn.Module`][torch.nn.Module], such as
    [`DiceLoss()`][cassetta.losses.DiceLoss].

    Parameters
    ----------
    loss : LossType
        Instantiated or non-instantiated loss
    *args : tuple
        Positional arguments pass to the loss constructor
    **kwargs : dict
        Keyword arguments pass to the loss constructor

    Returns
    -------
    loss : nn.Module
        Instantiated loss
    """
    reentrant = kwargs.pop('__reentrant', False)
    if isinstance(loss, str):
        if not reentrant:
            kwargs['__reentrant'] = True
            for prefix in ('', 'cassetta.', 'cassetta.losses.', 'torch.nn.'):
                try:
                    return make_loss(prefix + loss, *args, **kwargs)
                except Exception:
                    pass
        loss = import_fullname(loss)
    if not isinstance(loss, Module):
        loss = loss(*args, **kwargs)
    if not isinstance(loss, Module):
        raise ValueError('Instantiated object is not a Module')
    return loss

cassetta.losses.base

Loss

Loss(reduction='mean')

Bases: Module

Base class for losses

PARAMETER DESCRIPTION
reduction

Reduction to apply across batch elements

TYPE: (mean, sum) DEFAULT: 'mean'

Source code in cassetta/losses/base.py
def __init__(self, reduction='mean'):
    """
    Parameters
    ----------
    reduction : {'mean', 'sum'} or callable
        Reduction to apply across batch elements
    """
    super().__init__()
    self.reduction = reduction

cassetta.losses.segmentation

DiceLoss

DiceLoss(square=True, weighted=False, labels=None, eps=None, reduction='mean', activation=None)

Bases: Loss

Soft Dice Loss

By default, each class is weighted identically. The weighted mode allows classes to be weighted by frequency.

References

  1. Milletari, Navab & Ahmadi, "V-Net: Fully convolutional neural networks for volumetric medical image segmentation." 3DV (2016). arxiv:1606.04797
  2. Sudre, Li, Vercauteren, Ourselin & Cardoso, "Generalised dice overlap as a deep learning loss function for highly unbalanced segmentations." DLMIA (2017). arxiv:1707.03237
  3. Tilborghs, Bertels, Robben, Vandermeulen & Maes, "The Dice loss in the context of missing or empty labels: introducing \(\Phi\) and \(\epsilon\)." MICCAI (2022). arxiv:2207.09521
PARAMETER DESCRIPTION
square

Square the denominator in SoftDice.

TYPE: bool DEFAULT: True

weighted

If True, weight the Dice of each class by its frequency in the reference. If a list, use these weights for each class.

TYPE: bool or list[float] DEFAULT: False

labels

Label corresponding to each one-hot class. Only used if the reference is an integer label map.

TYPE: list[int] DEFAULT: range(nb_class)

eps

Stabilization of the Dice loss. Optimally, should be equal to each class' expected frequency across the whole dataset. See Tilborghs et al.

TYPE: float or list[float] DEFAULT: 1/K

reduction

Type of reduction to apply across minibatch elements.

TYPE: (mean, sum, None) DEFAULT: 'mean'

activation

Activation to apply to the prediction before computing the loss

TYPE: Module or str DEFAULT: None

Source code in cassetta/losses/segmentation.py
def __init__(self, square=True, weighted=False, labels=None,
             eps=None, reduction='mean', activation=None):
    """

    Parameters
    ----------
    square : bool, default=True
        Square the denominator in SoftDice.
    weighted : bool or list[float], default=False
        If True, weight the Dice of each class by its frequency in the
        reference. If a list, use these weights for each class.
    labels : list[int], default=range(nb_class)
        Label corresponding to each one-hot class. Only used if the
        reference is an integer label map.
    eps : float or list[float], default=1/K
        Stabilization of the Dice loss.
        Optimally, should be equal to each class' expected frequency
        across the whole dataset. See Tilborghs et al.
    reduction : {'mean', 'sum', None} or callable, default='mean'
        Type of reduction to apply across minibatch elements.
    activation : nn.Module or str
        Activation to apply to the prediction before computing the loss
    """
    super().__init__(reduction)
    self.square = square
    self.weighted = weighted
    self.labels = labels
    self.eps = eps
    self.activation = _make_activation(activation)

forward

forward(pred, ref, mask=None)
PARAMETER DESCRIPTION
pred

Predicted classes.

TYPE: (batch, nb_class, *spatial) tensor

ref

Reference classes (or their expectation).

TYPE: (batch, nb_class|1, *spatial) tensor

mask

Loss mask

TYPE: (batch, 1, *spatial) tensor DEFAULT: None

RETURNS DESCRIPTION
loss

The output shape depends on the type of reduction used. If 'mean' or 'sum', this function returns a scalar tensor.

TYPE: scalar or (batch,) tensor

Source code in cassetta/losses/segmentation.py
def forward(self, pred, ref, mask=None):
    """

    Parameters
    ----------
    pred : (batch, nb_class, *spatial) tensor
        Predicted classes.
    ref : (batch, nb_class|1, *spatial) tensor
        Reference classes (or their expectation).
    mask : (batch, 1, *spatial) tensor, optional
        Loss mask

    Returns
    -------
    loss : scalar or (batch,) tensor
        The output shape depends on the type of reduction used.
        If 'mean' or 'sum', this function returns a scalar tensor.

    """
    if self.activation:
        pred = self.activation(pred)

    nb_classes = pred.shape[1]
    backend = dict(dtype=pred.dtype, device=pred.device)
    nvox = pred.shape[2:].numel()

    eps = self.eps or 1/nb_classes
    eps = make_vector(eps, nb_classes, **backend)
    eps = eps * nvox

    # prepare weights
    weighted = self.weighted
    if not torch.is_tensor(weighted) and not weighted:
        weighted = False
    if not isinstance(weighted, bool):
        weighted = make_vector(weighted, nb_classes, **backend)

    if ref.dtype.is_floating_point:
        return self.forward_onehot(pred, ref, mask, weighted, eps)
    else:
        return self.forward_labels(pred, ref, mask, weighted, eps)

CatLoss

CatLoss(weighted=False, labels=None, reduction='mean', activation=None)

Bases: Loss

Weighted categorical cross-entropy.

By default, each class is weighted identically.

This differs from the classical "categorical cross-entropy loss", which corresponds to the true Categorical log-likelihood and where classes are therefore weighted by frequency. The default behavior of our loss is that of a "weighted categorical cross-entropy".

With weighted=True, classes are weighted by frequency.

PARAMETER DESCRIPTION
weighted

If True, weight the term of each class by its frequency in the reference. If a list, use these weights for each class.

TYPE: bool or list[float] DEFAULT: False

labels

Label corresponding to each one-hot class. Only used if the reference is an integer label map.

TYPE: list[int] DEFAULT: range(nb_class)

reduction

Type of reduction to apply across minibatch elements.

TYPE: (mean, sum, None) DEFAULT: 'mean'

activation

Activation to apply to the prediction before computing the loss

TYPE: Module or str DEFAULT: None

Source code in cassetta/losses/segmentation.py
def __init__(self, weighted=False, labels=None,
             reduction='mean', activation=None):
    """

    Parameters
    ----------
    weighted : bool or list[float], default=False
        If True, weight the term of each class by its frequency
         in the reference. If a list, use these weights for each class.
    labels : list[int], default=range(nb_class)
        Label corresponding to each one-hot class. Only used if the
        reference is an integer label map.
    reduction : {'mean', 'sum', None} or callable, default='mean'
        Type of reduction to apply across minibatch elements.
    activation : nn.Module or str
        Activation to apply to the prediction before computing the loss
    """
    super().__init__(reduction)
    self.weighted = weighted
    self.labels = labels
    self.reduction = reduction
    self.activation = _make_activation(activation)

forward

forward(pred, ref, mask=None)
PARAMETER DESCRIPTION
pred

Predicted classes.

TYPE: (batch, nb_class, *spatial) tensor

ref

Reference classes (or their expectation).

TYPE: (batch, nb_class|1, *spatial) tensor

mask

Loss mask

TYPE: (batch, 1, *spatial) tensor DEFAULT: None

RETURNS DESCRIPTION
loss

The output shape depends on the type of reduction used. If 'mean' or 'sum', this function returns a scalar tensor.

TYPE: scalar or (batch,) tensor

Source code in cassetta/losses/segmentation.py
def forward(self, pred, ref, mask=None):
    """

    Parameters
    ----------
    pred : (batch, nb_class, *spatial) tensor
        Predicted classes.
    ref : (batch, nb_class|1, *spatial) tensor
        Reference classes (or their expectation).
    mask : (batch, 1, *spatial) tensor, optional
        Loss mask

    Returns
    -------
    loss : scalar or (batch,) tensor
        The output shape depends on the type of reduction used.
        If 'mean' or 'sum', this function returns a scalar tensor.

    """
    if self.activation:
        pred = self.activation(pred)

    nb_classes = pred.shape[1]
    backend = dict(dtype=pred.dtype, device=pred.device)

    pred = pred.log()
    pred.masked_fill_(~torch.isfinite(pred), 0)

    # prepare weights
    weighted = self.weighted
    if not torch.is_tensor(weighted) and not weighted:
        weighted = False
    if not isinstance(weighted, bool):
        weighted = make_vector(weighted, nb_classes, **backend)

    if ref.dtype.is_floating_point:
        return self.forward_onehot(pred, ref, mask, weighted)
    else:
        return self.forward_labels(pred, ref, mask, weighted)

CatMSELoss

CatMSELoss(weighted=False, labels=None, reduction='mean', activation=None)

Bases: Loss

Mean Squared Error between one-hots.

PARAMETER DESCRIPTION
weighted

If True, weight the Dice of each class by its size in the reference. If a list, use these weights for each class.

TYPE: bool or list[float] DEFAULT: False

labels

Label corresponding to each one-hot class. Only used if the reference is an integer label map.

TYPE: list[int] DEFAULT: range(nb_class)

reduction

Type of reduction to apply across minibatch elements.

TYPE: (mean, sum, None) DEFAULT: 'mean'

activation

Activation to apply to the prediction before computing the loss

TYPE: Module or str DEFAULT: None

Source code in cassetta/losses/segmentation.py
def __init__(self, weighted=False, labels=None, reduction='mean',
             activation=None):
    """

    Parameters
    ----------
    weighted : bool or list[float], default=False
        If True, weight the Dice of each class by its size in the
        reference. If a list, use these weights for each class.
    labels : list[int], default=range(nb_class)
        Label corresponding to each one-hot class. Only used if the
        reference is an integer label map.
    reduction : {'mean', 'sum', None} or callable, default='mean'
        Type of reduction to apply across minibatch elements.
    activation : nn.Module or str
        Activation to apply to the prediction before computing the loss
    """
    super().__init__(reduction)
    self.weighted = weighted
    self.labels = labels
    self.reduction = reduction
    if isinstance(activation, str):
        activation = getattr(nn, activation)
    self.activation = activation

forward

forward(pred, ref, mask=None)
PARAMETER DESCRIPTION
pred

Predicted classes.

TYPE: (batch, nb_class, *spatial) tensor

ref

Reference classes (or their expectation).

TYPE: (batch, nb_class|1, *spatial) tensor

mask

Loss mask

TYPE: (batch, 1, *spatial) tensor DEFAULT: None

RETURNS DESCRIPTION
loss

The output shape depends on the type of reduction used. If 'mean' or 'sum', this function returns a scalar tensor.

TYPE: scalar or (batch,) tensor

Source code in cassetta/losses/segmentation.py
def forward(self, pred, ref, mask=None):
    """

    Parameters
    ----------
    pred : (batch, nb_class, *spatial) tensor
        Predicted classes.
    ref : (batch, nb_class|1, *spatial) tensor
        Reference classes (or their expectation).
    mask : (batch, 1, *spatial) tensor, optional
        Loss mask

    Returns
    -------
    loss : scalar or (batch,) tensor
        The output shape depends on the type of reduction used.
        If 'mean' or 'sum', this function returns a scalar tensor.

    """
    if self.activation:
        pred = self.activation(pred)

    nb_classes = pred.shape[1]
    backend = dict(dtype=pred.dtype, device=pred.device)

    # prepare weights
    weighted = self.weighted
    if not torch.is_tensor(weighted) and not weighted:
        weighted = False
    if not isinstance(weighted, bool):
        weighted = make_vector(weighted, nb_classes, **backend)

    if ref.dtype.is_floating_point:
        return self.forward_onehot(pred, ref, mask, weighted)
    else:
        return self.forward_labels(pred, ref, mask, weighted)

LogitMSELoss

LogitMSELoss(target=5, weighted=False, labels=None, reduction='mean', activation=None)

Bases: Loss

Mean Squared Error between logits and target positive/negative values.

PARAMETER DESCRIPTION
target

Target value when the ground truth is True.

TYPE: float DEFAULT: 5

weighted

If True, weight the score of each class by its frequency in the reference. If 'inv', weight the score of each class by its inverse frequency in the reference. If a list, use these weights for each class.

TYPE: bool or list[float] or inv DEFAULT: False

labels

Label corresponding to each one-hot class. Only used if the reference is an integer label map.

TYPE: list[int] DEFAULT: range(nb_class)

reduction

Type of reduction to apply across minibatch elements.

TYPE: (mean, sum, None) DEFAULT: 'mean'

activation

Activation to apply to the prediction before computing the loss

TYPE: Module or str DEFAULT: None

Source code in cassetta/losses/segmentation.py
def __init__(self, target=5, weighted=False, labels=None, reduction='mean',
             activation=None):
    """

    Parameters
    ----------
    target : float
        Target value when the ground truth is True.
    weighted : bool or list[float] or 'inv', default=False
        If True, weight the score of each class by its frequency in
        the reference.
        If 'inv', weight the score of each class by its inverse
        frequency in the reference.
        If a list, use these weights for each class.
    labels : list[int], default=range(nb_class)
        Label corresponding to each one-hot class. Only used if the
        reference is an integer label map.
    reduction : {'mean', 'sum', None} or callable, default='mean'
        Type of reduction to apply across minibatch elements.
    activation : nn.Module or str
        Activation to apply to the prediction before computing the loss
    """
    super().__init__(reduction)
    self.weighted = weighted
    self.labels = labels
    self.reduction = reduction
    self.target = target
    if isinstance(activation, str):
        activation = getattr(nn, activation)
    self.activation = activation

forward

forward(pred, ref, mask=None)
PARAMETER DESCRIPTION
pred

Predicted classes.

TYPE: (batch, nb_class, *spatial) tensor

ref

Reference classes (or their expectation).

TYPE: (batch, nb_class|1, *spatial) tensor

mask

Loss mask

TYPE: (batch, 1, *spatial) tensor DEFAULT: None

RETURNS DESCRIPTION
loss

The output shape depends on the type of reduction used. If 'mean' or 'sum', this function returns a scalar tensor.

TYPE: scalar or (batch,) tensor

Source code in cassetta/losses/segmentation.py
def forward(self, pred, ref, mask=None):
    """

    Parameters
    ----------
    pred : (batch, nb_class, *spatial) tensor
        Predicted classes.
    ref : (batch, nb_class|1, *spatial) tensor
        Reference classes (or their expectation).
    mask : (batch, 1, *spatial) tensor, optional
        Loss mask

    Returns
    -------
    loss : scalar or (batch,) tensor
        The output shape depends on the type of reduction used.
        If 'mean' or 'sum', this function returns a scalar tensor.

    """
    if self.activation:
        pred = self.activation(pred)

    nb_classes = pred.shape[1]
    backend = dict(dtype=pred.dtype, device=pred.device)

    # prepare weights
    weighted = self.weighted
    if not torch.is_tensor(weighted) and not weighted:
        weighted = False
    if not isinstance(weighted, bool):
        weighted = make_vector(weighted, nb_classes, **backend)

    if ref.dtype.is_floating_point:
        return self.forward_onehot(pred, ref, mask, weighted)
    else:
        return self.forward_labels(pred, ref, mask, weighted)