Transforms

Usage

Transforms include common data augmentation and preprocessing functions that you would use with volumetric data in a machine learning context. We implement all transforms directly in PyTorch, therefore they should all work on both CPU and GPU and do not require further dependencies. If you would like to see transforms that we have not implemented, feel free to write an issue or send us a PR.

For Transforms in torchvtk it is necessary to inherit DictTransform. A DictTransform takes care of applying transformations to specified items of the data dict (as defined in TorchDataset) and gives control over dtype and device for the transform. We stick with this dictionariy paradigm to chain all possible preprocessing tasks together in a Composite. An example is presented in the following code snipped.

    from torchvtk.transforms import Composite, RandFlip, Resize, Lambda
    def _basic_func(data):
        data['newkey'] = 1 # do some custom stuff here
        return data        # Make sure to return the dictionary!

    def _basic_add(number): # This function does not care about dict's
       return number + 1

    tfms = Composite(
      _basic_func,
      Lambda(_basic_add, apply_on='newkey'),
      RandFlip(),
      Resize((256,256,256),
      apply_on=['vol'])
    tmp = tfms(data) # data should be a dict, like the item of a TorchDataset

    ds = TorchDataset(path_to_ds, preprocess_fn=tfms) # Use the transforms upon loading with a TorchDataset

Compositing

We can Composite DictTransforms, as well as normal functions, assuming they all operate on dicts and return the modified dictionary. All subclasses of DictTransform, which are all classes in torchvtk.transforms, can be given the parameters of DictTransform through the **kwargs in their respective __init__s. That means you can specify for each transform on which items in the dictionary they should be applied (e.g. apply_on=['vol']), as well as a preferred torch.dtype and torch.device for the transform.

Note that setting the apply_on paramter of Composite (as in the example), applies to all transforms that have not specified apply_on themselves. The dtype and device parameters work similar.

DictTransform arguments

Further note that by default the transformations are applied to all keys that have a torch.Tensor as value. Beware of other Tensors in your data that you do not wan’t to modify! You should usually define apply_on for all your transforms somehow, be it specific or through Composite. As for the dtype and device, the transform is executed on using the type and device that the data comes in. Setting a type or device at the beginning of a Composite can thus determine the type or device until the next transform specifies another.

Mixing with non-DictTransforms

In this example, the _basic_func is executed first and simply gets the whole dictionary in and must return the modified one. Here a new key 'newkey' is added. _basic_add is a standard function that knows nothing of dicts and we can wrap it using Lambda to make use of apply_on etc. As you can see we apply _basic_add only to 'newkey'. All the other transforms in the Composite are applied to 'vol', since the Composite sets it for all transforms that did not specify apply_on.

API

DictTransform

class torchvtk.transforms.DictTransform(device=None, apply_on=None, dtype=None)

Super Class for the Transforms.

__init__(device=None, apply_on=None, dtype=None)
Parameters
  • apply_on – The keys of the item dictionaries on which the transform should be applied. Defaults to applying to all torch.Tensors

  • device – The torch.device on which the transformation should be executed. Also valid: “cpu”, “cuda”. Defaults to using whatever comes.

  • dtype – The torch.dtype to which the data should be converted before the transform. Defaults to using whatever comes..

override_apply_on(apply_on)
abstract transform(data)

Transformation Method, must be overwritten by every SubClass.

Composite

class torchvtk.transforms.Composite(*tfms, apply_on=None, device=None, dtype=None)

Bases: torchvtk.transforms.dict_transform.DictTransform

__init__(*tfms, apply_on=None, device=None, dtype=None)

Composites multiple transforms together

Parameters
  • tfms (Callable, DictTransform) – `DictTransform`s or just callable objects that can handle the incoming dict data

  • apply_on (List of str) – Overrides the apply_on dictionary masks of the given transforms. (Only applies to `DictTransform`s)

  • device (torch.device, str) – torch.device, ‘cpu’ or ‘cuda’. This overrides the device for all `DictTransform`s.

  • dtype (torch.dtype) – Overrides the dtype for all `DictTransform`s this composites.

override_apply_on(apply_on)
abstract transform(data)

Transformation Method, must be overwritten by every SubClass.

Lambda

class torchvtk.transforms.Lambda(func, as_list=False, **kwargs)

Bases: torchvtk.transforms.dict_transform.DictTransform

__init__(func, as_list=False, **kwargs)

Applies a given function, wrapped in a DictTransform

Parameters
  • func (function) – The function to be executed

  • as_list (bool) – Wether all inputs specified in apply_on are passed as a list, or as separate items. Defaults to False (separate items).

  • kwargs – Arguments for DictTransform

override_apply_on(apply_on)
transform(items)

Transformation Method, must be overwritten by every SubClass.

Crop

class torchvtk.transforms.Crop(size=(20, 20, 20), position=0, **kwargs)

Bases: torchvtk.transforms.dict_transform.DictTransform

__init__(size=(20, 20, 20), position=0, **kwargs)

Crops a tensor size (3-tuple of int): Size of the crop. position (3-tuple of int): Middle point of the cropped region. kwargs: Arguments for DictTransform.

get_center_crop(data, size)

Helper method for the crop.

get_crop_around(data, mid, size)

Helper method for the crop.

override_apply_on(apply_on)
transform(items)

Applies the Center Crop.

Resize

class torchvtk.transforms.Resize(size, mode='trilinear', is_batch=False, **kwargs)

Bases: torchvtk.transforms.dict_transform.DictTransform

__init__(size, mode='trilinear', is_batch=False, **kwargs)

Resizes volumes to a given size or by a given factor

Parameters
  • size (tuple/list or float) – The new spatial dimensions in a tuple or a factor as scalar

  • mode (str, optional) – Resampling mode. See PyTorch’s torch.nn.functional.interpolate. Defaults to ‘trilinear’.

  • is_batch (bool) – Wether the data passed in here already has a batch dimension (cannot be inferred if size is given as scalar). Defaults to False.

  • kwargs – Arguments for DictTransform

override_apply_on(apply_on)
transform(items)

Transformation Method, must be overwritten by every SubClass.

RandFlip

class torchvtk.transforms.RandFlip(flip_probability=0.5, dims=[1, 1, 1], **kwargs)

Bases: torchvtk.transforms.dict_transform.DictTransform

Flips dimensions with a given probability. (Random event occurs for each dimension)

__init__(flip_probability=0.5, dims=[1, 1, 1], **kwargs)

Flips dimensions of a tensor with a given flip_probability.

Parameters
  • flip_probability (float) – Probability of a dimension being flipped. Default 0.5.

  • dims (list of 3 ints) – Dimensions that may be flipped are denoted with a 1, otherwise 0. [1,0,1] would randomly flip a volumes depth and width dimension, while never flipping its height dimension

  • kwargs – Arguments for DictTransform

override_apply_on(apply_on)
transform(items)

Transformation Method, must be overwritten by every SubClass.

RandPermute

class torchvtk.transforms.RandPermute(permutations=None, **kwargs)

Bases: torchvtk.transforms.dict_transform.DictTransform

Chooses one of the 8 random permutations for the volume axes

__init__(permutations=None, **kwargs)

Randomly choose one of the given permutations.

Parameters
  • permutations (list of 3-tuples) – Overrides the list of possible permutations to choose from. The default is [ (0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0) ]. permutations must be a list or tuple of items that are compatible with torch.permute. Assume 0 to be the first spatial dimension, we account for a possible batch and channel dimension. The permutation will then be chosen at random from the given list/tuple.

  • kwargs – Arguments for DictTransform

override_apply_on(apply_on)
transform(items)

Transformation Method, must be overwritten by every SubClass.

GaussianBlur

class torchvtk.transforms.GaussianBlur(channels=1, kernel_size=(3, 3, 3), sigma=1, **kwargs)

Bases: torchvtk.transforms.dict_transform.DictTransform

__init__(channels=1, kernel_size=(3, 3, 3), sigma=1, **kwargs)

Blurs tensors using a Gaussian filter

Parameters
  • channels (int) – Amount of channels of the input data.

  • kernel_size (list of int) – Size of the convolution kernel.

  • sigma (float) – Standard deviation.

  • kwargs – Arguments for DictTransform

override_apply_on(apply_on)
transform(items)

Applies the Blur using a 3D Convolution.

GaussianNoise

class torchvtk.transforms.GaussianNoise(std_deviation=0.01, mean=0, **kwargs)

Bases: torchvtk.transforms.dict_transform.DictTransform

__init__(std_deviation=0.01, mean=0, **kwargs)

Adds Gaussian noise to tensors

Parameters
  • std_deviation (float, tensor) – The variance of the noise

  • mean (float, tensor) – The mean of the noise.

  • kwargs – Arguments for DictTransform.

override_apply_on(apply_on)
transform(items)

Applies the Noise onto the images. Variance is controlled by the noise_variance parameter.