Why do we need to explicitly call zero_grad()?

Neural NetworkDeep LearningPytorchGradient Descent

Neural Network Problem Overview


Why do we need to explicitly zero the gradients in PyTorch? Why can't gradients be zeroed when loss.backward() is called? What scenario is served by keeping the gradients on the graph and asking the user to explicitly zero the gradients?

Neural Network Solutions


Solution 1 - Neural Network

We explicitly need to call zero_grad() because, after loss.backward() (when gradients are computed), we need to use optimizer.step() to proceed gradient descent. More specifically, the gradients are not automatically zeroed because these two operations, loss.backward() and optimizer.step(), are separated, and optimizer.step() requires the just computed gradients.

In addition, sometimes, we need to accumulate gradient among some batches; to do that, we can simply call backward multiple times and optimize once.

Solution 2 - Neural Network

I have a use case for the current setup in PyTorch.

If one is using a recurrent neural network (RNN) that is making predictions at every step, one might want to have a hyperparameter that allows one to accumulate gradients back in time. Not zeroing the gradients at every time step allows for one to use back-propagating through time (BPTT) in interesting and novel ways.

If you would like more info on BPTT or RNNs see the article Recurrent Neural Networks Tutorial, Part 3 – Backpropagation Through Time and Vanishing Gradients or The Unreasonable Effectiveness of Recurrent Neural Networks.

Solution 3 - Neural Network

There is a cycle in PyTorch:

  • Forward when we get output or y_hat from the input,
  • Calculating loss where loss = loss_fn(y_hat, y)
  • loss.backward when we calculate the gradients
  • optimizer.step when we update parameters

Or in code:

for mb in range(10): # 10 mini batches
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

If we would not clear the gradients after the optimizer.step, which is the appropriate step or just before the next backward() gradients would accumulate. Here is an example showing accumulation:

import torch
w = torch.rand(5)
w.requires_grad_()
print(w) 
s = w.sum() 
s.backward()
print(w.grad) # tensor([1., 1., 1., 1., 1.])
s.backward()
print(w.grad) # tensor([2., 2., 2., 2., 2.])
s.backward()
print(w.grad) # tensor([3., 3., 3., 3., 3.])
s.backward()
print(w.grad) # tensor([4., 4., 4., 4., 4.])

loss.backward() does not have any way specifying this.

torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None)

From all the options you can specify there is no way to zero the gradients manually. Like this in previous mini example:

w.grad.zero_()

There was some discussion on doing zero_grad() every time with backward() (obviously previous gradients) and to keep grads with preserve_grads=True, but this never came to life.

Solution 4 - Neural Network

Leaving the gradients in place before calling .step() is useful in case you'd like to accumulate the gradient across multiple batches (as others have mentioned).

It's also useful for after calling .step() in case you'd like to implement momentum for SGD, and various other methods may depend on the values from the previous update's gradient.

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
QuestionWasi AhmadView Question on Stackoverflow
Solution 1 - Neural NetworkdancheView Answer on Stackoverflow
Solution 2 - Neural NetworktwricharView Answer on Stackoverflow
Solution 3 - Neural NetworkprostiView Answer on Stackoverflow
Solution 4 - Neural Network190290000 Ruble ManView Answer on Stackoverflow