====== Stochastic Blahut Arimoto for Fine-Tuning LLMs ====== > We will derive a reinforcement learning algorithm suitable for integration with deep learning architectures, grounded in robust principles from rate-distortion theory. This approach will yield an agent optimized for memory efficiency. ===== Ingredients ===== We need the following ingredients: * an alphabet $X$ over which we define the set of strings $X^\ast$; * a prior distribution $P(\tau)$ over the strings $X^\ast$ to generate samples, e.g. implemented using an LLM; * a reward model $R(\tau) \in \mathbb{R}$ over strings. ===== Bounded-rational policies ===== First we'll briefly review bounded-rational policies motivated by information theory. Such policies have many properties. An incomplete list is: * they have a controllable policy search **complexity**; * they are **robust**, in the sense that they protect against adversarial reward perturbations. Rather than optimizing the expected reward, bounded-rational policies optimize the **free energy objective** \[ F(Q) = \mathbb{E}_{Q}\bigl[R(\tau)\bigr] - \frac{1}{\beta} D_{KL}\bigl( Q \| P \bigr) \] that is, the KL-regularized expected reward. The objective has a closed form solution: \[ P^\ast(\tau) = \frac{P(\tau) \exp(\beta R(\tau))}{\sum_{\tau'} P(\tau') \exp(\beta R(\tau'))} \] The parameter $\beta$, known as the inverse temperature, controls how strongly we wish to modify the prior distribution to adjust to the reward. In order to protect from adversarial perturbations, **we do not want $\beta$ to be too large**. This avoids e.g. overfitting to a potentially wrong reward function. ==== Sampling from the bounded-rational policy ==== Because the optimal policy is a distribution, acting optimally means obtaining a sample. There's many ways to do this, but the most straightforward is **rejection sampling**. This works as follows: * Generate a string $\tau \sim P(\tau)$. * Generate a uniform random variate $u \sim U(0, 1)$. * If $u \leq \exp(\beta R(\tau) - \beta R^\ast)$ return $\tau$. * Else, repeat the procedure. Here, $R^\ast$ is the optimal reward. If you don't know it, you can use an estimate. The inverse temperature parameter controls how many candidates you'll have to generate before accepting a sample. If $\beta$ is close enough to zero, then almost every sample gets accepted. But if $\beta$ is very large, it might take forever to obtain a sample, especially if your reward function is structured like a **needle-in-a-haystack** problem. How can we mitigate this? ==== Follow the geodesic ==== Rather than sampling directly from the optimal policy, we can approach it by taking smaller, intermediate steps, **following a geodesic** in information geometry. To do so, notice that if $\beta \approx 0$ then accepting a sample is easy. We can exploit this. Choose a new inverse temperature $\alpha \approx 0$, and then alternate the following steps: - **Initialization**: Let $P_0(\tau)$ be the starting model with parameters $\theta_0$. - **Generate better trajectories**: generate $N$ samples $\tau^1, \ldots, \tau^N$ from the bounded-rational optimal policy using $\alpha$ as the inverse temperature parameter and $P_t(\tau)$ as the prior over strings. - **Stopping condition**: If in addition all the samples $\tau^1, \ldots, \tau^N$ pass the acceptance test using the target inverse temperature $\beta$ instead of $\alpha$, stop. - **Fine-tune prior model**: Fine-tune the prior model $P_t(\tau)$ with the generated samples $\tau^1, \ldots, \tau^N$. This yields a new model $P_t(\tau)$ with parameters $\theta_t$. - **Repeat**: Set $t \leftarrow t + 1$ and repeat from step 2. The resulting distribution $P_t(\tau)$ is our bounded-rational policy. You will have to experiment with the choices of $\alpha$ (which controls the step size) and $N$ (which controls the representation quality of the target distribution) to obtain a satisfactory training time. ===== Adding a user-provided context ===== The above algorithm generates a new prior $P(\tau)$ which places more weights on desirable strings. However, often we want policies to respond to a user-provided context string $c \in X^\ast$, i.e. we want to sample strings from $P(\tau|c)$, not $P(\tau)$. The problem is that the contexts $c$ are not generated in a way that conforms to the reward function, so the training procedure above will bias the model. ==== Enter memory-constrained agents ==== \[ \mathbb{E}\Bigl[ R(\tau) \Bigr] - \frac{1}{\beta} I() \]