In this post, I’d like to talk about how to create your own dataset, process it and make data batches ready to be fed into your neural networks, with the help of PyTorch.

In PyTorch, in order to feed your own training data into the network, you will mainly deal with two classes: the Dataset class and the Dataloader class. Now I will explain in more detail what they do.

# Create your Dataset class

## Overview

Dataset class is used to provide an interface for accessing all the training or testing samples in your dataset. In order to achieve this, you have to implement at least two methods, __getitem__ and __len__ so that each training sample (in image classification, a sample means an image plus its class label) can be accessed by its index.

In the initialization part of the class, you should collect a list of all the images and its labels in the dataset. When we want to get a particular sample, we then read the image, transform it and return the transformed image and the corresponding label.

A good example is ImageFolder class provided by torchvision package, you can check its source code here to get a sense of how it actually works.

## Data augmentation and preprocessing

Data augmentation and preprocessing is an important part of the whole work-flow. In PyTorch, we do it by providing a transform parameter to the Dataset class. Transform are class object which are called to process the given input. You can cascade a series of transforms by providing a list of transforms to torchvision.transforms.Compose method. Then the given transforms will be performed on the input in the order they appear.

It should be noted that some of the transforms are for PIL image object, such as RandomCrop() and Resize(). Other transforms are for torch Tensor, such as Normalize. If your dataset contains images, you should first perform all transforms expecting PIL image object, then convert PIL image to Tensor using ToTensor() method. The ToTensor transform will convert PIL image to torch Tensor of shape $H\times W\times C$, with its values in the range [0.0, 1.0].

The Normalize transform expects torch tensors. Its parameters are the means and standard deviations of RGB channels of all the training images. For ImageNet, the devs have already done that for us, the normalize transform should be

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

For your own dataset, you have to calculate the statistics yourself.

# Create data batch using Dataloader

Although we can access all the training data using the Dataset class, but that is not enough. For deep learning, we need the functionality such as batching, shuffling, multiprocess data loading, etc. This is what the Dataloader class do.

The Dataloader class accept a dataset and other parameters such as batch_size, batch_sampler and number of workers to load the data and so on… Then we can iterate over the Dataloader to get batches of training data and train our models.

By default, Dataloader use collate_fn method to pack a series of images and target as tensors (first dimension of tensor is batch size). The default collate_fn expects all the images in a batch to have the same size because it uses torch.stack() to pack the images. If the images provided by Dataset have variable size, you have to provide your custom collate_fn. A simple example is shown below:

# a simple custom collate function, just to show the idea

# batch is a list of tuple where first element is image tensor and

# second element is corresponding label

def my_collate(batch):
data = [item[0] for item in batch]  # just form a list of tensor

target = [item[1] for item in batch]
target = torch.LongTensor(target)
return [data, target]

# Conclusion

In this post, I give an introduction to the use of Dataset and Dataloader in PyTorch. Dataset is used to access single sample from your dataset and transform it, while Dataloader is used to load a batch of samples for training or testing your models. If your training images have variable size, you may also have to use your own custom collate_fn.