In this post, I’d like to talk about how to create your own dataset, process it and make data batches ready to be fed into your neural networks, with the help of PyTorch.
In PyTorch, in order to feed your own training data into the network, you will mainly deal with two classes: the Dataset class and the Dataloader class. Now I will explain in more detail what they do.
Dataset class is used to provide an interface for accessing all the training or testing samples in your dataset. In order to achieve this, you have to implement at least two methods,
__len__ so that each training sample (in image classification, a sample means an image plus its class label) can be accessed by its index.
In the initialization part of the class, you should collect a list of all the images and its labels in the dataset. When we want to get a particular sample, we then read the image, transform it and return the transformed image and the corresponding label.
A good example is
ImageFolder class provided by
torchvision package, you can check its source code here to get a sense of how it actually works.
Data augmentation and preprocessing
Data augmentation and preprocessing is an important part of the whole work-flow. In PyTorch, we do it by providing a
transform parameter to the
Transform are class object which are called to process the given input. You can cascade a series of transforms by providing a list of transforms to
torchvision.transforms.Compose method. Then the given transforms will be performed on the input in the order they appear.
It should be noted that some of the transforms are for
PIL image object, such as RandomCrop() and Resize(). Other transforms are for torch
Tensor, such as Normalize. If your dataset contains images, you should first perform all transforms expecting
PIL image object, then convert
PIL image to
Tensor using ToTensor() method. The
ToTensor transform will convert
PIL image to torch
Tensor of shape $H\times W\times C$, with its values in the range [0.0, 1.0].
Normalize transform expects torch tensors. Its parameters are the means and standard deviations of RGB channels of all the training images. For ImageNet, the devs have already done that for us, the normalize transform should be
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
For your own dataset, you have to calculate the statistics yourself.
Create data batch using
Although we can access all the training data using the
Dataset class, but that is not enough. For deep learning, we need the functionality such as batching, shuffling, multiprocess data loading, etc. This is what the
Dataloader class do.
Dataloader class accept a dataset and other parameters such as
batch_sampler and number of
workers to load the data and so on… Then we can iterate over the
Dataloader to get batches of training data and train our models.
Loading variable size input images
collate_fn method to pack a series of images and target as tensors (first dimension of tensor is batch size). The default
collate_fn expects all the images in a batch to have the same size because it uses
torch.stack() to pack the images. If the images provided by
Dataset have variable size, you have to provide your custom
collate_fn. A simple example is shown below:
# a simple custom collate function, just to show the idea # `batch` is a list of tuple where first element is image tensor and # second element is corresponding label def my_collate(batch): data = [item for item in batch] # just form a list of tensor target = [item for item in batch] target = torch.LongTensor(target) return [data, target]
In this post, I give an introduction to the use of
Dataloader in PyTorch.
Dataset is used to access single sample from your dataset and transform it, while
Dataloader is used to load a batch of samples for training or testing your models. If your training images have variable size, you may also have to use your own custom
License CC BY-NC-ND 4.0