TorchDataset

This is your go-to dataset class to handle large volumes. Each volume is stored together as a serialized binary PyTorch file (.pt), together with its meta data. You can always load such a file like this:

import torch
data = torch.load('/path/to/file.pt´)

In the case of an item from the CQ500 datasets, data would be a dictionary of the form

{
   'vol': torch.Tensor(),     # Volume tensor of shape (C, D, H, W)
   'vox_scl': torch.Tensor(), # Scale of the voxels (for rectangular grids)
   'name': str                # Name of the volume
}

Note that volumes saved in the torchvtk format are always scaled to [0,1] and saved as torch.float16 or torch.float32. Also the shape for single-channel volumes is always 4-dimensional: (1, D, H, W).

Basic Example

We have mechanisms to download some volume datasets automatically using torchvtk. These datasets are downloaded to your local torchvtk cache (~/.torchvtk/ unless specified) where they are converted to TorchDataset format. If the dataset exists on your torchvtk folder, it will just be loaded.

from torchvtk.datasets import TorchDataset

ds = TorchDataset.CQ500('custom/path/torchvtk', num_workers=4)

The num_workers parameter applies to downloading and unpacking.

More involved example

This example showcases how TorchDatasets can be used to easily preprocess datasets, both for saving persisently on disk and during loading. The following snippet shows how we can serialize a preprocessed version of the CQ500 dataset easily using multiprocessing:

import torch.nn.functional as F
from torchvtk.datasets import TorchDataset
from torchvtk.utils    import make_5d

def to_256(data):
   data['vol'] = F.interpolate(make_5d(data['vol']), size=(256,256,256), mode='trilinear').squeeze(0)
   return data # Must return the dictionary!

ds = TorchDataset.CQ500('/mnt/hdd/torchvtk')
ds_256 = ds.cache_processed(to_256, 'CQ500_256', num_workers=4) # Call this only once

After calling this, the resized dataset is serialized. From then on use the following, assuming your local torchvtk folder is /mnt/hdd/torchvtk/:

ds = TorchDataset('/mnt/hdd/torchvtk/CQ500_256')

Serializing large volume datasets at the resolution you will use for training (which is likely lower than the original) can be very beneficial for data loading times.

Having serialized the volumes in low resolution, we can apply more preprocessing in the TorchDataset that is applied upon loading a sample. This is ideal for torchvtk.transforms. We further split our data into train and validation:

from torchvtk.transforms import GaussianNoise

train_ds = TorchDataset('/mnt/hdd/torchvtk/CQ500_256', filter_fn=lambda p: int(p.name[9:-3]) < 400,
   preprocess_fn=GaussianNoise(apply_on=['vol']))

valid_ds = TorchDataset('/mnt/hdd/torchvtk/CQ500_256', filter_fn=lambda p: int(p.name[9:-3]) >= 400)

We split our dataset into training and validation simply by using the filter_fn parameter which takes a function that filters out files from the dataset based on their filepath (pathlib.Path). Here the file’s name is trimmed to the number specifically for the CQ500 item. The preprocess_fn parameter takes any callable object and is expected to take a dictionary (as specified at the top of this article) and returns a modified dict. torchvtk.transforms fulfill these requirements and can easily specify to which keys in your data the operations shall be applied. Check out Transforms.

API

class torchvtk.datasets.TorchDataset(ds_files, filter_fn=None, preprocess_fn=None)

Bases: torch.utils.data.Dataset

static CQ500(tvtk_ds_path='~/.torchvtk/', num_workers=0, **kwargs)

Get the QureAI CQ500 Dataset. Downloads, extracts and converts to TorchDataset if not locally available Find the dataset here: http://headctstudy.qure.ai/dataset Credits to Chilamkurthy et al. https://arxiv.org/abs/1803.05854

Parameters
  • tvtk_ds_path (str, Path) – Path where your torchvtk datasets shall be saved.

  • num_workers (int) – Number of processes used for downloading, extracting, converting

  • kwargs – Keyword arguments to pass on to TorchDataset.__init__()

Returns

TorchDataset containing CQ500.

Return type

TorchDataset

__init__(ds_files, filter_fn=None, preprocess_fn=None)

A dataset that uses serialized PyTorch Tensors.

Parameters
  • ds_files (str, Path (Dict), List of Path (Files)) – Path to the TorchDataset directory (containing *.pt) or list of paths pointing to .pt files

  • filter_fn (function) – Function that filters the found items. Input is filepath

  • preprocess_fn (function) – Function to process the loaded dirctionary.

cache_processed(process_fn, name, num_workers=0, delete_old_from_disk=False)

Processes the given TorchDataset and serializes it. Iterates through the dataset and applies the given process_fn to each item (which should be a dictionary). The resulting new dataset will be serialized next to the old one, using then given name. This function can work multithreaded.

Parameters
  • process_fn (function) – The function to be applied on the inidividual items

  • name (str) – Name of the new processed dataset

  • num_workers (int > 0) – Number of threads used for processing

  • delete_old_from_disk (bool) – If True, the root directory of the old, unprocessed, dataset is removed from disk.

Returns

TorchDataset with the new items. (no filter or preprocess_fn set)

Return type

TorchDataset

static from_file(file_path, filter_fn=None, preprocess_fn=None)
preload(device='cpu', num_workers=0)

Preloads the dataset into memory.

Parameters
  • device (torch.device, optional) – Device to store the dataset on. Defaults to ‘cpu’.

  • num_workers (int, optional) – Number of workers to load items into memory. Defaults to 0.

Returns

New TorchDataset using the preloaded data.

Return type

PreloadedTorchDataset

tile(keys_to_tile, tile_sz=128, overlap=2, dim=3, **kwargs)

Converts the Dataset to a Tiled Dataset, drawing only parts of the data Since the data needs to be loaded to determine the number of tiles, a tile is drawn randomly after loading the volume, Without guaranteeing full coverage.

Parameters
  • keys_to_tile ([str]) – List of strings matching the keys of the data dictionaries that need to be tiled. All must have the same shape and result in the same tiling.

  • tile_sz (int/tuple of ints, optional) – Size of the tiles drawn. Either int or tuple with length matching the given dim. Defaults to 128.

  • overlap (int/tuple of its, optional) – . Defaults to 2.

  • dim (int, optional) – Dimensionality of the data. If tile_sz or overlap is given as tuple this must match their lengths. Defaults to 3.

Returns

Tiling-aware TorchDataset

Return type

TiledTorchDataset