Leveraging Einsum to Improve Your Deep Learning Codes
I recently stumbled upon the term “einops” and “einsum” while implementing some ideas I have for my graduation thesis on time series prediction. If you’ve been following the trends, you can see that recent implementations [1] of the state-of-the-art deep learning methods have been using these mysterious magical toolboxes that can do complex matrix operations, all while improving code readability and performance. With this article, I’ll do a quick introduction on what einsum actually means, and hopefully, it’ll give you a head start on how to start implementing the operations with einsum to improve your codes. Of course, I’m open to suggestions, and if you find any mistakes, I’ll be happy to correct them. I’ll also be using PyTorch implementation of einsum for the rest of the tutorial. I might also do a continuation of this tutorial which will dive into the world of einops.
Let’s talk a bit about the name einsum and einops itself. Einsum stands for the “Einstein summation” convention (introduced by Albert Einstein himself in 1916), which is often used in the applications of linear algebra to physics. [2] In linear algebra, we deal with matrices and tensors that have multiple dimensions. Operations involving these elements require us to write things in the form of index summations (Σ ← this thing you see in most math textbooks), and they get really complicated when you have several dimensions in your tensors. By using einsum, we can easily write the summation in a simpler form without reducing the clarity of the summation. A simple example here would be an inner product between two 1D vectors with a dimension of N components.
The application to deep learning programming is started by some genius who apparently found a way to apply this convention to NumPy, and thus einsum was born. Shortly after, another genius [1] thought, “Why limit the convention on summation operation while we can use the same convention on other miscellaneous tensor operations?”, and thus he blessed us with einops, standing for “Einstein operations”. Einsum can now be found in most of the generally available deep learning libraries such as TensorFlow, PyTorch, and the likes, while einops is a standalone package that you need to install separately. As iterated above, this article will mainly give you an introduction to einsum and einops will be out of the scope of this article.
There are three rules to understanding einsum: [3]
1. Indices that are repeated are summed over if the indices don’t appear on the resulting tensor.
2. Each index can only appear at most two times.
3. Each term must contain identical non-repeated indices.
Once you grasp the meaning of these rules, you’ll see that these rules ensure that einsum notation only has one meaning (one valid operation for the involved tensors). Let’s see how one can write an einsum operation in codes and how all of them obey the three rules above.
The way you write an einsum operation in PyTorch (or other packages. They don’t change much.) using these steps:
1. Write torch.einsum(“”, a, b)
with a
and b
denoting the variable name of the two tensors. The first string argument will be where you write the notation of your operation.
2. Write out the name of the indices for each dimension of the first tensor you want to operate in inside the quotations, separate them by comma and similarly write the indices for the second tensor. E.g. “ijk, ijk”
if you have two tensors with 3 dimensions. The index name you write here matters according to the rules described above.
3. Finally, write ->
and write the indices you want to hold and omit the indices you want to sum over. E.g. “ijk, ijk -> i”
means the operation is going to do an element-wise multiplication on 2nd (j
) and 3rd (k
) dimensions and add them along the dimensions, resulting in a tensor with 1 dimension of a size of the 1st dimension of the tensors (i
). You can actually think of this with the following nested loop. In fact, if you are not sure about your einsum operation, writing it out in a nested loop will be the best way to debug it.
import torcha = torch.ones([2, 3, 4])
b = torch.ones([2, 3, 4])
print(torch.einsum("ijk, ijk -> i", a, b))
>> tensor([12., 12.])# is equivalent to:c = torch.zeros(a.shape[0])
for i in range(a.shape[0]):
for j in range(a.shape[1]):
for k in range(a.shape[2]):
c[i] += a[i, j, k] * b[i, j, k]
print(c)
>> tensor([12., 12.])
Now, let’s look at some operations you can do with einsum to get a feel for how it works. You can see that you can rewrite most linear algebra operations with einsum, demonstrating its versatility.
- Inner product
a = torch.Tensor([1, 2, 3])
b = torch.Tensor([1, 2, 5])
c = torch.einsum("i, i->", a, b)
>> tensor(20.)
- Hadamard product (element-wise product)
a = torch.Tensor([[1, 2, 4], [2, 3, 1], [5, 2, 1]])
b = torch.Tensor([[5, 2, 4], [9, 4, 1], [8, 1, 6]])
print(torch.einsum("ij, ij -> ij", a, b))
>> tensor([[ 5., 4., 16.], [18., 12., 1.], [40., 2., 6.]]
- Matrix multiplication
a = torch.Tensor([[1, 2, 4], [2, 3, 1], [5, 2, 1]])
b = torch.Tensor([[5, 2, 4], [9, 4, 1], [8, 1, 6]])
print(torch.einsum("ij, jl -> il", a, b))
>> tensor([[55., 14., 30.], [45., 17., 17.], [51., 19., 28.]])
Note that the indices are written as ij
and jl
, meaning the 2nd dimension of the 1st matrix and the 1st dimension of the 2nd matrix is the same. The index j
also doesn’t appear in the resulting matrix (il
), meaning it’s multiplied and summed over, which is the same as doing matrix multiplication over these two matrices.
- Trace (diagonal sum of matrix)
a = torch.Tensor([[1, 2, 4], [2, 3, 1], [5, 2, 1]])
torch.einsum("ii->", a)
>> tensor(5.)
ii
means the dimension is the same, no index in the resulting matrix means it’s summed over (no multiplication as there is no second matrix).
One might think that all of these operations can be done easily using built-in operators such as @
or torch.MatMul
(matrix multiplication), torch.trace()
for calculating trace and others. However, by notating your operation in einsum, you can clearly state which dimension is multiplied and summed over, which dimension is preserved, and most importantly, see the dimensions of the resulting tensor easily. This will really come in handy when you try to do more complex operations such as batch multiplication (over temporal or channel dimension) and such. To show this, I will be using einsum to calculate the batched attention mechanism that has been gaining popularity and inspired several state-of-the-art deep learning methods. The specific type of attention mechanism I’m going to implement here is the dot product attention, which explanation can be found here.
- Dot product attention calculation
# Initialize data batch size, time length of sequence and the dimension.
batch_size, time_length, dimension = 20, 10, 256# Set the dimension of Q, K and V
dimension_weight = 512# Initialize random vector as attention inputs
query_vector = torch.randn((batch_size, time_length, dimension))
context_vector = torch.randn((batch_size, time_length, dimension))
print(query_vector.shape)
>> torch.Size([20, 10, 256])# Initialize random vector as attention weights
W_q = torch.randn((dimension, dimension_weight))
W_k = torch.randn((dimension, dimension_weight))
W_v = torch.randn((dimension, dimension_weight))
print(W_q.shape)
>> torch.Size([256, 512])# Calculate Q, K, V with einsum
Q = torch.einsum("dw, btd -> btw", W_q, query_vector)
K = torch.einsum("dw, btd -> btw", W_k, context_vector)
V = torch.einsum("dw, btd -> btw", W_v, context_vector)
print(Q.shape)
>> torch.Size([20, 10, 512])# Calculate scores (QK^T) using einsum. Essentially a batched matrix multiplication, with K transposed first.
scores = torch.einsum("bij, bkj -> bik", Q, K)
print(scores.shape)
>> torch.Size([20, 10, 10])# Calculate softmax as attention score
scores = torch.softmax(scores, dim=-1)# Calculate output (scores @ V). Also a batched matrix multiplication.
output = torch.einsum("bik, bkj -> bij", scores, V)
print(output.shape)
>> torch.Size([20, 10, 512])
Hopefully, with this introduction, you can start using einsum in your model implementations. I might update this article to include more examples in the future, so stay tuned!
References — all accessed on 15 December 2020
[1] Einops: https://github.com/arogozhnikov/einops
[2] Einstein notation — Wikipedia: https://en.wikipedia.org/wiki/Einstein_notation
[3] Einstein Summation — WolframMathWorld: https://mathworld.wolfram.com/EinsteinSummation.html
[4] Einsum Is All You Need — Einstein Summation In Deep Learning: https://rockt.github.io/2018/04/30/einsum
[5] Attention? Attention!: https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html#summary