In this post, I’d like to talk about how to create your own dataset, process it and make data batches ready to be fed into your neural networks, with the help of PyTorch.
In PyTorch, in order to feed your own training data into the network, you will mainly deal with two classes: the Dataset class and the Dataloader class. Now I will explain in more detail what they do.
Dataset class is used to provide an interface for accessing all the training
or testing samples in your dataset. In order to achieve this, you have to
implement at least two methods,
__len__ so that each
training sample (in image classification, a sample means an image plus its
class label) can be accessed by its index.
In the initialization part of the class, you should collect a list of all the images and its labels in the dataset. When we want to get a particular sample, we then read the image, transform it and return the transformed image and the corresponding label.
A good example is
ImageFolder class provided by
torchvision package, you
can check its source code
to get a sense of how it actually works.
Data augmentation and preprocessing
Data augmentation and preprocessing is an important part of the whole
work-flow. In PyTorch, we do it by providing a
transform parameter to the
Transform are class object which are called to process the
given input. You can cascade a series of transforms by providing a list of
method. Then the given transforms will be performed on the input in the order
It should be noted that some of the transforms are for
PIL image object, such
Other transforms are for torch
Tensor, such as
If your dataset contains images, you should first perform all transforms
PIL image object, then convert
PIL image to
ToTensor transform will convert
PIL image to torch
shape $H\times W\times C$, with its values in the range [0.0, 1.0].
Normalize transform expects torch tensors. Its parameters are the means
and standard deviations of RGB channels of all the training images. For
ImageNet, the devs have already done that for us, the normalize transform
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
For your own dataset, you have to calculate the statistics yourself.
Create data batch using
Although we can access all the training data using the
Dataset class, but
that is not enough. For deep learning, we need the functionality such as
batching, shuffling, multiprocess data loading, etc. This is what the
Dataloader class do.
Dataloader class accept a dataset and other parameters such as
batch_sampler and number of
workers to load the data and so
on… Then we can iterate over the
Dataloader to get batches of training data
and train our models.
Loading variable size input images
method to pack a series of images and target as tensors (first dimension of
tensor is batch size). The default
collate_fn expects all the images in a
batch to have the same size because it uses
torch.stack() to pack the images.
If the images provided by
Dataset have variable size, you have to provide
collate_fn. A simple example is shown below:
# a simple custom collate function, just to show the idea # `batch` is a list of tuple where first element is image tensor and # second element is corresponding label def my_collate(batch): data = [item for item in batch] # just form a list of tensor target = [item for item in batch] target = torch.LongTensor(target) return [data, target]
In this post, I give an introduction to the use of
Dataset is used to access single sample from your dataset and
transform it, while
Dataloader is used to load a batch of samples for
training or testing your models. If your training images have variable size,
you may also have to use your own custom
License CC BY-NC-ND 4.0