TorchDataset¶
This is your go-to dataset class to handle large volumes.
Each volume is stored together as a serialized binary PyTorch file (.pt
), together with its meta data. You can always load such a file like this:
import torch
data = torch.load('/path/to/file.pt´)
In the case of an item from the CQ500 datasets, data would be a dictionary of the form
{
'vol': torch.Tensor(), # Volume tensor of shape (C, D, H, W)
'vox_scl': torch.Tensor(), # Scale of the voxels (for rectangular grids)
'name': str # Name of the volume
}
Note that volumes saved in the torchvtk format are always scaled to [0,1]
and saved as torch.float16
or torch.float32
. Also the shape for single-channel volumes is always 4-dimensional: (1, D, H, W)
.
Basic Example¶
We have mechanisms to download some volume datasets automatically using torchvtk. These datasets are downloaded to your local torchvtk cache (~/.torchvtk/
unless specified) where they are converted to TorchDataset
format. If the dataset exists on your torchvtk folder, it will just be loaded.
from torchvtk.datasets import TorchDataset
ds = TorchDataset.CQ500('custom/path/torchvtk', num_workers=4)
The num_workers
parameter applies to downloading and unpacking.
More involved example¶
This example showcases how TorchDataset
s can be used to easily preprocess datasets, both for saving persisently on disk and during loading.
The following snippet shows how we can serialize a preprocessed version of the CQ500 dataset easily using multiprocessing:
import torch.nn.functional as F
from torchvtk.datasets import TorchDataset
from torchvtk.utils import make_5d
def to_256(data):
data['vol'] = F.interpolate(make_5d(data['vol']), size=(256,256,256), mode='trilinear').squeeze(0)
return data # Must return the dictionary!
ds = TorchDataset.CQ500('/mnt/hdd/torchvtk')
ds_256 = ds.cache_processed(to_256, 'CQ500_256', num_workers=4) # Call this only once
After calling this, the resized dataset is serialized. From then on use the following, assuming your local torchvtk folder is /mnt/hdd/torchvtk/
:
ds = TorchDataset('/mnt/hdd/torchvtk/CQ500_256')
Serializing large volume datasets at the resolution you will use for training (which is likely lower than the original) can be very beneficial for data loading times.
Having serialized the volumes in low resolution, we can apply more preprocessing in the TorchDataset
that is applied upon loading a sample. This is ideal for torchvtk.transforms
. We further split our data into train and validation:
from torchvtk.transforms import GaussianNoise
train_ds = TorchDataset('/mnt/hdd/torchvtk/CQ500_256', filter_fn=lambda p: int(p.name[9:-3]) < 400,
preprocess_fn=GaussianNoise(apply_on=['vol']))
valid_ds = TorchDataset('/mnt/hdd/torchvtk/CQ500_256', filter_fn=lambda p: int(p.name[9:-3]) >= 400)
We split our dataset into training and validation simply by using the filter_fn
parameter which takes a function that filters out files from the dataset based on their filepath (pathlib.Path
). Here the file’s name is trimmed to the number specifically for the CQ500 item.
The preprocess_fn
parameter takes any callable object and is expected to take a dictionary (as specified at the top of this article) and returns a modified dict.
torchvtk.transforms
fulfill these requirements and can easily specify to which keys in your data the operations shall be applied. Check out Transforms.
API¶
-
class
torchvtk.datasets.
TorchDataset
(ds_files, filter_fn=None, preprocess_fn=None) Bases:
torch.utils.data.Dataset
-
static
CQ500
(tvtk_ds_path='~/.torchvtk/', num_workers=0, **kwargs) Get the QureAI CQ500 Dataset. Downloads, extracts and converts to TorchDataset if not locally available Find the dataset here: http://headctstudy.qure.ai/dataset Credits to Chilamkurthy et al. https://arxiv.org/abs/1803.05854
- Parameters
tvtk_ds_path (str, Path) – Path where your torchvtk datasets shall be saved.
num_workers (int) – Number of processes used for downloading, extracting, converting
kwargs – Keyword arguments to pass on to TorchDataset.__init__()
- Returns
TorchDataset containing CQ500.
- Return type
-
__init__
(ds_files, filter_fn=None, preprocess_fn=None) A dataset that uses serialized PyTorch Tensors.
- Parameters
ds_files (str, Path (Dict), List of Path (Files)) – Path to the TorchDataset directory (containing *.pt) or list of paths pointing to .pt files
filter_fn (function) – Function that filters the found items. Input is filepath
preprocess_fn (function) – Function to process the loaded dirctionary.
-
cache_processed
(process_fn, name, num_workers=0, delete_old_from_disk=False) Processes the given TorchDataset and serializes it. Iterates through the dataset and applies the given process_fn to each item (which should be a dictionary). The resulting new dataset will be serialized next to the old one, using then given name. This function can work multithreaded.
- Parameters
process_fn (function) – The function to be applied on the inidividual items
name (str) – Name of the new processed dataset
num_workers (int > 0) – Number of threads used for processing
delete_old_from_disk (bool) – If True, the root directory of the old, unprocessed, dataset is removed from disk.
- Returns
TorchDataset with the new items. (no filter or preprocess_fn set)
- Return type
-
static
from_file
(file_path, filter_fn=None, preprocess_fn=None)
-
preload
(device='cpu', num_workers=0) Preloads the dataset into memory.
- Parameters
device (torch.device, optional) – Device to store the dataset on. Defaults to ‘cpu’.
num_workers (int, optional) – Number of workers to load items into memory. Defaults to 0.
- Returns
New TorchDataset using the preloaded data.
- Return type
PreloadedTorchDataset
-
tile
(keys_to_tile, tile_sz=128, overlap=2, dim=3, **kwargs) Converts the Dataset to a Tiled Dataset, drawing only parts of the data Since the data needs to be loaded to determine the number of tiles, a tile is drawn randomly after loading the volume, Without guaranteeing full coverage.
- Parameters
keys_to_tile ([str]) – List of strings matching the keys of the data dictionaries that need to be tiled. All must have the same shape and result in the same tiling.
tile_sz (int/tuple of ints, optional) – Size of the tiles drawn. Either int or tuple with length matching the given dim. Defaults to 128.
overlap (int/tuple of its, optional) – . Defaults to 2.
dim (int, optional) – Dimensionality of the data. If tile_sz or overlap is given as tuple this must match their lengths. Defaults to 3.
- Returns
Tiling-aware TorchDataset
- Return type
TiledTorchDataset
-
static