Learning Goal-Conditioned Representations

Learning Goal-Conditioned Representations for Language Reward Models

Vaskar Nath∗†, Dylan Slack∗, Jeff Da, Yuntao Ma, Hugh Zhang, Spencer Whitehead‡, Sean Hendryx‡

∗Equal contribution
†Corresponding author: vaskar.nath@scale.com
‡Equal senior authorship

Techniques that learn improved representations via offline data or self-supervised objectives have shown impressive results in traditional reinforcement learning (RL). Nevertheless, it is unclear how improved representation learning can benefit reinforcement learning from human feedback (RLHF) on language models (LMs). In this work, we propose training reward models (RMs) in a contrastive, goal-conditioned fashion by increasing the representation similarity of future states along sampled preferred trajectories and decreasing the similarity along randomly sampled dispreferred trajectories. This objective significantly improves reward model performance by up to 0.09 AUROC across challenging benchmarks, such as MATH and GSM8k. These findings extend to general alignment as well – on the Helpful-Harmless dataset, we observe 2.3% increase in accuracy. Beyond improving reward model performance, we show this way of training RM representations enables improved steerability because it allows us to evaluate the likelihood of an action achieving a particular goal-state (e.g., whether a solution is correct or helpful). Leveraging this insight, we find that we can filter up to 55% of generated tokens during majority voting by discarding trajectories likely to end up in an “incorrect” state, which leads to significant cost savings. We additionally find that these representations can perform fine-grained control by conditioning on desired future goal-states. For example, we show that steering a Llama 3 model towards helpful generations with our approach improves helpfulness by 9.6% over a supervised-fine-tuning trained baseline. Similarly, steering the model towards complex generations improves complexity by 21.6% over the baseline. Overall, we find that training RMs in this contrastive, goal-conditioned fashion significantly improves performance and enables model steerability. 1

Introduction

Aligning Language Models (LMs) with human preferences has proven to be an essential step for the adoption and safe use of these models, with the dominant paradigm being Reinforcement Learning from Human Feedback (RLHF) [50]. To accomplish this, a standard setup is to collect labels from humans for generated responses (e.g., preferences, quality ratings) [50]. These labels can then be used to train a reward model to produce a ranking/scoring of a given sequence or set of sequences.



Figure 1: Overview of contrastive goal-conditioned learning for text. Pictured is a prompt with a preferred and dispreferred response. Both source state tokens (ten) for the positive and negative trajectory are sampled from the preferred response. For illustrative purposes, the positve and negative source states are sampled as the same token, but in practice they can be different. The positive goal state is sampled as some future token (subtract) from the preferred response, and the negative goal state is sampled from any token (add) from the dispreferred response. The corresponding representations are retrieved from the last hidden state of the reward model. The training objective is then to maximize and minimize the similarity of the positive and negative representation pairs, respectively.

The policy LM is then trained to maximize the expected returns from this reward model using a Reinforcement Learning (RL) algorithm. 

High-quality representations have been shown to be an important piece for the success of RL algorithms [6, 42]. Although such representations can be learned during end-to-end training, many efforts have found it important to integrate more explicit representation learning components into RL algorithms, such as via data augmentation [39] or auxiliary losses [21]. Some work even casts certain RL algorithms as representation learning methods where using the similarity between state representations serves as a value function, demonstrating success on manipulation and navigation tasks [20]. Despite these successes in different areas, representation learning for aligning LMs has been less explored, while more emphasis has been placed on, e.g., pre-training reward models [7, 37] or learning from different types of rewards [15, 38].

In this paper, we present a simple yet effective approach to improve the representations learned by reward models for aligning LMs. We train LM-based reward models to learn representations that capture the expected reward or likelihood of achieving a goal state (e.g., correct solution to a problem, helpful response) at intermediate steps of the input sequence, inspired by goal-conditioned RL [5, 13, 20]. To do so, we use a contrastive objective and apply it to the reward model’s hidden representations from desirable and undesirable sequences. Enforcing this loss on representations from intermediate steps of the sequence helps encode a dense signal as to which trajectories are more promising at different points in the sequence, which we show offers several useful properties, such as helping to localize errors or evaluating partial completed sequences. This method is flexible enough to support different kinds of alignment data and does not require further annotations beyond common sequence-level annotations.

We find that this approach improves the reward model’s ability to identify correct/incorrect solutions in mathematical reasoning, boosting the AUROC on the task by up to 0.09 over standard preference ranking training. Towards natural language alignment, we find this method is able to increase the reward model’s accuracy of identifying helpful versus harmless responses by 2.3% on the HelpfulHarmless dataset [8].

We also show the utility of the learned representations themselves, e.g., for filtering solutions to improve accuracy and steering the outputs towards responses with certain attributes in guided decoding (e.g., helpfulness, coherence, and complexity). For mathematical reasoning, we show that 2 we are able to filter up to 55% of generated tokens by discarding trajectories that are likely to lead to incorrect solutions as deemed by the learned representations while achieving similar or better performance. Similarly, using these representations to steer a Llama 3 model by conditioning on desired future goal-states, we improve helpfulness by 9.6%, correctness by 12.2%, coherence by 16.5%, and complexity by 21.6% over the supervised-fine-tuning trained baseline.

In summary, our contributions are as follows: 1) We explore improving the learned representations of reward models and its effect on LM alignment. Towards this, we present a simple and effective representation learning method based on a goal-conditioned contrastive objective. 2) We demonstrate that training reward models with this method can improve reward model performance on mathematical reasoning and helpfulness/harmlessness benchmarks. 3) We show that simply utilizing a reward model trained with this method can improve policy LM alignment on math reasoning benchmarks. 4) We investigate using these representations as a mechanism to evaluate the likelihood of achieving a desired goal state by filtering generations in a majority-vote scheme and guided decoding, showing that they can be used to increase accuracy and control policy LM outputs.

Bibtex Citation

@misc{nath2024learninggoalconditionedrepresentationslanguage,
      title={Learning Goal-Conditioned Representations for Language Reward Models}, 
      author={Vaskar Nath and Dylan Slack and Jeff Da and Yuntao Ma and Hugh Zhang and Spencer Whitehead and Sean Hendryx},
      year={2024},
      eprint={2407.13887},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2407.13887}, 
}