Skip to main content
  1. Posts/

Understanding Computational Graphs in PyTorch

··693 words·4 mins·
Table of Contents

PyTorch is a relatively new deep learning library which support dynamic computation graphs. It has gained a lot of attention after its official release in January. In this post, I want to share what I have learned about the computation graph in PyTorch. Without basic knowledge of computation graph, we can hardly understand what is actually happening under the hood when we are trying to train our landscape-changing neural networks.

Computation graphs and its use in PyTorch
#

The idea of computation graph is important in the optimization of large-scale neural networks. In simple terms, a computation graph is a DAG in which nodes represent variables (tensors, matrix, scalars, etc.) and edge represent some mathematical operations (for example, summation, multiplication). The computation graph has some leaf variables. The root variables of the graph are computed according to operations defined by the graph. During the optimization step, we combine the chain rule and the graph to compute the derivative of the output w.r.t the learnable variable in the graph and update these variables to make the output close to what we want. In neural networks, these learnable variables are often called weight and bias.

You can also think of neural network as a computational graph: the input images and the parameters in each layer are leaf variables, the outputs (usually it is called the loss and we minimize it to update the parameters of the network) of neural networks are the root variables in the graph.

How is computation graph created and freed?
#

In PyTorch, the computation graph is created for each iteration in an epoch. In each iteration, we execute the forward pass, compute the derivatives of output w.r.t to the parameters of the network, and update the parameters to fit the given examples. After doing the backward pass, the graph will be freed to save memory. In the next iteration, a fresh new graph is created and ready for back-propagation.

Because the computation graph will be freed by default after the first backward pass, you will encounter errors if you are trying to do backward on the same graph the second time. That is why the following error message pops up:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time

A toy example
#

Now, let’s take a small example to illustrate the idea. Suppose that we have a computation graph shown above. The variable d and e is the output, and a is the input. The underlining computation is:

import torch
from torch.autograd import Variable
a = Variable(torch.rand(1, 4), requires_grad=True)
b = a**2
c = b*2
d = c.mean()
e = c.sum()

when we do d.backward(), that is fine. After this computation, the part of graph that calculate d will be freed by default to save memory. So if we do e.backward(), the error message will pop up. In order to do e.backward(), we have to set the parameter retain_graph to True in d.backward(), i.e.,

d.backward(retain_graph=True)

As long as you use retain_graph=True in your backward method, you can do backward any time you want:

d.backward(retain_graph=True) # fine
e.backward(retain_graph=True) # fine
d.backward() # also fine
e.backward() # error will occur!

Real use cases
#

A real use case that you want to backward through the graph for more than once is multi-task learning where you have multiple losses at different layers. Suppose that you have 2 losses: loss1 and loss2 and they reside in different layers. In order to back-prop the gradient of loss1 and loss2 w.r.t to the learnable weight of your network independently. You have to use retain_graph=True in backward() method in the first back-propagated loss.

# suppose you first back-propagate loss1, then loss2
# (you can also do it in reverse order)
loss1.backward(retain_graph=True)
loss2.backward() # now the graph is freed, and next process of batch gradient descent is ready

optimizer.step() # update the network parameters

References
#

Related

Accelerate Batched Image Inference in PyTorch
··517 words·3 mins
Set the Number of Threads to Use in PyTorch
··245 words·2 mins
Distributed Training in PyTorch with Horovod
··827 words·4 mins