import torch
= torch.tensor([-100, -5, 2, 100])
logits = logits.exp()
logits = logits / logits.sum()
probs sum() probs, probs.
(tensor([0., 0., 0., nan]), tensor(nan))
January 12, 2023
apart from less code of course! Another reason: it is safer!
I like to have more controll over what and how I am doing things, instead of using black boxes. But be warned that you can get burned when computing negative log likelihood yourself (the same is tru for softmaxes for example).
see this example:
import torch
logits = torch.tensor([-100, -5, 2, 100])
logits = logits.exp()
probs = logits / logits.sum()
probs, probs.sum()
(tensor([0., 0., 0., nan]), tensor(nan))
makes sense, right? exp(100) is VERY large, so if your network misbehaves and produces extreme activations, you have a problem, but…
import torch
logits = torch.tensor([-100, -5, 2, 100])
# here we deduct max value from the logits, so everyting is in (-∞, 0)
#----------------------
logits -= logits.max()
#----------------------
logits = logits.exp()
probs = logits / logits.sum()
probs, probs.sum()
(tensor([0.0000e+00, 0.0000e+00, 2.7465e-43, 1.0000e+00]), tensor(1.))
is working nicely, and that’s what F.cross_entropy does internally. Of course, you can always add that normalization to safeguard against such cases (or add batchnorm
layers to your architecture if you don’t wan’t to bother about such cases at the cost of a little more complexity and state in your model).
Plus of course I am sure there are also more good computational efficiency reasons to use torch’es built-in method do that.