Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add: support for peft in ddpo. #1165

Merged
merged 17 commits into from
Jan 2, 2024
Merged

add: support for peft in ddpo. #1165

merged 17 commits into from
Jan 2, 2024

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Jan 1, 2024

It's time.

Cc: @metric-space @younesbelkada

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +600 to +605
lora_config = LoraConfig(
r=4,
lora_alpha=4,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This matches with what was done previously.

Comment on lines -594 to -612
# Set correct lora layers
lora_attn_procs = {}
for name in self.sd_pipeline.unet.attn_processors.keys():
cross_attention_dim = (
None if name.endswith("attn1.processor") else self.sd_pipeline.unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = self.sd_pipeline.unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(self.sd_pipeline.unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self.sd_pipeline.unet.config.block_out_channels[block_id]

lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
self.sd_pipeline.unet.set_attn_processor(lora_attn_procs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No crazy layer iteration and state dict munging. Pretty please. Thanks to peft.

Copy link
Contributor

@metric-space metric-space Jan 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty swell that this is being cleanly replaced. Sweet stuff

@metric-space
Copy link
Contributor

hey @sayakpaul things look splendid. Nothing has really changed in theory but would be nice to have a test run that shows convergence but I'll leave it to your and @younesbelkada's discretion to do without

@sayakpaul
Copy link
Member Author

@metric-space yes, will do! Thanks for your reviews.

@sayakpaul
Copy link
Member Author

sayakpaul commented Jan 2, 2024

WandB run page: https://wandb.ai/sayakpaul/stable_diffusion_training/runs/7ebll3fb?workspace=user-sayakpaul.

This pig is too cute:

image

@kashif
Copy link
Collaborator

kashif commented Jan 2, 2024

LGTM! thanks

@kashif kashif self-requested a review January 2, 2024 11:51
@kashif kashif merged commit 20428c4 into huggingface:main Jan 2, 2024
9 checks passed
@sayakpaul sayakpaul deleted the harmonize-lora-ddpo branch January 2, 2024 12:00
@younesbelkada
Copy link
Contributor

Awesome work @sayakpaul and team !

lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* add: support for peft in ddpo.

* revert to the original modeling_base.

* style

* specify weight_name

* explicitly specify weight_name

* fix: parameter parsing

* fix: trainable_layers.

* parameterize use_lora.

* fix one more trainable_layers

* debug

* debug

* more fixes.

* manually set unet of sd_pipeline

* make trainable_layers cleaner.

* more fixes

* remove prints.

* tester class for LoRA too.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants