Skip to content

Commit

Permalink
tensorboard support
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Jan 18, 2023
1 parent dc8cda2 commit 37aa98e
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 6 deletions.
2 changes: 2 additions & 0 deletions examples/scripts/ppo-sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
# We first define the configuration of the experiment, defining the model, the dataset,
# the training parameters, and the PPO parameters.
# Check the default arguments in the `PPOConfig` class for more details.
# If you want to log with tensorboard, add the kwarg
# `accelerator_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig.
config = PPOConfig(
model_name="lvwerra/gpt2-imdb",
learning_rate=1.41e-5,
Expand Down
16 changes: 16 additions & 0 deletions trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ def rec(nest, prefix, into):
return flat


def convert_for_tensorboard(stats):
"""
Converts the stats from a flattened dict to single scalar dicts
"""
tensorboard_stats = {}
for k, v in stats.items():
# for tensorboard compatibility - arrays and tensors are ignored with tensorboard
# therefore we convert single element tensors to scalars
if (isinstance(v, torch.Tensor) or isinstance(v, np.ndarray)) and (
len(v.shape) == 0 or (len(v.shape) == 1 and v.shape[0] == 1)
):
v = v.item()
tensorboard_stats[k] = v
return tensorboard_stats


def stack_dicts(stats_dicts):
"""Stack the values of a dict."""
results = dict()
Expand Down
8 changes: 7 additions & 1 deletion trl/trainer/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import numpy as np

from ..core import flatten_dict


@dataclass
class PPOConfig(object):
Expand Down Expand Up @@ -58,6 +60,8 @@ class PPOConfig(object):
log_with (`str`, *optional*, defaults to "wandb"):
Log with either "wandb" or "tensorboard", check
https://huggingface.co/docs/accelerate/usage_guides/tracking for more details
accelerator_kwargs (`dict`, *optional*, defaults to {}):
Keyword arguments for the accelerator (e.g. `logging_dir`)
tracker_kwargs (`dict`, *optional*, defaults to {}):
Keyword arguments for the tracker (e.g. wandb_project)
tracker_project_name (`str`, *optional*, defaults to "trl"):
Expand All @@ -84,6 +88,7 @@ def __init__(
remove_unused_columns: Optional[bool] = True,
log_with: Optional[str] = "wandb",
tracker_kwargs: Optional[dict] = {},
accelerator_kwargs: Optional[dict] = {},
tracker_project_name: Optional[str] = "trl",
):
self.model_name = model_name
Expand Down Expand Up @@ -115,6 +120,7 @@ def __init__(
)

self.tracker_kwargs = tracker_kwargs
self.accelerator_kwargs = accelerator_kwargs
self.tracker_project_name = tracker_project_name

self.total_ppo_epochs = int(np.ceil(steps / batch_size))
Expand All @@ -123,4 +129,4 @@ def to_dict(self):
output_dict = {}
for key, value in self.__dict__.items():
output_dict[key] = value
return output_dict
return flatten_dict(output_dict)
23 changes: 18 additions & 5 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ..core import (
WANDB_PADDING,
clip_by_value,
convert_for_tensorboard,
entropy_from_logits,
flatten_dict,
logprobs_from_logits,
Expand Down Expand Up @@ -81,8 +82,8 @@ def __init__(
super().__init__(config)

# Step 1: Initialize Accelerator
self.accelerator = Accelerator(log_with=config.log_with)
self.accelerator.init_trackers(config.tracker_project_name, config=config, **config.tracker_kwargs)
self.accelerator = Accelerator(log_with=config.log_with, **config.accelerator_kwargs)
self.accelerator.init_trackers(config.tracker_project_name, config=config.to_dict(), **config.tracker_kwargs)

# Step 2: Initialize model, tokenizer, and dataloader
if not isinstance(model, PreTrainedModelWrapper):
Expand Down Expand Up @@ -138,6 +139,9 @@ def __init__(
# or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11
self.is_distributed = self.accelerator.distributed_type == "MULTI_GPU"

# init the current step
self.current_step = 0

# init wandb on the main process:
if self.accelerator.is_main_process and self.config.log_with == "wandb":
import wandb
Expand Down Expand Up @@ -360,6 +364,11 @@ def step(
# Log the total ppo time
timing["time/ppo/total"] = time.time() - t0
stats.update(timing)

# post-process stats for tensorboard
if self.config.log_with == "tensorboard":
stats = convert_for_tensorboard(stats)

return stats

def gather_stats(self, stats):
Expand Down Expand Up @@ -664,11 +673,15 @@ def log_stats(
rewards /= self.accelerator.num_processes

logs.update(stats)
logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy()
logs["env/reward_std"] = torch.std(rewards).cpu().numpy()
logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item()
logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
logs["env/reward_dist"] = rewards.cpu().numpy()

self.accelerator.log(logs)
if self.config.log_with == "tensorboard":
# update the current step
self.current_step += 1

self.accelerator.log(logs, step=self.current_step if self.config.log_with == "tensorboard" else None)

else:
if self.is_distributed:
Expand Down

0 comments on commit 37aa98e

Please sign in to comment.