sba

Stochastic Blahut Arimoto for Fine-Tuning LLMs

We will derive a reinforcement learning algorithm suitable for integration with deep learning architectures, grounded in robust principles from rate-distortion theory. This approach will yield an agent optimized for memory efficiency.

We need the following ingredients:

  • an alphabet $X$ over which we define the set of strings $X^\ast$;
  • a prior distribution $P(\tau)$ over the strings $X^\ast$ to generate samples, e.g. implemented using an LLM;
  • a reward model $R(\tau) \in \mathbb{R}$ over strings.

First we'll briefly review bounded-rational policies motivated by information theory. Such policies have many properties. An incomplete list is:

  • they have a controllable policy search complexity;
  • they are robust, in the sense that they protect against adversarial reward perturbations.

Rather than optimizing the expected reward, bounded-rational policies optimize the free energy objective \[ F(Q) = \mathbb{E}_{Q}\bigl[R(\tau)\bigr] - \frac{1}{\beta} D_{KL}\bigl( Q \| P \bigr) \] that is, the KL-regularized expected reward. The objective has a closed form solution: \[ P^\ast(\tau) = \frac{P(\tau) \exp(\beta R(\tau))}{\sum_{\tau'} P(\tau') \exp(\beta R(\tau'))} \] The parameter $\beta$, known as the inverse temperature, controls how strongly we wish to modify the prior distribution to adjust to the reward. In order to protect from adversarial perturbations, we do not want $\beta$ to be too large. This avoids e.g. overfitting to a potentially wrong reward function.

Because the optimal policy is a distribution, acting optimally means obtaining a sample. There's many ways to do this, but the most straightforward is rejection sampling. This works as follows:

  • Generate a string $\tau \sim P(\tau)$.
  • Generate a uniform random variate $u \sim U(0, 1)$.
  • If $u \leq \exp(\beta R(\tau) - \beta R^\ast)$ return $\tau$.
  • Else, repeat the procedure.

Here, $R^\ast$ is the optimal reward. If you don't know it, you can use an estimate.

The inverse temperature parameter controls how many candidates you'll have to generate before accepting a sample. If $\beta$ is close enough to zero, then almost every sample gets accepted. But if $\beta$ is very large, it might take forever to obtain a sample, especially if your reward function is structured like a needle-in-a-haystack problem. How can we mitigate this?

Rather than sampling directly from the optimal policy, we can approach it by taking smaller, intermediate steps, following a geodesic in information geometry.

To do so, notice that if $\beta \approx 0$ then accepting a sample is easy. We can exploit this. Choose a new inverse temperature $\alpha \approx 0$, and then alternate the following steps:

  1. Initialization: Let $P_0(\tau)$ be the starting model with parameters $\theta_0$.
  2. Generate better trajectories: generate $N$ samples $\tau^1, \ldots, \tau^N$ from the bounded-rational optimal policy using $\alpha$ as the inverse temperature parameter and $P_t(\tau)$ as the prior over strings.
  3. Stopping condition: If in addition all the samples $\tau^1, \ldots, \tau^N$ pass the acceptance test using the target inverse temperature $\beta$ instead of $\alpha$, stop.
  4. Fine-tune prior model: Fine-tune the prior model $P_t(\tau)$ with the generated samples $\tau^1, \ldots, \tau^N$. This yields a new model $P_t(\tau)$ with parameters $\theta_t$.
  5. Repeat: Set $t \leftarrow t + 1$ and repeat from step 2.

The resulting distribution $P_t(\tau)$ is our bounded-rational policy. You will have to experiment with the choices of $\alpha$ (which controls the step size) and $N$ (which controls the representation quality of the target distribution) to obtain a satisfactory training time.

The above algorithm generates a new prior $P(\tau)$ which places more weights on desirable strings. However, often we want policies to respond to a user-provided context string $c \in X^\ast$, i.e. we want to sample strings from $P(\tau|c)$, not $P(\tau)$. The problem is that the contexts $c$ are not generated in a way that conforms to the reward function, so the training procedure above will bias the model.

\[ \mathbb{E}\Bigl[ R(\tau) \Bigr] - \frac{1}{\beta} I() \]

  • sba.txt
  • Last modified: 2024/10/27 18:11
  • by pedroortega