torchvtk.datasets

TorchDataset

class torchvtk.datasets.TorchDataset(ds_files, filter_fn=None, preprocess_fn=None)

Bases: torch.utils.data.Dataset

static CQ500(tvtk_ds_path='~/.torchvtk/', num_workers=0, **kwargs)

Get the QureAI CQ500 Dataset. Downloads, extracts and converts to TorchDataset if not locally available Find the dataset here: http://headctstudy.qure.ai/dataset Credits to Chilamkurthy et al. https://arxiv.org/abs/1803.05854

Parameters
  • tvtk_ds_path (str, Path) – Path where your torchvtk datasets shall be saved.

  • num_workers (int) – Number of processes used for downloading, extracting, converting

  • kwargs – Keyword arguments to pass on to TorchDataset.__init__()

Returns

TorchDataset containing CQ500.

Return type

TorchDataset

__init__(ds_files, filter_fn=None, preprocess_fn=None)

A dataset that uses serialized PyTorch Tensors.

Parameters
  • ds_files (str, Path (Dict), List of Path (Files)) – Path to the TorchDataset directory (containing *.pt) or list of paths pointing to .pt files

  • filter_fn (function) – Function that filters the found items. Input is filepath

  • preprocess_fn (function) – Function to process the loaded dirctionary.

cache_processed(process_fn, name, num_workers=0, delete_old_from_disk=False)

Processes the given TorchDataset and serializes it. Iterates through the dataset and applies the given process_fn to each item (which should be a dictionary). The resulting new dataset will be serialized next to the old one, using then given name. This function can work multithreaded.

Parameters
  • process_fn (function) – The function to be applied on the inidividual items

  • name (str) – Name of the new processed dataset

  • num_workers (int > 0) – Number of threads used for processing

  • delete_old_from_disk (bool) – If True, the root directory of the old, unprocessed, dataset is removed from disk.

Returns

TorchDataset with the new items. (no filter or preprocess_fn set)

Return type

TorchDataset

static from_file(file_path, filter_fn=None, preprocess_fn=None)
preload(device='cpu', num_workers=0)

Preloads the dataset into memory.

Parameters
  • device (torch.device, optional) – Device to store the dataset on. Defaults to ‘cpu’.

  • num_workers (int, optional) – Number of workers to load items into memory. Defaults to 0.

Returns

New TorchDataset using the preloaded data.

Return type

PreloadedTorchDataset

tile(keys_to_tile, tile_sz=128, overlap=2, dim=3, **kwargs)

Converts the Dataset to a Tiled Dataset, drawing only parts of the data Since the data needs to be loaded to determine the number of tiles, a tile is drawn randomly after loading the volume, Without guaranteeing full coverage.

Parameters
  • keys_to_tile ([str]) – List of strings matching the keys of the data dictionaries that need to be tiled. All must have the same shape and result in the same tiling.

  • tile_sz (int/tuple of ints, optional) – Size of the tiles drawn. Either int or tuple with length matching the given dim. Defaults to 128.

  • overlap (int/tuple of its, optional) – . Defaults to 2.

  • dim (int, optional) – Dimensionality of the data. If tile_sz or overlap is given as tuple this must match their lengths. Defaults to 3.

Returns

Tiling-aware TorchDataset

Return type

TiledTorchDataset

TorchQueueDataset

class torchvtk.datasets.TorchQueueDataset(torch_ds, epoch_len=1000, mode='onsample', num_workers=1, q_maxlen=None, ram_use=0.5, wait_fill=True, wait_fill_timeout=60, sample_tfm=<function noop>, batch_tfm=<function noop>, bs=1, collate_fn=<function dict_collate_fn>, log_sampling=False, avg_item_size=None, preprocess_fn=None, filter_fn=None)

Bases: torch.utils.data.IterableDataset

__init__(torch_ds, epoch_len=1000, mode='onsample', num_workers=1, q_maxlen=None, ram_use=0.5, wait_fill=True, wait_fill_timeout=60, sample_tfm=<function noop>, batch_tfm=<function noop>, bs=1, collate_fn=<function dict_collate_fn>, log_sampling=False, avg_item_size=None, preprocess_fn=None, filter_fn=None)

An iterable-style dataset that caches items in a queue in memory.

Parameters
  • torch_ds (TorchDataset, str,Path) – A TorchDataset to be used for queueing or path to the dataset on disk

  • mode (string) – Queue filling mode. - ‘onsample’ refills the queue after it got sampled - ‘always’ keeps refilling the queue as fast as possible

  • num_workers (int) – Number of threads loading in data

  • q_maxlen (int) – Set queue size. Overrides ram_use

  • ram_use (float) – Fraction of available system memory to use for queue or memory budget in MB (>1.0). Default is 75%

  • avg_item_size (float, torch.Tensor) – Example tensor or size in MB

  • wait_fill (int, bool) – Boolean whether queue should be filled on init or Int to fill the queue at least with a certain amount of items

  • wait_fill_timeout (int,float) – Time in seconds until wait_fill timeouts. Default is 60s

  • sample_tfm (Transform, function) – Applicable transform (receiving and producing a dict) that is applied upon sampling from the queue

  • batch_tfm (Transform, function) – Transforms to be applied on batches of items

  • preprocess_fn (function) – Override preprocess_fn from given torch_ds

  • filter_fn (function) – Filters filenames to load, like TorchDataset. Only used if torch_ds is a path to a dataset.

  • bs (int) – Batch Size

  • collate_fn (function) – Collate Function to merge items to batches. Default assumes dictionaries (like from TorchDataset) and stacks all tensors, while collecting non-tensors in a list

batch_generator()

Generator for sampling the queue. This makes use of the object attributes bs (batch size) and the collate function

Returns

Generator that samples randomly samples batches from the queue.

get_dataloader(**kwargs)
Returns

A dataloader that uses the batched sampling of the queue with appropriate collate_fn and batch_size.

Return type

torch.utils.data.DataLoader

property qsize

Current Queue length

wait_fill_queue(fill_atleast=None, timeout=60, polling_interval=0.25)

Waits untill the queue is filled (fill_atleast`=None) or until filled with at least `fill_atleast. Timeouts.

Parameters
  • fill_atleast (int) – Waits until queue is at least filled with so many items.

  • timeout (Number) – Time in seconds before this method terminates regardless of the queue size

  • polling_interval (Number) – Time in seconds how fast the queue size is polled while waiting.

dict_collate_fn

torchvtk.datasets.dict_collate_fn(items, key_filter=None, stack_tensors=True, convert_np=True, convert_numbers=True, warn_when_unstackable=True)

Collate function for dictionary data

This stacks tensors only if they are stackable, meaning they are of the same shape.

Parameters
  • items (list) – List of individual items for a Dataset

  • key_filter (list of str or callable, optional) – A list of keys to filter the dict data. Defaults to None.

  • stack_tensors (bool, optional) – Wether to stack dict entries of type torch.Tensors. Disable if you have unstackable tensors. They will be stacked as a list. Defaults to True.

  • convert_np (bool, optional) – Convert NumPy arrays to torch.Tensors and stack them. Defaults to True.

  • convert_numbers (bool, optional) – Converts standard Python numbers to torch.Tensors and stacks them. Defaults to True.

  • warn_when_unstackable (bool, list of str, optional) – If True, prints a warning when a set of Tensors is unstackable. You can also specify a list of keys for which to print the warning. Defaults to True.

Returns

One Dictionary with tensors stacked

Return type

dict