cassetta.core
Overview
This contains core utilities that are mostly used internally, and whose API may be less stable than the rest of the package. Use at your own risk.
cassetta.core.typing
A set of cassetta-specific type hints, plus backward-compatible hints.
OneOrSeveral
module-attribute
Either a single value of a type, or a sequence of such values.
DeviceType
module-attribute
An instantiated torch.device, or a string that allows
instantiating one, such as "cpu", "cuda" or "cuda:0".
DataType
module-attribute
A torch.dtype, np.dtype, or a string
that represents one such data type.
See to_torch_dtype for details.
ActivationType
module-attribute
An instantiated nn.Module,
or a nn.Module subtype,
or the name of an activation class in torch.nn or in
cassetta.layers.activations
See [make_activation][cassetta.layers.make_activation] for details.
NormType
module-attribute
DropoutType
module-attribute
AttentionType
module-attribute
An instantiated nn.Module,
or a nn.Module subtype,
or one of:
| String | Class |
|---|---|
"sqzex" |
[SqzEx][cassetta.layers.SqzEx] |
"cbam" |
[BlockAttention][cassetta.layers.BlockAttention] |
"dp" |
[DotProductAttention][cassetta.layers.DotProductAttention](scaled=False) |
"sdp" |
[DotProductAttention][cassetta.layers.DotProductAttention](scaled=True) |
"mha" |
[MultiHeadAttention][cassetta.layers.MultiHeadAttention] |
See [make_attention][cassetta.layers.make_attention] for details.
InterpolationType
module-attribute
InterpolationType = Union[int, Literal['nearest', 'linear', 'quadratic', 'cubic', 'fourth', 'fifth', 'sixth', 'seventh']]
The degree of B-splines used for interpolation, between 0 and 7, or one of their string aliases:
"nearest""linear""quadratic""cubic""fourth""fifth""sixth""seventh"
ModelType
module-attribute
A model can be:
- the name of a cassetta model, such as
"SegNet"; - the fully qualified path to a model, such as
"cassetta.models.ElasticRegNet", or"monai.networks.nets.ResNet"; - a
nn.Modulesubclass, such as [SegNet][cassetta.models.SegNet]; - an already instantiated
nn.Module, such as [SegNet(3, 1, 5)][cassetta.models.SegNet].
LossType
module-attribute
A loss can be:
- the name of a cassetta loss, such as
"DiceLoss"; - the fully qualified path to a model, such as
"cassetta.losses.DiceLoss", or"monai.losses.GeneralizedDiceLoss"; - a
nn.Modulesubclass, such as [DiceLoss][cassetta.losses.DiceLoss]; - an already instantiated
nn.Module, such as [DiceLoss()][cassetta.losses.DiceLoss].
cassetta.core.utils
A set of various utility functions.
ensure_list
Ensure that an object is a list
The output list is of length at least length.
When crop is True, its length is also at most length.
If needed, the last value is replicated, unless default is provided.
If x is a list, nothing is done (no copy triggered). If it is a tuple, range, or generator, it is converted to a list. Otherwise, it is placed inside a list.
Source code in cassetta/core/utils.py
ensure_tuple
make_vector
Ensure that the input is a (tensor) vector and pad/crop if necessary.
| PARAMETER | DESCRIPTION |
|---|---|
input |
Input argument(s).
TYPE:
|
length |
Target length.
TYPE:
|
crop |
Crop input sequence if longer than
TYPE:
|
Keyword Parameters
default : optional Default value to pad with. If not provided, replicate the last value. dtype : torch.dtype, optional Output data type. device : torch.device, optional Output device
| RETURNS | DESCRIPTION |
|---|---|
output
|
Output vector.
TYPE:
|
Source code in cassetta/core/utils.py
torch_version
Check torch version
| PARAMETER | DESCRIPTION |
|---|---|
mode |
TYPE:
|
version |
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
True if "torch.version <mode> version"
|
|
Source code in cassetta/core/utils.py
to_torch_dtype
Transform a python or numpy dtype or dtype name to a torch dtype.
Python -> PyTorch convention
We follow the PyTorch convention and convert float to
torch.float32, int to torch.long and complex to
torch.complex64.
| PARAMETER | DESCRIPTION |
|---|---|
dtype |
Input data type |
upcast |
Upcast to nearest torch dtype if input dtype cannot be represented
exactly. Else, raise a
TYPE:
|
trunc |
Trunc to nearest torch dtype if input dtype cannot be represented
exactly. Else, raise a
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
dtype
|
Torch data type
TYPE:
|
Source code in cassetta/core/utils.py
import_submodules
Pre-import submodules into parent module, so that we can do
instead of| PARAMETER | DESCRIPTION |
|---|---|
submodules |
Names of submodules to import
TYPE:
|
module |
Path to parent module:
TYPE:
|
all |
Reference to the parent module's
TYPE:
|
import_into |
Also import all objects from the submodule into the parent module
TYPE:
|
Source code in cassetta/core/utils.py
refresh_experiment_dir
Check if the directory has contents, and if so, delete them recursively.
| PARAMETER | DESCRIPTION |
|---|---|
dir_path |
Path to the directory to be checked and cleared.
TYPE:
|
| RAISES | DESCRIPTION |
|---|---|
FileNotFoundError
|
If the directory does not exist. |
Source code in cassetta/core/utils.py
delete_files_with_pattern
Deletes all files in the specified directory that match the given pattern.
| PARAMETER | DESCRIPTION |
|---|---|
directory |
The path to the directory.
TYPE:
|
pattern |
The glob pattern to match files.
TYPE:
|
Example
delete_files_with_pattern('/path/to/directory', 'last')
Source code in cassetta/core/utils.py
find_files_with_pattern
Search for files in the specified directory that match the given grep pattern.
| PARAMETER | DESCRIPTION |
|---|---|
directory |
The path to the directory to search.
TYPE:
|
pattern |
The grep pattern to search for.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
paths
|
A list of file paths that match the pattern.
TYPE:
|
Source code in cassetta/core/utils.py
find_checkpoint
Find the checkpoint file in the specified experiment directory.
| PARAMETER | DESCRIPTION |
|---|---|
experiment_dir |
The path to the directory of the experiment.
TYPE:
|
checkpoint_type |
The type of checkpoint to find, either "last" or "best". Defaults to "best".
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Optional[str]
|
The full path to the checkpoint file if found, otherwise None. |
| RAISES | DESCRIPTION |
|---|---|
ValueError
|
If checkpoint_type is not "last" or "best". |