PyTorch provides a lot of methods for the Tensor type. Some of these methods may be confusing for new users. Here, I would like to talk about view() vs reshape(), transpose() vs permute().

view() vs reshape() and transpose()

view() vs transpose()

Both view() and reshape() can be used to change the size or shape of tensors. But they are slightly different.

The view() has existed for a long time. It will return a tensor with the new shape. The returned tensor shares the underling data with the original tensor. If you change the tensor value in the returned tensor, the corresponding value in the viewed tensor also changes.

On the other hand, it seems that reshape() has been introduced in version 0.4. According to the document, this method will

Returns a tensor with the same data and number of elements as input, but with the specified shape. When possible, the returned tensor will be a view of input. Otherwise, it will be a copy. Contiguous inputs and inputs with compatible strides can be reshaped without copying, but you should not depend on the copying vs. viewing behavior.

It means that torch.reshape may return a copy or a view of the original tensor. You can not count on that to return a view or a copy. According to the developer:

if you need a copy use clone() if you need the same storage use view(). The semantics of reshape() are that it may or may not share the storage and you don’t know beforehand.

As a side note, I found that torch version 0.4.1 and 1.0.1 behaves differently when you print the id of original tensor and viewing tensor:

In [1]: import torch

In [2]: a = torch.rand(3, 4)

In [3]: id(a), id(a.storage())
Out[3]: (2236511690472, 2236511611848)

In [4]: b = a.view(2, 6)

In [5]: id(b), id(b.storage())
Out[5]: (2236523527984, 2236470501128)

You see that id of a.storage() and b.storage() is not the same. Isn’t that their underlying data the same? Why this difference?

I filed an issue in the PyTorch repo and got answers from the developer. It turns out that to find the data pointer, we have to use the data_ptr() method. You will find that their data pointers are the same.

view() vs transpose()

transpose(), like view() can also be used to change the shape of a tensor and it also returns a new tensor sharing the data with the original tensor:

Returns a tensor that is a transposed version of input. The given dimensions dim0 and dim1 are swapped.

The resulting out tensor shares it’s underlying storage with the input tensor, so changing the content of one would change the content of the other.

One difference is that view() can only operate on contiguous tensor and the returned tensor is still contiguous. transpose() can operate both on contiguous and non-contiguous tensor. Unlike view(), the returned tensor may be not contiguous any more.

But what does contiguous mean?

There is a good answer on SO which discusses the meaning of contiguous in Numpy. It also applies to PyTorch.

As I understand, contiguous in PyTorch means if the neighboring elements in the tensor are actually next to each other in memory. Let’s take a simple example:

x = torch.tensor([[1, 2, 3], [4, 5, 6]]) # x is contiguous
y = torch.transpose(0, 1) # y is non-contiguous

Tensor x and y in the above example share the same memory space1.

print(x.data_ptr()) # 94018404758288
print(y.data_ptr()) # 94018404758288

If you check their contiguity with is_contiguous(), you will find that x is contiguous but y is not.

print(x.is_contiguous()) # True
print(y.is_contiguous()) # False

Since x is contiguous, x[0][0] and x0 are next to each other in memory. But y[0][0] and y0 is not.

A lot of tensor operations requires that the tensor should be contiguous, otherwise, an error will be thrown. To make a non-contiguous tensor become contiguous, use call the contiguous(), which will return a new contiguous tensor. In plain words, it will create a new memory space for the new tensor and copy the value from the non-contiguous tensor to the new tensor.

transpose() and permute()

permute() and tranpose() are similar. transpose() can only swap two dimension. But permute() can swap all the dimensions. For example:

x = torch.rand(16, 32, 3)
y = x.tranpose(0, 2)

z = x.permute(2, 1, 0)

Note that, in permute(), you must provide the new order of all the dimensions. In transpose(), you can only provide two dimensions. tranpose() can be thought as a special case of permute() method in for 2D tensors.

References


  1. To show a tensor’s memory address, use tensor.data_ptr().