In this post, I’d like to talk about how to create your own dataset, process it and make data batches ready to be fed into your neural networks, with the help of PyTorch.
In PyTorch, in order to feed your own training data into the network, you will mainly deal with two classes: the Dataset class and the Dataloader class. Now I will explain in more detail what they do.
Create your Dataset class#
Overview#
Dataset class is used to provide an interface for accessing all the training
or testing samples in your dataset. In order to achieve this, you have to
implement at least two methods, __getitem__ and __len__ so that each
training sample (in image classification, a sample means an image plus its
class label) can be accessed by its index.
In the initialization part of the class, you should collect a list of all the images and its labels in the dataset. When we want to get a particular sample, we then read the image, transform it and return the transformed image and the corresponding label.
A good example is ImageFolder class provided by torchvision package, you
can check its source code
here
to get a sense of how it actually works.
Data augmentation and preprocessing#
Data augmentation and preprocessing is an important part of the whole
work-flow. In PyTorch, we do it by providing a transform parameter to the
Dataset class. Transform are class object which are called to process the
given input. You can cascade a series of transforms by providing a list of
transforms to
torchvision.transforms.Compose
method. Then the given transforms will be performed on the input in the order
they appear.
It should be noted that some of the transforms are for PIL image object, such
as
RandomCrop()
and
Resize().
Other transforms are for torch Tensor, such as
Normalize.
If your dataset contains images, you should first perform all transforms
expecting PIL image object, then convert PIL image to Tensor using
ToTensor()
method. The ToTensor transform will convert PIL image to torch Tensor of
shape $H\times W\times C$, with its values in the range [0.0, 1.0].
The Normalize transform expects torch tensors. Its parameters are the means
and standard deviations of RGB channels of all the training images. For
ImageNet, the devs have already done that for us, the normalize transform
should be
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
For your own dataset, you have to calculate the statistics yourself.
Create data batch using Dataloader#
Although we can access all the training data using the Dataset class, but
that is not enough. For deep learning, we need the functionality such as
batching, shuffling, multiprocess data loading, etc. This is what the
Dataloader class do.
The Dataloader class accept a dataset and other parameters such as
batch_size, batch_sampler and number of workers to load the data and so
on… Then we can iterate over the Dataloader to get batches of training data
and train our models.
Loading variable size input images#
By default, Dataloader use
collate_fn
method to pack a series of images and target as tensors (first dimension of
tensor is batch size). The default collate_fn expects all the images in a
batch to have the same size because it uses torch.stack() to pack the images.
If the images provided by Dataset have variable size, you have to provide
your custom collate_fn. A simple example is shown below:
# a simple custom collate function, just to show the idea
# `batch` is a list of tuple where first element is image tensor and
# second element is corresponding label
def my_collate(batch):
data = [item[0] for item in batch] # just form a list of tensor
target = [item[1] for item in batch]
target = torch.LongTensor(target)
return [data, target]
Conclusion#
In this post, I give an introduction to the use of Dataset and Dataloader
in PyTorch. Dataset is used to access single sample from your dataset and
transform it, while Dataloader is used to load a batch of samples for
training or testing your models. If your training images have variable size,
you may also have to use your own custom collate_fn.