torchvtk.datasets¶
TorchDataset¶
-
class
torchvtk.datasets.
TorchDataset
(ds_files, filter_fn=None, preprocess_fn=None)¶ Bases:
torch.utils.data.Dataset
-
static
CQ500
(tvtk_ds_path='~/.torchvtk/', num_workers=0, **kwargs)¶ Get the QureAI CQ500 Dataset. Downloads, extracts and converts to TorchDataset if not locally available Find the dataset here: http://headctstudy.qure.ai/dataset Credits to Chilamkurthy et al. https://arxiv.org/abs/1803.05854
- Parameters
tvtk_ds_path (str, Path) – Path where your torchvtk datasets shall be saved.
num_workers (int) – Number of processes used for downloading, extracting, converting
kwargs – Keyword arguments to pass on to TorchDataset.__init__()
- Returns
TorchDataset containing CQ500.
- Return type
-
__init__
(ds_files, filter_fn=None, preprocess_fn=None)¶ A dataset that uses serialized PyTorch Tensors.
- Parameters
ds_files (str, Path (Dict), List of Path (Files)) – Path to the TorchDataset directory (containing *.pt) or list of paths pointing to .pt files
filter_fn (function) – Function that filters the found items. Input is filepath
preprocess_fn (function) – Function to process the loaded dirctionary.
-
cache_processed
(process_fn, name, num_workers=0, delete_old_from_disk=False)¶ Processes the given TorchDataset and serializes it. Iterates through the dataset and applies the given process_fn to each item (which should be a dictionary). The resulting new dataset will be serialized next to the old one, using then given name. This function can work multithreaded.
- Parameters
process_fn (function) – The function to be applied on the inidividual items
name (str) – Name of the new processed dataset
num_workers (int > 0) – Number of threads used for processing
delete_old_from_disk (bool) – If True, the root directory of the old, unprocessed, dataset is removed from disk.
- Returns
TorchDataset with the new items. (no filter or preprocess_fn set)
- Return type
-
static
from_file
(file_path, filter_fn=None, preprocess_fn=None)¶
-
preload
(device='cpu', num_workers=0)¶ Preloads the dataset into memory.
- Parameters
device (torch.device, optional) – Device to store the dataset on. Defaults to ‘cpu’.
num_workers (int, optional) – Number of workers to load items into memory. Defaults to 0.
- Returns
New TorchDataset using the preloaded data.
- Return type
PreloadedTorchDataset
-
tile
(keys_to_tile, tile_sz=128, overlap=2, dim=3, **kwargs)¶ Converts the Dataset to a Tiled Dataset, drawing only parts of the data Since the data needs to be loaded to determine the number of tiles, a tile is drawn randomly after loading the volume, Without guaranteeing full coverage.
- Parameters
keys_to_tile ([str]) – List of strings matching the keys of the data dictionaries that need to be tiled. All must have the same shape and result in the same tiling.
tile_sz (int/tuple of ints, optional) – Size of the tiles drawn. Either int or tuple with length matching the given dim. Defaults to 128.
overlap (int/tuple of its, optional) – . Defaults to 2.
dim (int, optional) – Dimensionality of the data. If tile_sz or overlap is given as tuple this must match their lengths. Defaults to 3.
- Returns
Tiling-aware TorchDataset
- Return type
TiledTorchDataset
-
static
TorchQueueDataset¶
-
class
torchvtk.datasets.
TorchQueueDataset
(torch_ds, epoch_len=1000, mode='onsample', num_workers=1, q_maxlen=None, ram_use=0.5, wait_fill=True, wait_fill_timeout=60, sample_tfm=<function noop>, batch_tfm=<function noop>, bs=1, collate_fn=<function dict_collate_fn>, log_sampling=False, avg_item_size=None, preprocess_fn=None, filter_fn=None)¶ Bases:
torch.utils.data.IterableDataset
-
__init__
(torch_ds, epoch_len=1000, mode='onsample', num_workers=1, q_maxlen=None, ram_use=0.5, wait_fill=True, wait_fill_timeout=60, sample_tfm=<function noop>, batch_tfm=<function noop>, bs=1, collate_fn=<function dict_collate_fn>, log_sampling=False, avg_item_size=None, preprocess_fn=None, filter_fn=None)¶ An iterable-style dataset that caches items in a queue in memory.
- Parameters
torch_ds (TorchDataset, str,Path) – A TorchDataset to be used for queueing or path to the dataset on disk
mode (string) – Queue filling mode. - ‘onsample’ refills the queue after it got sampled - ‘always’ keeps refilling the queue as fast as possible
num_workers (int) – Number of threads loading in data
q_maxlen (int) – Set queue size. Overrides ram_use
ram_use (float) – Fraction of available system memory to use for queue or memory budget in MB (>1.0). Default is 75%
avg_item_size (float, torch.Tensor) – Example tensor or size in MB
wait_fill (int, bool) – Boolean whether queue should be filled on init or Int to fill the queue at least with a certain amount of items
wait_fill_timeout (int,float) – Time in seconds until wait_fill timeouts. Default is 60s
sample_tfm (Transform, function) – Applicable transform (receiving and producing a dict) that is applied upon sampling from the queue
batch_tfm (Transform, function) – Transforms to be applied on batches of items
preprocess_fn (function) – Override preprocess_fn from given torch_ds
filter_fn (function) – Filters filenames to load, like TorchDataset. Only used if torch_ds is a path to a dataset.
bs (int) – Batch Size
collate_fn (function) – Collate Function to merge items to batches. Default assumes dictionaries (like from TorchDataset) and stacks all tensors, while collecting non-tensors in a list
-
batch_generator
()¶ Generator for sampling the queue. This makes use of the object attributes bs (batch size) and the collate function
- Returns
Generator that samples randomly samples batches from the queue.
-
get_dataloader
(**kwargs)¶ - Returns
A dataloader that uses the batched sampling of the queue with appropriate collate_fn and batch_size.
- Return type
torch.utils.data.DataLoader
-
property
qsize
¶ Current Queue length
-
wait_fill_queue
(fill_atleast=None, timeout=60, polling_interval=0.25)¶ Waits untill the queue is filled (fill_atleast`=None) or until filled with at least `fill_atleast. Timeouts.
- Parameters
fill_atleast (int) – Waits until queue is at least filled with so many items.
timeout (Number) – Time in seconds before this method terminates regardless of the queue size
polling_interval (Number) – Time in seconds how fast the queue size is polled while waiting.
-
dict_collate_fn¶
-
torchvtk.datasets.
dict_collate_fn
(items, key_filter=None, stack_tensors=True, convert_np=True, convert_numbers=True, warn_when_unstackable=True)¶ Collate function for dictionary data
This stacks tensors only if they are stackable, meaning they are of the same shape.
- Parameters
items (list) – List of individual items for a Dataset
key_filter (list of str or callable, optional) – A list of keys to filter the dict data. Defaults to None.
stack_tensors (bool, optional) – Wether to stack dict entries of type torch.Tensors. Disable if you have unstackable tensors. They will be stacked as a list. Defaults to True.
convert_np (bool, optional) – Convert NumPy arrays to torch.Tensors and stack them. Defaults to True.
convert_numbers (bool, optional) – Converts standard Python numbers to torch.Tensors and stacks them. Defaults to True.
warn_when_unstackable (bool, list of str, optional) – If True, prints a warning when a set of Tensors is unstackable. You can also specify a list of keys for which to print the warning. Defaults to True.
- Returns
One Dictionary with tensors stacked
- Return type
dict