What is the KL penalty?

Note

From now on, we’ll use the reward function that measures alliteration, as defined here.

Why do we even need the KL penalty?

We could just train a model as-is. If we do, the reward looks great:

The x-axis is number of steps

But if we look at the text generations near the end of training, the model always predicts the same 100 words:

Once upon a time to the tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall tall

Gradient descent has singlemindedly pushed the model weights toward maxmizing the reward function, and we’ve lost important qualities like:

  • The model should generate a variety of texts, not the same one every time
  • The text should be a story with acceptable grammar and coherence, not just unrelated words

Definition of the KL penalty

KL stands for Kullback–Leibler, as in Kullback–Leibler divergence. The KL divergence can be interpreted as a difference in cross-entropies, and we’ll use that form for the penalty:

kl_penalty_term = beta * (lm_cross_entropy - original_lm_cross_entropy)

where

  • beta >= 0 is a hyperparameter controlling the penalty strength
  • lm_cross_entropy is the cross entropy loss of the language model on the current text generation
  • original_lm_cross_entropy is the cross entropy loss using a copy of the weights before reinforcement learning

We add the kl_penalty_term to raw_reward (i.e. the reward from counting alliteration) and then apply RLOO to that composite reward.

Motivation:

  • Suppose that the LM has generated a sequence like “tall tall tall….” This sequence has high raw reward, but since this isn’t coherent English, we’d like the composite reward to be low.
  • lm_cross_entropy will be a small positive number, since the LM did generate this text.
  • original_lm_cross_entropy will be a large positive number, since the original LM wouldn’t have generated this text.
  • Hence the kl_penalty_term will be negative, which will reduce the composite reward.