In this post, I’d like to talk about how to create your own dataset, process it and make data batches ready to be fed into your neural networks, with the help of PyTorch.
In PyTorch, in order to feed your own training data into the network, you will mainly deal with two classes: the Dataset class and the Dataloader class. Now I will explain in more detail what they do.
Create your Dataset
class#
Overview#
Dataset
class is used to provide an interface for accessing all the training
or testing samples in your dataset. In order to achieve this, you have to
implement at least two methods, __getitem__
and __len__
so that each
training sample (in image classification, a sample means an image plus its
class label) can be accessed by its index.
In the initialization part of the class, you should collect a list of all the images and its labels in the dataset. When we want to get a particular sample, we then read the image, transform it and return the transformed image and the corresponding label.
A good example is ImageFolder
class provided by torchvision
package, you
can check its source code
here
to get a sense of how it actually works.
Data augmentation and preprocessing#
Data augmentation and preprocessing is an important part of the whole
work-flow. In PyTorch, we do it by providing a transform
parameter to the
Dataset
class. Transform
are class object which are called to process the
given input. You can cascade a series of transforms by providing a list of
transforms to
torchvision.transforms.Compose
method. Then the given transforms will be performed on the input in the order
they appear.
It should be noted that some of the transforms are for PIL
image object, such
as
RandomCrop()
and
Resize().
Other transforms are for torch Tensor
, such as
Normalize.
If your dataset contains images, you should first perform all transforms
expecting PIL
image object, then convert PIL
image to Tensor
using
ToTensor()
method. The ToTensor
transform will convert PIL
image to torch Tensor
of
shape $H\times W\times C$, with its values in the range [0.0, 1.0].
The Normalize
transform expects torch tensors. Its parameters are the means
and standard deviations of RGB channels of all the training images. For
ImageNet, the devs have already done that for us, the normalize transform
should be
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
For your own dataset, you have to calculate the statistics yourself.
Create data batch using Dataloader
#
Although we can access all the training data using the Dataset
class, but
that is not enough. For deep learning, we need the functionality such as
batching, shuffling, multiprocess data loading, etc. This is what the
Dataloader
class do.
The Dataloader
class accept a dataset and other parameters such as
batch_size
, batch_sampler
and number of workers
to load the data and so
on… Then we can iterate over the Dataloader
to get batches of training data
and train our models.
Loading variable size input images#
By default, Dataloader
use
collate_fn
method to pack a series of images and target as tensors (first dimension of
tensor is batch size). The default collate_fn
expects all the images in a
batch to have the same size because it uses torch.stack()
to pack the images.
If the images provided by Dataset
have variable size, you have to provide
your custom collate_fn
. A simple example is shown below:
# a simple custom collate function, just to show the idea
# `batch` is a list of tuple where first element is image tensor and
# second element is corresponding label
def my_collate(batch):
data = [item[0] for item in batch] # just form a list of tensor
target = [item[1] for item in batch]
target = torch.LongTensor(target)
return [data, target]
Conclusion#
In this post, I give an introduction to the use of Dataset
and Dataloader
in PyTorch. Dataset
is used to access single sample from your dataset and
transform it, while Dataloader
is used to load a batch of samples for
training or testing your models. If your training images have variable size,
you may also have to use your own custom collate_fn
.