Skip to content

Commit

Permalink
add: support for peft in ddpo. (huggingface#1165)
Browse files Browse the repository at this point in the history
* add: support for peft in ddpo.

* revert to the original modeling_base.

* style

* specify weight_name

* explicitly specify weight_name

* fix: parameter parsing

* fix: trainable_layers.

* parameterize use_lora.

* fix one more trainable_layers

* debug

* debug

* more fixes.

* manually set unet of sd_pipeline

* make trainable_layers cleaner.

* more fixes

* remove prints.

* tester class for LoRA too.
  • Loading branch information
sayakpaul authored and Andrew Lapp committed May 10, 2024
1 parent de0705f commit bff9dde
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 40 deletions.
4 changes: 3 additions & 1 deletion examples/scripts/ddpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class ScriptArguments:
"""HuggingFace model ID for aesthetic scorer model weights"""
hf_hub_aesthetic_model_filename: str = "aesthetic-model.pth"
"""HuggingFace model filename for aesthetic scorer model weights"""
use_lora: bool = True
"""Whether to use LoRA."""

ddpo_config: DDPOConfig = field(
default_factory=lambda: DDPOConfig(
Expand Down Expand Up @@ -193,7 +195,7 @@ def image_outputs_logger(image_data, global_step, accelerate_logger):
args = tyro.cli(ScriptArguments)

pipeline = DefaultDDPOStableDiffusionPipeline(
args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=True
args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=args.use_lora
)

trainer = DDPOTrainer(
Expand Down
32 changes: 30 additions & 2 deletions tests/test_ddpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

import torch

from trl import is_diffusers_available
from trl import is_diffusers_available, is_peft_available

from .testing_utils import require_diffusers


if is_diffusers_available():
if is_diffusers_available() and is_peft_available():
from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline


Expand Down Expand Up @@ -97,3 +97,31 @@ def test_calculate_loss(self):
)

self.assertTrue(torch.isfinite(loss.cpu()))


@require_diffusers
class DDPOTrainerWithLoRATester(DDPOTrainerTester):
"""
Test the DDPOTrainer class.
"""

def setUp(self):
self.ddpo_config = DDPOConfig(
num_epochs=2,
train_gradient_accumulation_steps=1,
per_prompt_stat_tracking_buffer_size=32,
sample_num_batches_per_epoch=2,
sample_batch_size=2,
mixed_precision=None,
save_freq=1000000,
)
pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch"
pretrained_revision = "main"

pipeline = DefaultDDPOStableDiffusionPipeline(
pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=True
)

self.trainer = DDPOTrainer(self.ddpo_config, scorer_function, prompt_function, pipeline)

return super().setUp()
70 changes: 36 additions & 34 deletions trl/models/modeling_sd_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@
import numpy as np
import torch
from diffusers import DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
from diffusers.utils import convert_state_dict_to_diffusers

from ..core import randn_tensor
from ..import_utils import is_peft_available


if is_peft_available():
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict


@dataclass
Expand Down Expand Up @@ -534,7 +539,11 @@ def __init__(self, pretrained_model_name: str, *, pretrained_model_revision: str
self.pretrained_revision = pretrained_model_revision

try:
self.sd_pipeline.unet.load_attn_procs(pretrained_model_name, revision=pretrained_model_revision)
self.sd_pipeline.load_lora_weights(
pretrained_model_name,
weight_name="pytorch_lora_weights.safetensors",
revision=pretrained_model_revision,
)
self.use_lora = True
except OSError:
if use_lora:
Expand Down Expand Up @@ -583,42 +592,38 @@ def autocast(self):

def save_pretrained(self, output_dir):
if self.use_lora:
self.sd_pipeline.unet.save_attn_procs(output_dir)
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(self.sd_pipeline.unet))
self.sd_pipeline.save_lora_weights(save_directory=output_dir, unet_lora_layers=state_dict)
self.sd_pipeline.save_pretrained(output_dir)

def set_progress_bar_config(self, *args, **kwargs):
self.sd_pipeline.set_progress_bar_config(*args, **kwargs)

def get_trainable_layers(self):
if self.use_lora:
# Set correct lora layers
lora_attn_procs = {}
for name in self.sd_pipeline.unet.attn_processors.keys():
cross_attention_dim = (
None if name.endswith("attn1.processor") else self.sd_pipeline.unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = self.sd_pipeline.unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(self.sd_pipeline.unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self.sd_pipeline.unet.config.block_out_channels[block_id]

lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
self.sd_pipeline.unet.set_attn_processor(lora_attn_procs)
return AttnProcsLayers(self.sd_pipeline.unet.attn_processors)
lora_config = LoraConfig(
r=4,
lora_alpha=4,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
self.sd_pipeline.unet.add_adapter(lora_config)

# To avoid accelerate unscaling problems in FP16.
for param in self.sd_pipeline.unet.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
return self.sd_pipeline.unet
else:
return self.sd_pipeline.unet

def save_checkpoint(self, models, weights, output_dir):
if len(models) != 1:
raise ValueError("Given how the trainable params were set, this should be of length 1")
if self.use_lora and isinstance(models[0], AttnProcsLayers):
self.sd_pipeline.unet.save_attn_procs(output_dir)
if self.use_lora and hasattr(models[0], "peft_config") and getattr(models[0], "peft_config", None) is not None:
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(models[0]))
self.sd_pipeline.save_lora_weights(save_directory=output_dir, unet_lora_layers=state_dict)
elif not self.use_lora and isinstance(models[0], UNet2DConditionModel):
models[0].save_pretrained(os.path.join(output_dir, "unet"))
else:
Expand All @@ -627,15 +632,12 @@ def save_checkpoint(self, models, weights, output_dir):
def load_checkpoint(self, models, input_dir):
if len(models) != 1:
raise ValueError("Given how the trainable params were set, this should be of length 1")
if self.use_lora and isinstance(models[0], AttnProcsLayers):
tmp_unet = UNet2DConditionModel.from_pretrained(
self.pretrained_model,
revision=self.pretrained_revision,
subfolder="unet",
if self.use_lora:
lora_state_dict, network_alphas = self.sd_pipeline.lora_state_dict(
input_dir, weight_name="pytorch_lora_weights.safetensors"
)
tmp_unet.load_attn_procs(input_dir)
models[0].load_state_dict(AttnProcsLayers(tmp_unet.attn_processors).state_dict())
del tmp_unet
self.sd_pipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=models[0])

elif not self.use_lora and isinstance(models[0], UNet2DConditionModel):
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
models[0].register_to_config(**load_model.config)
Expand Down
14 changes: 11 additions & 3 deletions trl/trainer/ddpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def __init__(
if self.config.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True

self.optimizer = self._setup_optimizer(trainable_layers.parameters())
self.optimizer = self._setup_optimizer(
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
)

self.neg_prompt_embed = self.sd_pipeline.text_encoder(
self.sd_pipeline.tokenizer(
Expand All @@ -193,7 +195,11 @@ def __init__(
# more memory
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast

self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
else:
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)

if self.config.async_reward_computation:
self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
Expand Down Expand Up @@ -541,7 +547,9 @@ def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_sample
self.accelerator.backward(loss)
if self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(
self.trainable_layers.parameters(),
self.trainable_layers.parameters()
if not isinstance(self.trainable_layers, list)
else self.trainable_layers,
self.config.train_max_grad_norm,
)
self.optimizer.step()
Expand Down

0 comments on commit bff9dde

Please sign in to comment.