Difference between view, reshape, transpose and permute in PyTorch
PyTorch provides a lot of methods for the Tensor type. Some of these methods
may be confusing for new users. Here, I would like to talk about
view() vs reshape() and transpose()
view() vs transpose()
reshape() can be used to change the size or shape of
tensors. But they are slightly different.
view() has existed for a long time. It will return a tensor with the new
shape. The returned tensor shares the underling data with the original tensor.
If you change the tensor value in the returned tensor, the corresponding value
in the viewed tensor also changes.
On the other hand, it seems that
reshape() has been introduced in version
0.4. According to the
Returns a tensor with the same data and number of elements as input, but with the specified shape. When possible, the returned tensor will be a view of input. Otherwise, it will be a copy. Contiguous inputs and inputs with compatible strides can be reshaped without copying, but you should not depend on the copying vs. viewing behavior.
It means that
torch.reshape may return a copy or a view of the original
tensor. You can not count on that to return a view or a copy. According to the
if you need a copy use clone() if you need the same storage use view(). The semantics of reshape() are that it may or may not share the storage and you don’t know beforehand.
As a side note, I found that torch version 0.4.1 and 1.0.1 behaves differently
when you print the
id of original tensor and viewing tensor:
In : import torch In : a = torch.rand(3, 4) In : id(a), id(a.storage()) Out: (2236511690472, 2236511611848) In : b = a.view(2, 6) In : id(b), id(b.storage()) Out: (2236523527984, 2236470501128)
You see that
b.storage() is not the same. Isn’t
that their underlying data the same? Why this difference?
I filed an issue in the
PyTorch repo and got answers from the developer. It turns out that to find the
data pointer, we have to use the
data_ptr() method. You will find that their
data pointers are the same.
view() vs transpose()
view() can also be used to change the shape of a tensor
and it also returns a new tensor sharing the data with the original tensor:
Returns a tensor that is a transposed version of input. The given dimensions dim0 and dim1 are swapped.
The resulting out tensor shares it’s underlying storage with the input tensor, so changing the content of one would change the content of the other.
One difference is that
view() can only operate on contiguous tensor and the
returned tensor is still contiguous.
transpose() can operate both on
contiguous and non-contiguous tensor. Unlike
view(), the returned tensor may
be not contiguous any more.
But what does contiguous mean?
There is a good answer on SO
which discusses the meaning of
contiguous in Numpy. It also applies to
As I understand,
contiguous in PyTorch means if the neighboring elements in
the tensor are actually next to each other in memory. Let’s take a simple
x = torch.tensor([[1, 2, 3], [4, 5, 6]]) # x is contiguous y = torch.transpose(0, 1) # y is non-contiguous
y in the above example share the same memory space1.
print(x.data_ptr()) # 94018404758288 print(y.data_ptr()) # 94018404758288
If you check their contiguity with
you will find that
x is contiguous but
y is not.
print(x.is_contiguous()) # True print(y.is_contiguous()) # False
Since x is contiguous, x and x are next to each other in memory. But y and y is not.
A lot of tensor operations requires that the tensor should be contiguous,
otherwise, an error will be thrown. To make a non-contiguous tensor become
contiguous, use call the
which will return a new contiguous tensor. In plain words, it will create a new
memory space for the new tensor and copy the value from the non-contiguous
tensor to the new tensor.
transpose() and permute()
tranpose() are similar.
transpose() can only swap two
permute() can swap all the dimensions. For example:
x = torch.rand(16, 32, 3) y = x.tranpose(0, 2) z = x.permute(2, 1, 0)
Note that, in
permute(), you must provide the new order of all the
transpose(), you can only provide two dimensions.
can be thought as a special case of
permute() method in for 2D tensors.
- tensor data pointers.
- view after transpose raises non-contiguous error.
- When to use which, permute, view, transpose.
- Difference between reshape() and view().
To show a tensor’s memory address, use
License CC BY-NC-ND 4.0