In this post, I will share how PyTorch set the number of the threads to use for its operations.

torch.set_num_threads() is used to set the number of threads used for intra operations on CPU. According to discussions here, intra operation roughly means operations executed within an operation, for example, for matrix multiplication. By default, pytorch will use all the available cores on the computer, to verify this, we can use torch.get_num_threads() get the default threads number.

For operations supporting parallelism, increase the number of threads will usually leads to faster execution on CPU. Apart from setting the number of threads via torch.set_num_threads, we can also set the env variable OMP_NUM_THREADS or MKL_NUM_THREADS to set the number of threads. Below is a toy script to verify this (adapted from code here):

import time

import numpy as np
import torch

INDEX = 10000
NELE = 1000
a = torch.rand(INDEX, NELE)
index = np.random.randint(INDEX-1, size=INDEX*8)
b = torch.from_numpy(index)

start = time.time()
for _ in range(10):
    res = a.index_select(0, b)
print("the number of cpu threads: {}, time: {}".format(torch.get_num_threads(), time.time()-start))

Setting OM_NUM_THREADS to 1, 2, 4, 8, the running time is:

the number of cpu threads: 1, time: 2.927994728088379
the number of cpu threads: 2, time: 1.6809608936309814
the number of cpu threads: 4, time: 1.092754602432251
the number of cpu threads: 8, time: 0.7028472423553467