Source code for aeronet.dataset.dataset

import numpy as np


[docs]class RandomDataset: def __init__(self, band_collections, sample_size, input_channels, output_labels, transform=None, seed=42): """ Produce samples from band collections randomly (uniform dist.) with specified shape. Args: band_collections: list of 'BandCollection' objects sample_size: spatial resolution of sample in pixels (height, width) input_channels: list of channel names, e.g. ['RED', 'GRN', '100'], case sensitive output_labels: list of output label names, e.g. ['100', '101', 'roof'], case sensitive transform: function for sample transformation seed: `NotImplemented` """ self.band_collections = band_collections self.sample_size = sample_size self.input_channels = input_channels self.output_labels = output_labels self.transform = transform self.seed = seed #TODO: make dataset with seed def __getitem__(self, i): bc = self.band_collections[i] h, w = self.sample_size # random sampler x_range = max(1, bc.width - w) y_range = max(1, bc.height - h) x = np.random.randint(x_range) y = np.random.randint(y_range) # extract samples sample = dict() sample['image'] = (bc.sample(y, x, h, w) .ordered(*self.input_channels) .numpy() .transpose(1, 2, 0)) sample['mask'] = (bc.sample(y, x, h, w) .ordered(*self.output_labels) .numpy() .transpose(1, 2, 0)) # transform samples, e.g. augmentations or standartization if self.transform is not None: sample = self.transform(**sample) return sample def __len__(self): return len(self.band_collections)