Transforms¶
Usage¶
Transforms include common data augmentation and preprocessing functions that you would use with volumetric data in a machine learning context. We implement all transforms directly in PyTorch, therefore they should all work on both CPU and GPU and do not require further dependencies. If you would like to see transforms that we have not implemented, feel free to write an issue or send us a PR.
For Transforms in torchvtk
it is necessary to inherit DictTransform
.
A DictTransform
takes care of applying transformations to specified items of the data dict (as defined in TorchDataset) and gives control over dtype
and device
for the transform. We stick with this dictionariy paradigm to chain all possible preprocessing tasks together in a Composite
.
An example is presented in the following code snipped.
from torchvtk.transforms import Composite, RandFlip, Resize, Lambda
def _basic_func(data):
data['newkey'] = 1 # do some custom stuff here
return data # Make sure to return the dictionary!
def _basic_add(number): # This function does not care about dict's
return number + 1
tfms = Composite(
_basic_func,
Lambda(_basic_add, apply_on='newkey'),
RandFlip(),
Resize((256,256,256),
apply_on=['vol'])
tmp = tfms(data) # data should be a dict, like the item of a TorchDataset
ds = TorchDataset(path_to_ds, preprocess_fn=tfms) # Use the transforms upon loading with a TorchDataset
Compositing¶
We can Composite DictTransform
s, as well as normal functions, assuming they all operate on dict
s and return the modified dictionary. All subclasses of DictTransform
, which are all classes in torchvtk.transforms
, can be given the parameters of DictTransform
through the **kwargs
in their respective __init__
s. That means you can specify for each transform on which items in the dictionary they should be applied (e.g. apply_on=['vol']
), as well as a preferred torch.dtype
and torch.device
for the transform.
Note that setting the apply_on
paramter of Composite
(as in the example), applies to all transforms that have not specified apply_on
themselves. The dtype
and device
parameters work similar.
DictTransform arguments¶
Further note that by default the transformations are applied to all keys that have a torch.Tensor
as value. Beware of other Tensors in your data that you do not wan’t to modify! You should usually define apply_on
for all your transforms somehow, be it specific or through Composite
. As for the dtype
and device
, the transform is executed on using the type and device that the data comes in. Setting a type or device at the beginning of a Composite
can thus determine the type or device until the next transform specifies another.
Mixing with non-DictTransforms¶
In this example, the _basic_func
is executed first and simply gets the whole dictionary in and must return the modified one. Here a new key 'newkey'
is added. _basic_add
is a standard function that knows nothing of dicts and we can wrap it using Lambda
to make use of apply_on
etc.
As you can see we apply _basic_add
only to 'newkey'
. All the other transforms in the Composite are applied to 'vol'
, since the Composite
sets it for all transforms that did not specify apply_on
.
API¶
DictTransform¶
-
class
torchvtk.transforms.
DictTransform
(device=None, apply_on=None, dtype=None) Super Class for the Transforms.
-
__init__
(device=None, apply_on=None, dtype=None) - Parameters
apply_on – The keys of the item dictionaries on which the transform should be applied. Defaults to applying to all torch.Tensors
device – The torch.device on which the transformation should be executed. Also valid: “cpu”, “cuda”. Defaults to using whatever comes.
dtype – The torch.dtype to which the data should be converted before the transform. Defaults to using whatever comes..
-
override_apply_on
(apply_on)
-
abstract
transform
(data) Transformation Method, must be overwritten by every SubClass.
-
Composite¶
-
class
torchvtk.transforms.
Composite
(*tfms, apply_on=None, device=None, dtype=None) Bases:
torchvtk.transforms.dict_transform.DictTransform
-
__init__
(*tfms, apply_on=None, device=None, dtype=None) Composites multiple transforms together
- Parameters
tfms (Callable, DictTransform) – `DictTransform`s or just callable objects that can handle the incoming dict data
apply_on (List of str) – Overrides the apply_on dictionary masks of the given transforms. (Only applies to `DictTransform`s)
device (torch.device, str) – torch.device, ‘cpu’ or ‘cuda’. This overrides the device for all `DictTransform`s.
dtype (torch.dtype) – Overrides the dtype for all `DictTransform`s this composites.
-
override_apply_on
(apply_on)
-
abstract
transform
(data) Transformation Method, must be overwritten by every SubClass.
-
Lambda¶
-
class
torchvtk.transforms.
Lambda
(func, as_list=False, **kwargs) Bases:
torchvtk.transforms.dict_transform.DictTransform
-
__init__
(func, as_list=False, **kwargs) Applies a given function, wrapped in a DictTransform
- Parameters
func (function) – The function to be executed
as_list (bool) – Wether all inputs specified in apply_on are passed as a list, or as separate items. Defaults to False (separate items).
kwargs – Arguments for DictTransform
-
override_apply_on
(apply_on)
-
transform
(items) Transformation Method, must be overwritten by every SubClass.
-
Crop¶
-
class
torchvtk.transforms.
Crop
(size=(20, 20, 20), position=0, **kwargs) Bases:
torchvtk.transforms.dict_transform.DictTransform
-
__init__
(size=(20, 20, 20), position=0, **kwargs) Crops a tensor size (3-tuple of int): Size of the crop. position (3-tuple of int): Middle point of the cropped region. kwargs: Arguments for DictTransform.
-
get_center_crop
(data, size) Helper method for the crop.
-
get_crop_around
(data, mid, size) Helper method for the crop.
-
override_apply_on
(apply_on)
-
transform
(items) Applies the Center Crop.
-
Resize¶
-
class
torchvtk.transforms.
Resize
(size, mode='trilinear', is_batch=False, **kwargs) Bases:
torchvtk.transforms.dict_transform.DictTransform
-
__init__
(size, mode='trilinear', is_batch=False, **kwargs) Resizes volumes to a given size or by a given factor
- Parameters
size (tuple/list or float) – The new spatial dimensions in a tuple or a factor as scalar
mode (str, optional) – Resampling mode. See PyTorch’s torch.nn.functional.interpolate. Defaults to ‘trilinear’.
is_batch (bool) – Wether the data passed in here already has a batch dimension (cannot be inferred if size is given as scalar). Defaults to False.
kwargs – Arguments for DictTransform
-
override_apply_on
(apply_on)
-
transform
(items) Transformation Method, must be overwritten by every SubClass.
-
RandFlip¶
-
class
torchvtk.transforms.
RandFlip
(flip_probability=0.5, dims=[1, 1, 1], **kwargs) Bases:
torchvtk.transforms.dict_transform.DictTransform
Flips dimensions with a given probability. (Random event occurs for each dimension)
-
__init__
(flip_probability=0.5, dims=[1, 1, 1], **kwargs) Flips dimensions of a tensor with a given flip_probability.
- Parameters
flip_probability (float) – Probability of a dimension being flipped. Default 0.5.
dims (list of 3 ints) – Dimensions that may be flipped are denoted with a 1, otherwise 0. [1,0,1] would randomly flip a volumes depth and width dimension, while never flipping its height dimension
kwargs – Arguments for DictTransform
-
override_apply_on
(apply_on)
-
transform
(items) Transformation Method, must be overwritten by every SubClass.
-
RandPermute¶
-
class
torchvtk.transforms.
RandPermute
(permutations=None, **kwargs) Bases:
torchvtk.transforms.dict_transform.DictTransform
Chooses one of the 8 random permutations for the volume axes
-
__init__
(permutations=None, **kwargs) Randomly choose one of the given permutations.
- Parameters
permutations (list of 3-tuples) – Overrides the list of possible permutations to choose from. The default is [ (0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0) ]. permutations must be a list or tuple of items that are compatible with torch.permute. Assume 0 to be the first spatial dimension, we account for a possible batch and channel dimension. The permutation will then be chosen at random from the given list/tuple.
kwargs – Arguments for DictTransform
-
override_apply_on
(apply_on)
-
transform
(items) Transformation Method, must be overwritten by every SubClass.
-
GaussianBlur¶
-
class
torchvtk.transforms.
GaussianBlur
(channels=1, kernel_size=(3, 3, 3), sigma=1, **kwargs) Bases:
torchvtk.transforms.dict_transform.DictTransform
-
__init__
(channels=1, kernel_size=(3, 3, 3), sigma=1, **kwargs) Blurs tensors using a Gaussian filter
- Parameters
channels (int) – Amount of channels of the input data.
kernel_size (list of int) – Size of the convolution kernel.
sigma (float) – Standard deviation.
kwargs – Arguments for DictTransform
-
override_apply_on
(apply_on)
-
transform
(items) Applies the Blur using a 3D Convolution.
-
GaussianNoise¶
-
class
torchvtk.transforms.
GaussianNoise
(std_deviation=0.01, mean=0, **kwargs) Bases:
torchvtk.transforms.dict_transform.DictTransform
-
__init__
(std_deviation=0.01, mean=0, **kwargs) Adds Gaussian noise to tensors
- Parameters
std_deviation (float, tensor) – The variance of the noise
mean (float, tensor) – The mean of the noise.
kwargs – Arguments for DictTransform.
-
override_apply_on
(apply_on)
-
transform
(items) Applies the Noise onto the images. Variance is controlled by the noise_variance parameter.
-