Hands-on Automatic differentiation
Remember those videos you watched two hours ago? Maybe you got panicky at the time (how hard will it be to write the chain rule (with so many derivatives!) in backpropagation?!? ๐ฑ)
Don't fear! PyTorch has automatic differentiation abilities, meaning that it can track all the operations conducted on tensors and do the backpropagation automatically for you.
This is done thanks to the torch.autograd
package.
Tracking computations
PyTorch does not track all the computations on all the tensors (this would be extremely memory intensive!). To start tracking computations on a vector, set the .requires_grad
attribute to True
:
import torch
x = torch.rand(3, 7, requires_grad=True)
print(x)
The grad_fun
attribute
Whenever a tensor is created by an operation involving a tracked tensor, it has a grad_fun
attribute:
x = torch.ones(2, 4, requires_grad=True)
print(x)
y = x + 1
print(y)
print(y.grad_fn)
Judicious tracking
You don't want to track more than is necessary. There are multiple ways to avoid tracking what you don't want to.
You can simply stop tracking computations on a tensor with the method detach
:
x = torch.rand(4, 3, requires_grad=True)
print(x)
print(x.detach_())
You can change its requires_grad
flag:
x = torch.rand(4, 3, requires_grad=True)
print(x)
print(x.requires_grad_(False))
Alternatively, you can wrap any code you don't want to track with with torch.no_grad():
with torch.no_grad():
<some code>
Calculating gradients
After you have performed a number of operations on x
and obtained a final object (let's call it loss
since in the context of neural networks, the output of the loss function is the starting place of the backpropagation process), you can get the gradient of any object y
with:
loss.backward()
print(y.grad)
Example
Let's go over a simple example:
- let
real
be the tensor of some real values - let
predicted
be the tensor given by some model trying to predict these real values after an iteration
We will calculate the first derivative (first step of the backpropagation) manually and with the torch.autograd
package to really understand what that package does.
Let's fill real
and predicted
with random values since we don't have a real situation with a real network (but let's make sure to start recording the history of computations performed on predicted
):
real = torch.rand(3, 8)
print(real)
predicted = torch.rand(3, 8, requires_grad=True)
print(predicted)
Several loss functions can be used in machine learning, but let's use the one you saw in the previous video: \[\text{loss}=\sum_{}^{} (\text{predicted} - \text{real})^2\]
loss = (predicted - real).pow(2).sum()
Now, to train a model, after each forward-pass, we need to go through the backpropagation to adjust the weights and biases of the model. That means, we need to calculate all the derivatives, starting from the derivative of the predicted values up to the derivatives of the weights and biases.
Here, we will only do the very first step: calculate the derivative of predicted
.
Manual derivative calculation
As you saw in the previous video, the formula for this first derivative, with the loss function we used, is: \[\text{gradient}_\text{predicted}=2(\text{predicted} - \text{real})\]
There is no point in adding this operation to predicted
's computation history, so we will exclude it with with torch.no_grad():
with torch.no_grad():
manual_gradient_predicted = 2.0 * (predicted - real)
print(manual_gradient_predicted)
Automatic derivative calculation
Now, with torch.autograd
:
loss.backward()
Since we tracked computations on predicted
, we can calculate its gradient with:
auto_gradient_predicted = predicted.grad
print(auto_gradient_predicted)
Comparison
The result is the same, as can be tested with:
print(manual_gradient_predicted.eq(auto_gradient_predicted).all())
The calculation of this first derivative of backpropagation was simple enough. But to propagate all the derivatives calculations backward through the chain rule would quickly turn into a deep calculus problem. With torch.autograd
, calculating the gradients of all the other elements of the network is as simple as calling them with the attribute grad
once the function torch.Tensor.backward()
has been run.