Skip to content

cassetta.functional

Overview

This module contains utility functions that we find are either incomplete or entirely missing in PyTorch. All of them are rather low-level functions, and do not form an object-oriented API like the rest of the package.

cassetta.functional.jit

Indexing

[ind2sub][cassetta.functional.jit.ind2sub]

Convert linear indices into sub indices (i, j, k).

[sub2ind][cassetta.functional.jit.sub2ind], [sub2ind_list][cassetta.functional.jit.sub2ind_list]

Convert sub indices (i, j, k) into linear indices.

Math

[square][cassetta.functional.jit.square], [square_][cassetta.functional.jit.square_]

x**2

[cube][cassetta.functional.jit.cube], [cube_][cassetta.functional.jit.cube_]

x**3

[pow4][cassetta.functional.jit.pow4], [pow4_][cassetta.functional.jit.pow4_]

x**4

[pow5][cassetta.functional.jit.pow5], [pow5_][cassetta.functional.jit.pow5_]

x**5

[pow6][cassetta.functional.jit.pow6], [pow6_][cassetta.functional.jit.pow6_]

x**6

[pow7][cassetta.functional.jit.pow7], [pow7_][cassetta.functional.jit.pow7_]

x**7

[floor_div][cassetta.functional.jit.floor_div], [floor_div_int][cassetta.functional.jit.floor_div_int]

floor(x / y)

[trunc_div][cassetta.functional.jit.trunc_div], [trunc_div_int][cassetta.functional.jit.trunc_div_int]

trunc(x / y)

Meshgrid

[meshgrid_list_ij][cassetta.functional.jit.meshgrid_list_ij]

Meshgrid with indexing="ij"

[meshgrid_list_xy][cassetta.functional.jit.meshgrid_list_xy]

Meshgrid with indexing="xy"

Python objects

[pad_list_int][cassetta.functional.jit.pad_list_int] [pad_list_float][cassetta.functional.jit.pad_list_float] [pad_list_str][cassetta.functional.jit.pad_list_str]

Pad a list

[any_list_bool][cassetta.functional.jit.any_list_bool]

any(list)

[all_list_bool][cassetta.functional.jit.all_list_bool]

all(list)

[prod_list_int][cassetta.functional.jit.prod_list_int]

prod(list)

[sum_list_int][cassetta.functional.jit.sum_list_int]

sum(list)

[reverse_list_int][cassetta.functional.jit.reverse_list_int]

reversed(list)

[cumprod_list_int][cassetta.functional.jit.cumprod_list_int]

Cumulative product

Tensors

[prod_list_tensor][cassetta.functional.jit.prod_list_tensor]

prod(list[tensor])

[sum_list_tensor][cassetta.functional.jit.sum_list_tensor]

sum(list[tensor])

[movedim][cassetta.functional.jit.movedim]

movedim(tensor, src, dst)

cassetta.functional.jit.indexing

ind2sub

ind2sub(ind, shape)

Convert linear indices into sub indices (i, j, k).

The rightmost dimension is the most rapidly changing one -> if shape == [D, H, W], the strides are therefore [H*W, W, 1]

PARAMETER DESCRIPTION
ind

Linear indices

TYPE: tensor_like

shape

Size of each dimension.

TYPE: (D,) vector_like

RETURNS DESCRIPTION
subs

Sub-indices.

TYPE: (D, ...) tensor

Source code in cassetta/functional/jit/indexing.py
@torch.jit.script
def ind2sub(ind, shape: List[int]):
    """Convert linear indices into sub indices (i, j, k).

    The rightmost dimension is the most rapidly changing one
    -> if shape == [D, H, W], the strides are therefore [H*W, W, 1]

    Parameters
    ----------
    ind : tensor_like
        Linear indices
    shape : (D,) vector_like
        Size of each dimension.

    Returns
    -------
    subs : (D, ...) tensor
        Sub-indices.
    """
    stride = cumprod_list_int(shape, reverse=True, exclusive=True)
    sub = ind.new_empty([len(shape)] + ind.shape)
    sub.copy_(ind)
    for d in range(len(shape)):
        if d > 0:
            sub[d] = torch.remainder(sub[d], stride[d-1])
        sub[d] = floor_div_int(sub[d], stride[d])
    return sub

sub2ind

sub2ind(subs, shape)

Convert sub indices (i, j, k) into linear indices.

The rightmost dimension is the most rapidly changing one -> if shape == [D, H, W], the strides are therefore [H*W, W, 1]

PARAMETER DESCRIPTION
subs

List of sub-indices. The first dimension is the number of dimension. Each element should have the same number of elements and shape.

TYPE: (D, ...) tensor

shape

Size of each dimension. Its length should be the same as the first dimension of subs.

TYPE: (D,) list[int]

RETURNS DESCRIPTION
ind

Linear indices

TYPE: (...) tensor

Source code in cassetta/functional/jit/indexing.py
@torch.jit.script
def sub2ind(subs, shape: List[int]):
    """Convert sub indices (i, j, k) into linear indices.

    The rightmost dimension is the most rapidly changing one
    -> if shape == [D, H, W], the strides are therefore [H*W, W, 1]

    Parameters
    ----------
    subs : (D, ...) tensor
        List of sub-indices. The first dimension is the number of dimension.
        Each element should have the same number of elements and shape.
    shape : (D,) list[int]
        Size of each dimension. Its length should be the same as the
        first dimension of ``subs``.

    Returns
    -------
    ind : (...) tensor
        Linear indices
    """
    subs = subs.unbind(0)
    ind = subs[-1]
    subs = subs[:-1]
    ind = ind.clone()
    stride = cumprod_list_int(shape[1:], reverse=True, exclusive=False)
    for i, s in zip(subs, stride):
        ind += i * s
    return ind

sub2ind_list

sub2ind_list(subs, shape)

Convert sub indices (i, j, k) into linear indices.

The rightmost dimension is the most rapidly changing one -> if shape == [D, H, W], the strides are therefore [H*W, W, 1]

PARAMETER DESCRIPTION
subs

List of sub-indices. The first dimension is the number of dimension. Each element should have the same number of elements and shape.

TYPE: (D,) list[tensor]

shape

Size of each dimension. Its length should be the same as the first dimension of subs.

TYPE: (D,) list[int]

RETURNS DESCRIPTION
ind

Linear indices

TYPE: (...) tensor

Source code in cassetta/functional/jit/indexing.py
@torch.jit.script
def sub2ind_list(subs: List[Tensor], shape: List[int]):
    """Convert sub indices (i, j, k) into linear indices.

    The rightmost dimension is the most rapidly changing one
    -> if shape == [D, H, W], the strides are therefore [H*W, W, 1]

    Parameters
    ----------
    subs : (D,) list[tensor]
        List of sub-indices. The first dimension is the number of dimension.
        Each element should have the same number of elements and shape.
    shape : (D,) list[int]
        Size of each dimension. Its length should be the same as the
        first dimension of ``subs``.

    Returns
    -------
    ind : (...) tensor
        Linear indices
    """
    ind = subs[-1]
    subs = subs[:-1]
    ind = ind.clone()
    stride = cumprod_list_int(shape[1:], reverse=True, exclusive=False)
    for i, s in zip(subs, stride):
        ind += i * s
    return ind

cassetta.functional.jit.math

square

square(x)
Source code in cassetta/functional/jit/math.py
@torch.jit.script
def square(x):
    """"""
    return x * x

square_

square_(x)
Source code in cassetta/functional/jit/math.py
@torch.jit.script
def square_(x):
    """"""
    return x.mul_(x)

cube

cube(x)
Source code in cassetta/functional/jit/math.py
@torch.jit.script
def cube(x):
    """"""
    return x * x * x

cube_

cube_(x)
Source code in cassetta/functional/jit/math.py
@torch.jit.script
def cube_(x):
    """"""
    return square_(x).mul_(x)

pow4

pow4(x)
Source code in cassetta/functional/jit/math.py
@torch.jit.script
def pow4(x):
    """"""
    return square(square(x))

pow4_

pow4_(x)
Source code in cassetta/functional/jit/math.py
@torch.jit.script
def pow4_(x):
    """"""
    return square_(square_(x))

pow5

pow5(x)
Source code in cassetta/functional/jit/math.py
@torch.jit.script
def pow5(x):
    """"""
    return x * pow4(x)

pow5_

pow5_(x)
Source code in cassetta/functional/jit/math.py
@torch.jit.script
def pow5_(x):
    """"""
    return pow4_(x).mul_(x)

pow6

pow6(x)
Source code in cassetta/functional/jit/math.py
@torch.jit.script
def pow6(x):
    """"""
    return square(cube(x))

pow6_

pow6_(x)
Source code in cassetta/functional/jit/math.py
@torch.jit.script
def pow6_(x):
    """"""
    return square_(cube_(x))

pow7

pow7(x)
Source code in cassetta/functional/jit/math.py
@torch.jit.script
def pow7(x):
    """"""
    return pow6(x) * x

pow7_

pow7_(x)
Source code in cassetta/functional/jit/math.py
@torch.jit.script
def pow7_(x):
    """"""
    return pow6_(x).mul_(x)

floor_div

floor_div(x, y)
Source code in cassetta/functional/jit/math.py
@torch.jit.script
def floor_div(x, y) -> torch.Tensor:
    """"""
    return torch.div(x, y, rounding_mode='floor')

floor_div_int

floor_div_int(x, y)
Source code in cassetta/functional/jit/math.py
@torch.jit.script
def floor_div_int(x, y: int) -> torch.Tensor:
    """"""
    return torch.div(x, y, rounding_mode='floor')

trunc_div

trunc_div(x, y)
Source code in cassetta/functional/jit/math.py
@torch.jit.script
def trunc_div(x, y) -> torch.Tensor:
    """"""
    return torch.div(x, y, rounding_mode='trunc')

trunc_div_int

trunc_div_int(x, y)
Source code in cassetta/functional/jit/math.py
@torch.jit.script
def trunc_div_int(x, y: int) -> torch.Tensor:
    """"""
    return torch.div(x, y, rounding_mode='trunc')

cassetta.functional.jit.meshgrid

Eager vs TorchScript

The signature of torch.meshgrid differs between eager and torchscript modes.

In eager mode, meshgrid takes an unpacked list of tensors as inputs:

gx, gy = torch.meshgrid(coordx, coordy)

In torchscript mode, it takes a (packed) list of tensors instead:

gx, gy = torch.meshgrid([coordx, coordy])

This makes writing code that works in both eager (by setting the environment variable PYTORCH_JIT=0) and torchscript modes (with the default PYTORCH_JIT=1) complicated. Instead, we define functions with explicit names (meshgrid_list_*) that always take a (packed) list of tensors as input.

Backward compatibility

For torch<1.10, torch.meshgrid worked in "ij" mode, meaning that the first tensor of coordinates was mapped to the first dimension of the output grid and the second tensor of coordinates was mapped to the second dimension of the output grid. Starting with torch=1.10, the keyword argument indexing, which takes value "ij" or "xy", was introduced. Furthermore, the default behavior of the function when indexing is not used will change from "ij" to "xy" in the future.

To make any code backward compatible, we define explicit functions postfixed by either _ij or _xy.

meshgrid_list_ij

meshgrid_list_ij(tensors)

Creates grids of coordinates specified by the 1D inputs in tensors.

This is helpful when you want to visualize data over some range of inputs.

Given \(N\) 1D tensors \(T_0, \dots, T_{N-1}\) as inputs with corresponding sizes \(S_0, \dots, S_{N-1}\), this creates \(N\) N-dimensional tensors \(G_0, \dots, G_{N-1}\), each with shape \((S_0, \dots, S_{N-1})\) where the output \(G_i\) is constructed by expanding \(T_i\) to the result shape.

Note

0D inputs are treated equivalently to 1D inputs of a single element.

PARAMETER DESCRIPTION
tensors

list of scalars or 1 dimensional tensors. Scalars will be treated as tensors of size \((1,)\) automatically

TYPE: list[tensor]

RETURNS DESCRIPTION
seq

list of expanded tensors

TYPE: list[tensor]

Source code in cassetta/functional/jit/meshgrid.py
@torch.jit.script
def meshgrid_list_ij(tensors: TensorList) -> TensorList:
    r"""
    Creates grids of coordinates specified by the 1D inputs in `tensors`.

    This is helpful when you want to visualize data over some
    range of inputs.

    Given $N$ 1D tensors $T_0, \dots, T_{N-1}$ as inputs with
    corresponding sizes $S_0, \dots, S_{N-1}$, this creates $N$
    N-dimensional tensors $G_0, \dots, G_{N-1}$, each with shape
    $(S_0, \dots, S_{N-1})$ where the output $G_i$ is constructed
    by expanding $T_i$ to the result shape.

    !!! note
        0D inputs are treated equivalently to 1D inputs of a
        single element.

    Parameters
    ----------
    tensors : list[tensor]
        list of scalars or 1 dimensional tensors. Scalars will be
        treated as tensors of size $(1,)$ automatically

    Returns
    -------
    seq : list[tensor]
        list of expanded tensors
    """
    return list(torch.meshgrid(tensors, indexing='ij'))

meshgrid_list_xy

meshgrid_list_xy(tensors)

Creates grids of coordinates specified by the 1D inputs in tensors.

This is helpful when you want to visualize data over some range of inputs.

Given \(N\) 1D tensors \(T_0, \dots, T_{N-1}\) as inputs with corresponding sizes \(S_0, \dots, S_{N-1}\), this creates \(N\) N-dimensional tensors \(G_0, \dots, G_{N-1}\), each with shape \((S_0, \dots, S_{N-1})\) where the output \(G_i\) is constructed by expanding \(T_i\) to the result shape.

Note

0D inputs are treated equivalently to 1D inputs of a single element.

Warning

In mode xy, the first dimension of the output corresponds to the cardinality of the second input and the second dimension of the output corresponds to the cardinality of the first input.

PARAMETER DESCRIPTION
tensors

list of scalars or 1 dimensional tensors. Scalars will be treated as tensors of size \((1,)\) automatically

TYPE: list[tensor]

RETURNS DESCRIPTION
seq

list of expanded tensors

TYPE: list[tensor]

Source code in cassetta/functional/jit/meshgrid.py
@torch.jit.script
def meshgrid_list_xy(tensors: TensorList) -> TensorList:
    r"""
    Creates grids of coordinates specified by the 1D inputs in `tensors`.

    This is helpful when you want to visualize data over some
    range of inputs.

    Given $N$ 1D tensors $T_0, \dots, T_{N-1}$ as inputs with
    corresponding sizes $S_0, \dots, S_{N-1}$, this creates $N$
    N-dimensional tensors $G_0, \dots, G_{N-1}$, each with shape
    $(S_0, \dots, S_{N-1})$ where the output $G_i$ is constructed
    by expanding $T_i$ to the result shape.

    !!! note
        0D inputs are treated equivalently to 1D inputs of a
        single element.

    !!! warning
        In mode `xy`, the first dimension of the output corresponds to the
        cardinality of the second input and the second dimension of the output
        corresponds to the cardinality of the first input.

    Parameters
    ----------
    tensors : list[tensor]
        list of scalars or 1 dimensional tensors. Scalars will be
        treated as tensors of size $(1,)$ automatically

    Returns
    -------
    seq : list[tensor]
        list of expanded tensors
    """
    return list(torch.meshgrid(tensors, indexing='xy'))

cassetta.functional.jit.python

TorchScript compatible functions that act on Python bultins.

pad_list_int

pad_list_int(x, length)

Pad/crop a list of int until it reaches a target length.

Note

If padding, the last element gets replicated.

PARAMETER DESCRIPTION
x

List of int

TYPE: list[int]

length

Target length

TYPE: int

RETURNS DESCRIPTION
x

List of length length.

TYPE: (length,) list[int]

Source code in cassetta/functional/jit/python.py
@torch.jit.script
def pad_list_int(x: List[int], length: int) -> List[int]:
    """
    Pad/crop a list of int until it reaches a target length.

    !!! note
        If padding, the last element gets replicated.

    Parameters
    ----------
    x : list[int]
        List of int
    length : int
        Target length

    Returns
    -------
    x : (length,) list[int]
        List of length `length`.

    """
    if len(x) < length:
        x = x + x[-1:] * (length - len(x))
    if len(x) > length:
        x = x[:length]
    return x

pad_list_float

pad_list_float(x, dim)

See [pad_list_int][cassetta.functional.jit.pad_list_int].

Source code in cassetta/functional/jit/python.py
@torch.jit.script
def pad_list_float(x: List[float], dim: int) -> List[float]:
    """
    See [`pad_list_int`][cassetta.functional.jit.pad_list_int].
    """
    if len(x) < dim:
        x = x + x[-1:] * (dim - len(x))
    if len(x) > dim:
        x = x[:dim]
    return x

pad_list_str

pad_list_str(x, dim)

See [pad_list_int][cassetta.functional.jit.pad_list_int].

Source code in cassetta/functional/jit/python.py
@torch.jit.script
def pad_list_str(x: List[str], dim: int) -> List[str]:
    """
    See [`pad_list_int`][cassetta.functional.jit.pad_list_int].
    """
    if len(x) < dim:
        x = x + x[-1:] * (dim - len(x))
    if len(x) > dim:
        x = x[:dim]
    return x

any_list_bool

any_list_bool(x)

TorchScript equivalent to any(x)

Source code in cassetta/functional/jit/python.py
@torch.jit.script
def any_list_bool(x: List[bool]) -> bool:
    """TorchScript equivalent to `any(x)`"""
    for elem in x:
        if elem:
            return True
    return False

all_list_bool

all_list_bool(x)

TorchScript equivalent to all(x)

Source code in cassetta/functional/jit/python.py
@torch.jit.script
def all_list_bool(x: List[bool]) -> bool:
    """TorchScript equivalent to `all(x)`"""
    for elem in x:
        if not elem:
            return False
    return True

prod_list_int

prod_list_int(x)

Compute the product of elements in the list

Source code in cassetta/functional/jit/python.py
@torch.jit.script
def prod_list_int(x: List[int]) -> int:
    """Compute the product of elements in the list"""
    if len(x) == 0:
        return 1
    x0 = x[0]
    for x1 in x[1:]:
        x0 = x0 * x1
    return x0

sum_list_int

sum_list_int(x)

Compute the sum of elements in the list. Equivalent to sum(x).

Source code in cassetta/functional/jit/python.py
@torch.jit.script
def sum_list_int(x: List[int]) -> int:
    """Compute the sum of elements in the list. Equivalent to `sum(x)`."""
    if len(x) == 0:
        return 1
    x0 = x[0]
    for x1 in x[1:]:
        x0 = x0 + x1
    return x0

reverse_list_int

reverse_list_int(x)

TorchScript equivalent to x[::-1]

Source code in cassetta/functional/jit/python.py
@torch.jit.script
def reverse_list_int(x: List[int]) -> List[int]:
    """TorchScript equivalent to `x[::-1]`"""
    if len(x) == 0:
        return x
    return [x[i] for i in range(-1, -len(x)-1, -1)]

cumprod_list_int

cumprod_list_int(x, reverse=False, exclusive=False)

Cumulative product of elements in the list

PARAMETER DESCRIPTION
x

List of integers

TYPE: list[int]

reverse

Cumulative product from right to left. Else, cumulative product from left to right (default).

TYPE: bool DEFAULT: False

exclusive

Start series from 1. Else start series from first element (default).

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
y

Cumulative product

TYPE: list[int]

Source code in cassetta/functional/jit/python.py
@torch.jit.script
def cumprod_list_int(x: List[int], reverse: bool = False,
                     exclusive: bool = False) -> List[int]:
    """Cumulative product of elements in the list

    Parameters
    ----------
    x : list[int]
        List of integers
    reverse : bool
        Cumulative product from right to left.
        Else, cumulative product from left to right (default).
    exclusive : bool
        Start series from 1.
        Else start series from first element (default).

    Returns
    -------
    y : list[int]
        Cumulative product

    """
    if len(x) == 0:
        lx: List[int] = []
        return lx
    if reverse:
        x = reverse_list_int(x)

    x0 = 1 if exclusive else x[0]
    lx = [x0]
    all_x = x[:-1] if exclusive else x[1:]
    for x1 in all_x:
        x0 = x0 * x1
        lx.append(x0)
    if reverse:
        lx = reverse_list_int(lx)
    return lx

cassetta.functional.jit.tensors

prod_list_tensor

prod_list_tensor(x)

Compute the product of tensors in the list.

Source code in cassetta/functional/jit/tensors.py
@torch.jit.script
def prod_list_tensor(x: List[Tensor]) -> Tensor:
    """Compute the product of tensors in the list."""
    if len(x) == 0:
        empty: List[int] = []
        return torch.ones(empty)
    x0 = x[0]
    for x1 in x[1:]:
        x0 = x0 * x1
    return x0

sum_list_tensor

sum_list_tensor(x)

Compute the sum of tensors in the list. Equivalent to sum(x).

Source code in cassetta/functional/jit/tensors.py
@torch.jit.script
def sum_list_tensor(x: List[Tensor]) -> Tensor:
    """Compute the sum of tensors in the list. Equivalent to `sum(x)`."""
    if len(x) == 0:
        empty: List[int] = []
        return torch.ones(empty)
    x0 = x[0]
    for x1 in x[1:]:
        x0 = x0 + x1
    return x0

movedim

movedim(x, source, destination)

Backward compatible torch.movedim

Source code in cassetta/functional/jit/tensors.py
@torch.jit.script
def movedim(x, source: int, destination: int):
    """Backward compatible `torch.movedim`"""
    dim = x.dim()
    src, dst = source, destination
    src = dim + src if src < 0 else src
    dst = dim + dst if dst < 0 else dst
    permutation = [d for d in range(dim)]
    permutation = permutation[:src] + permutation[src+1:]
    permutation = permutation[:dst] + [src] + permutation[dst:]
    return x.permute(permutation)