Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shape descriptions in calculate_loss method #1204

Conversation

yuta0x89
Copy link

@yuta0x89 yuta0x89 commented Jan 9, 2024

The content of latents is defined in the following code in modeling_sd_base.py:

latents = self.prepare_latents(
    batch_size * num_images_per_prompt,
    num_channels_latents,
    height,
    width,
    prompt_embeds.dtype,
    device,
    generator,
    latents,
)

As a result, the shape of latents and next_latents mentioned in the description of the calculate_loss method in the DDPOTrainer class should be [batch_size, num_channels_latents, height, width].

@lvwerra lvwerra requested a review from kashif January 9, 2024 13:11
@lvwerra
Copy link
Member

lvwerra commented Jan 9, 2024

Letting @sayakpaul have a look too.

@kashif kashif merged commit b181e40 into huggingface:main Jan 9, 2024
1 check failed
lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants