diff --git a/verl/utils/tracking.py b/verl/utils/tracking.py index ff300690..5cbc58da 100644 --- a/verl/utils/tracking.py +++ b/verl/utils/tracking.py @@ -22,7 +22,7 @@ class Tracking(object): - supported_backend = ["wandb", "mlflow", "swanlab", "console"] + supported_backend = ["wandb", "mlflow", "swanlab", "vemlp_wandb", "console"] def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = 'console', config=None): if isinstance(default_backend, str): @@ -63,6 +63,24 @@ def __init__(self, project_name, experiment_name, default_backend: Union[str, Li mode=SWANLAB_MODE) self.logger["swanlab"] = swanlab + if 'vemlp_wandb' in default_backend: + import os + import volcengine_ml_platform + from volcengine_ml_platform import wandb as vemlp_wandb + volcengine_ml_platform.init( + ak=os.environ["VOLC_ACCESS_KEY_ID"], + sk=os.environ["VOLC_SECRET_ACCESS_KEY"], + region=os.environ["MLP_TRACKING_REGION"], + ) + + vemlp_wandb.init( + project=project_name, + name=experiment_name, + config=config, + sync_tensorboard=True, + ) + self.logger['vemlp_wandb'] = vemlp_wandb + if 'console' in default_backend: from verl.utils.logger.aggregate_logger import LocalLogger self.console_logger = LocalLogger(print_to_console=True) @@ -78,6 +96,8 @@ def __del__(self): self.logger['wandb'].finish(exit_code=0) if 'swanlab' in self.logger: self.logger['swanlab'].finish() + if 'vemlp_wandb' in self.logger: + self.logger['vemlp_wandb'].finish(exit_code=0) class _MlflowLoggingAdapter: