pytorch - connection between loss.backward() and optimizer.step()

Machine LearningNeural NetworkPytorchGradient Descent

Machine Learning Problem Overview


Where is an explicit connection between the optimizer and the loss?

How does the optimizer know where to get the gradients of the loss without a call liks this optimizer.step(loss)?

-More context-

When I minimize the loss, I didn't have to pass the gradients to the optimizer.

loss.backward() # Back Propagation
optimizer.step() # Gardient Descent

Machine Learning Solutions


Solution 1 - Machine Learning

Without delving too deep into the internals of pytorch, I can offer a simplistic answer:

Recall that when initializing optimizer you explicitly tell it what parameters (tensors) of the model it should be updating. The gradients are "stored" by the tensors themselves (they have a grad and a requires_grad attributes) once you call backward() on the loss. After computing the gradients for all tensors in the model, calling optimizer.step() makes the optimizer iterate over all parameters (tensors) it is supposed to update and use their internally stored grad to update their values.

More info on computational graphs and the additional "grad" information stored in pytorch tensors can be found in this answer.

Referencing the parameters by the optimizer can sometimes cause troubles, e.g., when the model is moved to GPU after initializing the optimizer. Make sure you are done setting up your model before constructing the optimizer. See this answer for more details.

Solution 2 - Machine Learning

When you call loss.backward(), all it does is compute gradient of loss w.r.t all the parameters in loss that have requires_grad = True and store them in parameter.grad attribute for every parameter.

optimizer.step() updates all the parameters based on parameter.grad

Solution 3 - Machine Learning

Perhaps this will clarify a little the connection between loss.backward and optim.step (although the other answers are to the point).

# Our "model"
x = torch.tensor([1., 2.], requires_grad=True)
y = 100*x

# Compute loss
loss = y.sum()

# Compute gradients of the parameters w.r.t. the loss
print(x.grad)     # None
loss.backward()      
print(x.grad)     # tensor([100., 100.])

# MOdify the parameters by subtracting the gradient
optim = torch.optim.SGD([x], lr=0.001)
print(x)        # tensor([1., 2.], requires_grad=True)
optim.step()
print(x)        # tensor([0.9000, 1.9000], requires_grad=True)

loss.backward() sets the grad attribute of all tensors with requires_grad=True in the computational graph of which loss is the leaf (only x in this case).

Optimizer just iterates through the list of parameters (tensors) it received on initialization and everywhere where a tensor has requires_grad=True, it subtracts the value of its gradient stored in its .grad property (simply multiplied by the learning rate in case of SGD). It doesn't need to know with respect to what loss the gradients were computed it just wants to access that .grad property so it can do x = x - lr * x.grad

Note that if we were doing this in a train loop we would call optim.zero_grad() because in each train step we want to compute new gradients - we don't care about gradients from the previous batch. Not zeroing grads would lead to gradient accumulation across batches.

Solution 4 - Machine Learning

Let's say we defined a model: model, and loss function: criterion and we have the following sequence of steps:

pred = model(input)
loss = criterion(pred, true_labels)
loss.backward()

pred will have an grad_fn attribute, that references a function that created it, and ties it back to the model. Therefore, loss.backward() will have information about the model it is working with.

Try removing grad_fn attribute, for example with:

pred = pred.clone().detach()

Then the model gradients will be None and consequently weights will not get updated.

And the optimizer is tied to the model because we pass model.parameters() when we create the optimizer.

Solution 5 - Machine Learning

Some answers explained well, but I'd like to give a specific example to explain the mechanism.

Suppose we have a function : z = 3 x^2 + y^3.
The updating gradient formula of z w.r.t x and y is:

enter image description here

initial values are x=1 and y=2.

x = torch.tensor([1.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
z = 3*x**2+y**3

print("x.grad: ", x.grad)
print("y.grad: ", y.grad)
print("z.grad: ", z.grad)

# print result should be:
x.grad:  None
y.grad:  None
z.grad:  None

Then calculating the gradient of x and y in current value (x=1, y=2)

enter image description here

# calculate the gradient
z.backward()

print("x.grad: ", x.grad)
print("y.grad: ", y.grad)
print("z.grad: ", z.grad)

# print result should be:
x.grad:  tensor([6.])
y.grad:  tensor([12.])
z.grad:  None

Finally, using SGD optimizer to update the value of x and y according the formula: enter image description here

# create an optimizer, pass x,y as the paramaters to be update, setting the learning rate lr=0.1
optimizer = optim.SGD([x, y], lr=0.1)

# executing an update step
optimizer.step()

# print the updated values of x and y
print("x:", x)
print("y:", y)

# print result should be:
x: tensor([0.4000], requires_grad=True)
y: tensor([0.8000], requires_grad=True)

Solution 6 - Machine Learning

Short answer:

loss.backward() # do gradient of all parameters for which we set required_grad= True. parameters could be any variable defined in code, like h2h or i2h.

optimizer.step() # according to the optimizer function (defined previously in our code), we update those parameters to finally get the minimum loss(error).

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
QuestionaerinView Question on Stackoverflow
Solution 1 - Machine LearningShaiView Answer on Stackoverflow
Solution 2 - Machine LearningGaneshView Answer on Stackoverflow
Solution 3 - Machine LearningpseudomarvinView Answer on Stackoverflow
Solution 4 - Machine LearningAkavallView Answer on Stackoverflow
Solution 5 - Machine LearningLollipopKnightView Answer on Stackoverflow
Solution 6 - Machine LearningpouryaView Answer on Stackoverflow