-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add: support for peft
in ddpo.
#1165
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
lora_config = LoraConfig( | ||
r=4, | ||
lora_alpha=4, | ||
init_lora_weights="gaussian", | ||
target_modules=["to_k", "to_q", "to_v", "to_out.0"], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This matches with what was done previously.
# Set correct lora layers | ||
lora_attn_procs = {} | ||
for name in self.sd_pipeline.unet.attn_processors.keys(): | ||
cross_attention_dim = ( | ||
None if name.endswith("attn1.processor") else self.sd_pipeline.unet.config.cross_attention_dim | ||
) | ||
if name.startswith("mid_block"): | ||
hidden_size = self.sd_pipeline.unet.config.block_out_channels[-1] | ||
elif name.startswith("up_blocks"): | ||
block_id = int(name[len("up_blocks.")]) | ||
hidden_size = list(reversed(self.sd_pipeline.unet.config.block_out_channels))[block_id] | ||
elif name.startswith("down_blocks"): | ||
block_id = int(name[len("down_blocks.")]) | ||
hidden_size = self.sd_pipeline.unet.config.block_out_channels[block_id] | ||
|
||
lora_attn_procs[name] = LoRAAttnProcessor( | ||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim | ||
) | ||
self.sd_pipeline.unet.set_attn_processor(lora_attn_procs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No crazy layer iteration and state dict munging. Pretty please. Thanks to peft
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty swell that this is being cleanly replaced. Sweet stuff
hey @sayakpaul things look splendid. Nothing has really changed in theory but would be nice to have a test run that shows convergence but I'll leave it to your and @younesbelkada's discretion to do without |
@metric-space yes, will do! Thanks for your reviews. |
WandB run page: https://wandb.ai/sayakpaul/stable_diffusion_training/runs/7ebll3fb?workspace=user-sayakpaul. This pig is too cute: |
LGTM! thanks |
Awesome work @sayakpaul and team ! |
* add: support for peft in ddpo. * revert to the original modeling_base. * style * specify weight_name * explicitly specify weight_name * fix: parameter parsing * fix: trainable_layers. * parameterize use_lora. * fix one more trainable_layers * debug * debug * more fixes. * manually set unet of sd_pipeline * make trainable_layers cleaner. * more fixes * remove prints. * tester class for LoRA too.
It's time.
Cc: @metric-space @younesbelkada