Enhancing Diffusion Models with Reinforcement Learning

jasonzhangxianrong發表於2024-07-24

TL;DR

Today we're going to tell you all about DRLX - our library for Diffusion Reinforcement Learning! Released a few weeks ago, DRLX is a library for scalable distributed training of diffusion models with reinforcement learning. It currently implements the recent DDPO paper, with more RL implementations coming soon. We share here preliminary results training Stable Diffusion v2.1 with PickScore, observing better prompt adherence and image quality. Contributions and manpower for helping with additional features and experiments are always greatly appreciated, so if you're interested, please join us in the #drlx channel of our Discord server.

Introduction

Over the last year, Large Language Models (LLMs) have gained popularity in part due to the training of instruction-tuning, preference learning, and chat models. This is often done with Reinforcement Learning from Human Feedback (RLHF), popularized by OpenAI with their work on InstructGPT and ChatGPT. RLHF is typically implemented by training a reward model on a database of human ratings/preferences over LM generations, and then using RL to optimize the reward yielded by the LM's generations. While some industry labs have established workflows and infrastructure internally, RLHF still remains inaccessible to many other companies, institutions, and labs. CarperAI has been focused on democratizing RLHF through the development of open-source libraries like trlX, as well as sharing research insights regarding the training of LLMs with RLHF.

A natural question to ask regarding RLHF is if it can be applied to other model frameworks. Concurrent with the rise of LLMs has been the advances in image generation with AI, specifically with diffusion models. While text-to-image diffusion models have provided impressive results, often the generations fully capture the prompts and satisfy human preferences. So the question arises if RLHF can be applied to train diffusion models and better capture human preferences.

In May, a recent paper from the Levine Lab at UC Berkeley demonstrated how proximal policy optimization (PPO), the popular RL algorithm used in most implementations of RLHF, can be applied to train diffusion models. Their algorithm, called Denoising Diffusion Policy Optimization (DDPO), showed promising results with a variety of simple reward models.

This paper has led to significant interest in diffusion RLHF. Once again in line with our mission to democratize access to RLHF, we have released DRLX (Diffusion Reinforcement Learning X), a library for modular RLHF training of diffusion models. While it currently only implements DDPO, we are adding implementations of other RL algorithms.

Reward Models

A preliminary ingredient for RLHF is a reward model. At first, it's likely not clear why we want to be predicting "human preference" as a single number. The loss function for a reward model in the style of InstructGPT is as follows:

Put in simple terms, the output of a reward model trained on pairwise comparisons can be thought of as log odds. Specifically, the difference in outputs between two generations (images in our case), can be thought of as the log-odds that the first would be chosen over the second by a human. An individual score can be thought of as the log-odds an image has of being chosen by a human over any other particular image, so an image with a large reward output is more likely to be preferred by a human over other images with lower rewards.

DDPO algorithm

Before going over how DDPO works, let’s think about how standard RLHF of language models works. The goal of RL algorithms is to optimize some reward. For RLHF, this is typically the output of a reward model that outputs a preference score given the output of a generative model.. The reward model is trained on tens of thousands of human preferences over LM generations. The language model generation is treated as a Markov Decision Process (MDP) where the state is the previous tokens, the policy is the language model, and the action is the predicted next token. With this setup, policy gradient methods like REINFORCE and PPO can be easily applied. Policy gradient methods enable us to directly optimize a model with an estimator of the gradient of the model with respect to the reward function. Note that unlike many standard RL setups, the reward is only given at the end of a trajectory, when the full LM generation is available.

DDPO also formulates the diffusion model as an MDP, which allows the application of the policy gradients methods for RLHF tuning of the diffusion models. The DDPO paper describes both the use of REINFORCE (referred to as DDPOSF in the paper) and PPO (referred to as DDPOIS in the paper). Given the significant improvement of performance of DDPOIS over DDPOSF (Figure 4 in the DDPO paper) we only implement the DDPOSF algorithm in DRLX.

The loss function for DDPOIS is given as follows:

Where is the diffusion model and the reward r(...) is given by the reward model, for which we pass the image and possibly a contextualizing prompt as input.

Additionally, the ratio in the loss function is clipped to prevent the current diffusion model from diverging too much and leading to a bad model (this is one of the key properties of the PPO algorithm).

For a deeper dive into the DDPO algorithm and how it is implemented, especially if you are new to RL, please check out this blog post for more information.

Note that there are a couple differences between DDPO and standard application of PPO to RLHF:

  1. DDPO does not have a value funcion baseline, while RLHF of LMs typically implements this by having a separate head on the LM, and uses generalized advantage estimation.
  2. RLHF as implemented by OpenAI performs KL regularization on the reward function, whereas DDPO does not.

These differences could provide directions for future work.

Results

LAION Aesthetic Score

Similar to the original paper, we train a Stable Diffusion model using DDPO treating the LAION aesthetic score as a reward model. Unlike the original paper though, here we use Stable Diffusion 2.1 (we use the base version to work with 512x512 images).

The LAION aesthetic classifier was trained on human ratings collected for images generated in the SimulacraBot Discord server with latent diffusion models that served as precursors to Stable Diffusion. It remains one of the most commonly used aesthetic scorers, despite significant issues with the dataset and model. Nevertheless, it serves as a useful baseline for our RLHF experiments here.

Using DRLX to train this model is trivial. First we need a prompt distribution. Here we use random ImageNet animal classes as prompts, just like the original DDPO paper.

class ImagenetAnimalPrompts(PromptPipeline):
    """
    Pipeline of prompts consisting of animals from ImageNet, as used in the original `DDPO paper `_.
    """
    def __init__(self, num=10000, *args, **kwargs):
        super().__init__(*args, **kwargs)
        r = requests.get("https://raw.githubusercontent.com/formigone/tf-imagenet/master/LOC_synset_mapping.txt")
        with open("LOC_synset_mapping.txt", "wb") as f: f.write(r.content)
        self.synsets = {k:v for k,v in [o.split(',')[0].split(' ', maxsplit=1) for o in Path('LOC_synset_mapping.txt').read_text().splitlines()]}
        self.imagenet_classes = list(self.synsets.values())
        self.num = num


    def __getitem__(self, index):
        animal = random.choice(self.imagenet_classes[:397])
        return f'{animal}'
   
    def __len__(self):
        'Denotes the total number of samples'
        return self.num

You can see we simply have to define our ImageNet classes (we download it from a text file online) and randomly choose one from the top 398 classes which are the animal classes, and return that as our prompt!

The next thing to set up is our reward function. We use the standard inference code for the aesthetic classifier, which you can see here.

Then we can define a RewardModel class for the aesthetic predictor:

class Aesthetics(RewardModel):
    """
    Reward model that rewards images with higher aesthetic score. Uses CLIP and an MLP (not put on any device by default)


    :param device: Device to load model on
    :type device: torch.device
    """
    def __init__(self, device = None):
        super().__init__()
        self.model = MLP(768)
        self.model.load_state_dict(load_aesthetic_model_weights())
        self.clip_model, self.preprocess = clip.load("ViT-L/14", device=device if device is not None else 'cpu')


        if device is not None:
            self.model.to(device)


    def forward(self, images : np.ndarray, prompts : Iterable[str]):
        return aesthetic_scoring(
            images,
            self.preprocess,
            self.clip_model,
            self.model
        )

As you can see, a very simple setup. We then need to set up a simple config file (you can see the defaults in configs.py)

method:
  name : "DDPO"


model:
  model_path: "stabilityai/stable-diffusion-2-1-base"
  model_arch_type: "LDMUnet"
  attention_slicing: True
  xformers_memory_efficient: True
  gradient_checkpointing: True


sampler:
  num_inference_steps: 50


optimizer:
  name: "adamw"
  kwargs:
    lr: 1.0e-5
    weight_decay: 1.0e-4
    betas: [0.9, 0.999]


scheduler:
  name: "linear" # Name of learning rate scheduler
  kwargs:
    start_factor: 1.0
    end_factor: 1.0
 
logging:
  run_name: 'ddpo_sd_imagenet'
  wandb_entity: 'tmabraham'
  wandb_project: 'DRLX'


train:
  num_epochs: 200
  num_samples_per_epoch: 256
  batch_size: 4
  sample_batch_size: 32
  grad_clip: 1.0
  checkpoint_interval: 50
  tf32: True
  suppress_log_keywords: "diffusers.pipelines,transformers"
  save_samples: False

Note the use of various performance optimizations already implemented in PyTorch or HuggingFace Diffusers. This includes tf32 support on Ampere GPUs, attention slicing, memory-efficient attention, and gradient checkpointing.

Also note that while SD 2.1 is a v-prediction model, we do not have to specify this manually and the HuggingFace Diffusers schedulers handle this automatically for us.

With all of that out of the way, training with DRLX is extremely simple:

import torch
from drlx.trainer.ddpo_trainer import DDPOTrainer
from drlx.configs import DRLXConfig
from drlx.reward_modelling.aesthetics import Aesthetics
from drlx.pipeline.imagenet_animal_prompts import ImagenetAnimalPrompts


config = DRLXConfig.load_yaml("configs/ddpo_sd_imagenet.yml")
pipe = ImagenetAnimalPrompts(num=config.train.num_samples_per_epoch)
trainer = DDPOTrainer(config)
trainer.train(pipe, Aesthetics())

As you can see, our DDPOTrainer class handles everything! It just needs a prompt pipeline, a reward function, and a config. After training for 200 epochs with 256 samples per epoch, we get the following reward curve:

A consistently increasing mean reward score is exactly what we want. Here is a comparison of a few samples before and after training on aesthetic score.

Notice how the model settles on a specific redish-orangish color palette for the images, as well as a narrow depth-of-focus effect, blurring the background. The subject tends to be centered, though not always. These effects of RLHF training are more apparent when observing even more samples:

I’d argue that these images certainly are more aesthetically pleasing than those generated by the base model, but still feel boring in some aspects, and lack diversity.

It is interesting to observe how a sample changes over the course of training (consistent seed):

For example, these visualizations demonstrate that the unique color palette only starts to appear towards the end of training.

RL Training with PickScore

One particular issue with using LAION aesthetic classifier as a reward model is that it does not take into consideration the prompt. This means that training to optimize the LAION aesthetic score could give us a model that always outputs the same image with an extremely high score with no regard for the prompt. In fact, in earlier experiments with Stable Diffusion v1.4, we did observe this behavior, where a noisy adversarial pattern was giving scores of ~7.1:

This behavior is sometimes referred to as reward hacking in the RL literature. In order to address this challenge, it would be ideal to ensure that the generated images are aligned well with the provided prompt, and optimize against this as well. Recently, researchers have been developing preference scores for AI-generated image that focus on prompt alignment as well. PickScore is one such preference model, which was developed in collaboration with Stability AI researchers. It is a CLIP-ViT-H finetuned on the Pick-A-Pic dataset of user preferences over Stable Diffusion XL preview beta images. It utilizes the CLIP text encoder as well, unlike LAION aesthetic predictor, which enables the PickScore model to better capture prompt alignment. For this reason, we tested RL training of Stable Diffusion 2.1 with PickScore as the reward model on the Pick-a-Pic dataset.

Our training code is set up in a similar manner to the LAION aesthetic classifier-based training:

import torch
from drlx.trainer.ddpo_trainer import DDPOTrainer
from drlx.configs import DRLXConfig
from drlx.pipeline.pickapic_prompts import PickAPicPrompts, PickAPicReplacementPrompts
from drlx.reward_modelling.pickscore import PickScoreModel
from drlx.reward_modelling.aesthetics import Aesthetics


pipe = PickAPicPrompts()
config = DRLXConfig.load_yaml("configs/ddpo_sd_pickapic.yml")
trainer = DDPOTrainer(config)
trainer.train(pipe, PickScoreModel())

We train the model with a guidance scale of 7.5, for 450 epochs, with 256 samples per epoch and a batch size of 16 (divided over 4 80GB A100 GPUs, so a batch size of 4 per GPU). Again we utilize similar performance optimizations as before.

We once again observe a consistently increasing mean reward score. Below you can see a few examples of generated images from different checkpoints during training:

While not all examples show a demonstrable improvement, we do see in some instances both improved prompt alignment and better aesthetic quality. For example, in the “A panda bear as a mad scientist” example, we see that at the end of training the panda arguably looks more “mad” and is handling glassware as if it’s a scientist. The “jedi duck” is also quite a good example, where there is a much clearer depiction of the duck holding a lightsaber. On the other hand, the generation of the final model for the prompt “a golden retriever jumping over a box” fails to actually show a golden retriever jumping over a box, but it is still quite an aesthetic image.

Once again, we can observe how a sample changes over the course of training (consistent seed):

We have released the weights of our DDPO-trained SD 2.1 with PickScore on Pick-A-Pic prompts to HuggingFace over here.

Discussion and Conclusion

Using RLHF for improvement of image generation models is a burgeoning field and there is lots of work left to do. On the side of the generative model, there are lots of different RL algorithms to be tried (DPO, offline RL, reST, etc.) and tweaks that could be made to the RL setup. On the reward model side, simply using CLIP is obviously very simplistic and would be insufficient for an instruct language model, so perhaps something new is in order for an image reward model. There is also a need to translate different advancements from instruct language generation to diffusion for images and audio, with new algorithms coming out for the former task every few weeks. If you'd like to join us on the journey of bringing RLHF up to speed for image generation, feel free to to say hello in the #drlx channel, and check the GitHub issues for tasks to help out with!

Enhancing Diffusion Models with Reinforcement Learning | CarperAI

相關文章