test_text = ["The cat sat on the mat.", "The dog", "Test sentence"]
inputs = tokenizer(test_text, return_tensors="pt", padding=True).to(device)Calculating Loss on a Transformer
Introduction
If you’ve ever fine-tuned a language model, you’ve probably just called model(**inputs, labels=labels) and trusted that outputs.loss gives you the right thing. But what’s actually happening under the hood? Understanding this is crucial when you need to implement custom loss functions or debug training issues.
In this post, we’ll build up the loss calculation from first principles, see why naive implementations fail, and arrive at the numerically stable version that HuggingFace uses.
Setup
Let’s start with a simple example - three sentences of different lengths:
Since our sentences have different lengths, the tokenizer pads shorter sequences. This padding will become important when we calculate loss - we don’t want to penalize the model for its predictions on padding tokens.
The HuggingFace Way
Here’s how you’d typically compute loss with HuggingFace transformers:
labels = inputs["input_ids"].clone()
labels[inputs["attention_mask"] == 0] = -100
with torch.no_grad():
outputs = llm_model(**inputs, labels=labels)
outputs.lossTwo things to note here:
We clone
input_ids: Python passes objects by reference, so without.clone(), setting padding positions to-100would modify our original inputs too.The magic
-100: PyTorch’sCrossEntropyLossignores any position with this label. By setting padding positions to-100, we exclude them from the loss calculation.
But what’s actually happening inside? Let’s derive it step by step.
Deriving the Loss from First Principles
The goal of language modeling is simple: given a sequence of tokens, predict the next token. The loss measures how well our predictions match reality.
Mathematically, we want the negative log probability of the correct next token. The model outputs logits (unnormalized scores) for each vocabulary item, so we need to:
- Convert logits to probabilities via softmax
- Extract the probability assigned to the correct token
- Take the negative log
Here’s the naive implementation. Pay close attention to the shape of the tensors and the offsets used in the indexing.
# Convert to probabilities
unnormalized_prob = outputs.logits.exp()
normalised_prob = unnormalized_prob / unnormalized_prob.sum(dim=-1, keepdim=True)
# normalised_prob.shape = (batch_size, sequence_length, vocab_size)
# Extract probability of correct tokens (shifted by 1 since we predict next token)
correct_token_probs = torch.gather(
normalised_prob, 2, inputs.input_ids[:, 1:].unsqueeze(-1)
).squeeze(-1)
# correct_token_probs.shape = (batch_size, sequence_length)
# Negative log likelihood, ignoring padding
valid_indices = inputs.attention_mask[:, 1:] == 1 # (batch_size, sequence_length)
loss = -correct_token_probs[valid_indices].log().mean()This code implements the naive, first-principles approach to calculating cross-entropy loss for a language model. First, it converts the model’s raw logits into probabilities by applying the softmax function manually: exponentiating each logit and dividing by the sum across the vocabulary dimension (unnormalized_prob.sum(dim=-1, keepdim=True)).
The torch.gather operation then extracts the probability the model assigned to each correct next token—note the [:, 1:] offset, which is essential because the prediction at position t corresponds to the actual token at position t+1.
Finally, the loss is computed as the negative log of these probabilities, averaged only over non-padding positions (identified by the attention mask).
While mathematically correct, this implementation is numerically unstable in practice: calling .exp() on large logits can produce values that overflow to infinity, which is why production code uses the log-sum-exp trick instead.
The Numerical Stability Problem
The code above works mathematically, but fails in practice. The culprit? .exp() on large logits produces enormous numbers that overflow to infinity.
The fix is the log-sum-exp trick. Instead of computing log(exp(x) / sum(exp(x))), we compute x - logsumexp(x), which is mathematically equivalent but numerically stable:
# Log probabilities directly (numerically stable)
log_probs = outputs.logits - outputs.logits.logsumexp(dim=-1, keepdim=True)
# Extract log probability of correct tokens
correct_token_log_probs = torch.gather(log_probs, 2, inputs.input_ids[:, 1:].unsqueeze(-1)).squeeze(
-1
)
# Negative log likelihood, ignoring padding
loss = -correct_token_log_probs[inputs.attention_mask[:, 1:] == 1].mean()The PyTorch CrossEntropyLoss Way
Of course, PyTorch provides all of this in CrossEntropyLoss. To match HuggingFace exactly:
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fn(
outputs.logits[:, :-1, :].reshape(-1, outputs.logits.shape[-1]), # All predictions except last
labels[:, 1:].reshape(-1), # All targets except first
)Note the offsets: logits[:, :-1] and labels[:, 1:]. We drop the last logit because there’s no “next token” to predict after the final token. Similarly, we drop the first label because there’s no prediction for the first token (nothing came before it).
When the model sees an <end_of_sequence> token, it has nothing meaningful left to predict - so we simply stop there.
Summary
The transformer loss is conceptually simple - it’s just cross-entropy between predicted and actual next tokens. The implementation details matter though:
- Offset by 1: Predictions at position
tcorrespond to labels at positiont+1 - Ignore padding: Use
-100as the ignore index - Numerical stability: Use log-sum-exp instead of naive softmax
Understanding this foundation is essential when you move beyond standard training - whether that’s implementing custom losses, working with masked language models, or debugging why your fine-tuning isn’t converging.