Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Naive question about multiple seeds #95

Open
kfu02 opened this issue Jun 6, 2024 · 4 comments
Open

Naive question about multiple seeds #95

kfu02 opened this issue Jun 6, 2024 · 4 comments
Assignees

Comments

@kfu02
Copy link

kfu02 commented Jun 6, 2024

Hi,

I am familiar with MARL in Pytorch, but very new to JAX, so please forgive me if this question is naive.

I see that many of your baselines are parallelized over multiple seeds at once (e.g. here in QMIX or here in transfQMIX). However, when running the baselines I notice that the resulting WandB runs seem to aggregate the seeds together. Is there some way to separate the performance of each seed for plotting purposes (e.g. to report the min/avg/max)? Your paper has several average return curves with some sort of error shading, so I imagine I must be missing something obvious.

@amacrutherford
Copy link
Collaborator

amacrutherford commented Jun 8, 2024

Hey! So

Thanks for reaching out and exciting that you are trying out JAX. Off the top of my head, I think for wandb the easiest way is to run one seed per script and then sweep over each seed with wandb sweeps (and if you set XLA_PYTHON_CLIENT_PREALLOCATE=false as an environment variable you can then run multiple scripts on one GPU but this is quite a bit less efficient then multiple seeds over one device). Have I missed something @mttga ?

@kfu02
Copy link
Author

kfu02 commented Jun 8, 2024

Hi, thanks for the reply!

Okay, so you're saying the answer is simply not to parallelize across seeds, then use WandB's tools to aggregate separate 1-seed runs together. If I'm understanding that correctly, then what is being plotted when I run multiple seeds in parallel? The average across those seeds?

@mttga
Copy link
Collaborator

mttga commented Jun 9, 2024

The parallel runs will plot in the same space, meaning that you will have datapoints from all your runs but you will not be able to distinguish them. To distinguish them you can use an approach like this:

def function(rng):

  original_seed = rng[0]
  
  # random stuff
  
  metrics = # a dictionary of your logging metrics
  
  def callback(metrics, original_seed):
        metrics.update({
            f'rng{int(original_seed)}/{k}':v
            for k, v in metrics.items()
        })
        wandb.log(metrics)
  
    jax.debug.callback(callback, metrics, original_seed)

we will include training code like this soon

@Chulabhaya
Copy link
Contributor

Hi all! I would just like to clarify a dumb question with regards to the current Jax setup and the WandB logging. When you run training with multiple seeds, with the first set of plots generated (with or without WANDB_LOG_ALL_SEEDS), is there a way to have all the aggregated data shown in such a way that WandB will show the mean/std of those runs combined? Currently it seems like when they're aggregated WandB will consider all your seeds as a single run so it won't show the std shading. And if you log all seeds separately then I'm not sure how to then combine them in the WandB interface to see the combined mean/std shading. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants