Skip to content

Commit

Permalink
Fix monitor state_averager in examples/albert (preliminary) (#452)
Browse files Browse the repository at this point in the history
This PR fixes several minor issues found in #446 : 

- fix `prefix=...` in training monitor
- create scheduler in training monitor
- rename experiment_prefix -> run_id
- enable checkpoints on aux peer by default
- decouple total steps from scheduler max steps

Co-authored-by: Yi Zhou <[email protected]>
Co-authored-by: Alexander Borzunov <[email protected]>
Co-authored-by: Max Ryabinin <[email protected]>
  • Loading branch information
4 people authored Jan 24, 2022
1 parent aa6d65a commit 2abda6f
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 18 deletions.
7 changes: 4 additions & 3 deletions examples/albert/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ To join the collaboration with a GPU trainer,
(see [default paths](./arguments.py#L117-L134) for reference)
- Run:
```bash
./run_trainer.py \
--initial_peers ONE_OR_MORE_PEERS \
--logging_first_step --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
./run_trainer.py --initial_peers ONE_OR_MORE_PEERS --per_device_train_batch_size BATCH_SIZE_FOR_YOUR_GPU
```

Here, `ONE_OR_MORE_PEERS` stands for multiaddresses of one or multiple existing peers (training monitors or existing
Expand All @@ -82,6 +80,9 @@ To join the collaboration with a GPU trainer,
You may need to change the IP address to a publicly visible one if some of the initial peers are located behind NAT.
If you have any trouble doing this, consider the ["Using IPFS"](#using-ipfs) section.

The `BATCH_SIZE_FOR_YOUR_GPU` should be tweaked so that the model fits into your GPU memory.
For 1080Ti or 2080Ti gpus, a good initial value is 4. For 8GB GPUs, try batch size 1-2.

See the ["Tips and tricks"](#tips-and-tricks) section for more information on setting up collaborative training.

As the peer begins training, it will periodically report training logs in the following form:
Expand Down
13 changes: 9 additions & 4 deletions examples/albert/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

@dataclass
class BaseTrainingArguments:
experiment_prefix: str = field(
run_id: str = field(
default="albert", metadata={"help": "A unique 'name' of this experiment, used to store metadata on the DHT"}
)
initial_peers: List[str] = field(
Expand Down Expand Up @@ -127,7 +127,7 @@ class AlbertTrainingArguments(TrainingArguments):
gradient_accumulation_steps: int = 2
seq_length: int = 512

max_steps: int = 125_000 # please note: this affects both number of steps and learning rate schedule
total_steps: int = 125_000 # please note: this only affects the learning rate schedule
learning_rate: float = 0.00176
warmup_steps: int = 5000
adam_epsilon: float = 1e-6
Expand All @@ -138,9 +138,14 @@ class AlbertTrainingArguments(TrainingArguments):
fp16: bool = True
fp16_opt_level: str = "O2"
do_train: bool = True
do_eval: bool = False

logging_dir: str = "logs"
output_dir: str = "outputs"
logging_steps: int = 100
logging_first_step: bool = True
overwrite_output_dir: bool = True

save_total_limit: int = 2
save_steps: int = 500

output_dir: str = "outputs"
max_steps: int = 10 ** 30 # meant as "peer should compute gradients forever"
6 changes: 3 additions & 3 deletions examples/albert/run_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def main():
# This data collator will take care of randomly masking the tokens.
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)

validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
validators, local_public_key = utils.make_validators(collaboration_args.run_id)

dht = DHT(
start=True,
Expand Down Expand Up @@ -260,12 +260,12 @@ def main():
]

scheduler = lambda opt: get_linear_schedule_with_warmup(
opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps
)

optimizer = Optimizer(
dht=dht,
run_id=collaboration_args.experiment_prefix,
run_id=collaboration_args.run_id,
target_batch_size=adjusted_target_batch_size,
batch_size_per_step=total_batch_size_per_step,
optimizer=opt,
Expand Down
13 changes: 7 additions & 6 deletions examples/albert/run_training_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import wandb
from torch_optimizer import Lamb
from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser, get_linear_schedule_with_warmup

import hivemind
from hivemind.optim.state_averager import TrainingStateAverager
Expand Down Expand Up @@ -40,6 +40,7 @@ class TrainingMonitorArguments(BaseTrainingArguments):
wandb_project: Optional[str] = field(
default=None, metadata={"help": "Name of Weights & Biases project to report the training progress to"}
)
store_checkpoints: bool = field(default=True, metadata={"help": "If False, disables periodic checkpoint saving"})
save_checkpoint_step_interval: int = field(
default=5, metadata={"help": "Frequency (in steps) of fetching and saving state from peers"}
)
Expand All @@ -56,7 +57,6 @@ class TrainingMonitorArguments(BaseTrainingArguments):
upload_interval: Optional[float] = field(
default=None, metadata={"help": "Frequency (in seconds) of uploading the model to Hub"}
)
store_checkpoints: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})


class CheckpointHandler:
Expand Down Expand Up @@ -99,7 +99,8 @@ def __init__(
self.state_averager = TrainingStateAverager(
dht=dht,
optimizer=opt,
prefix=experiment_prefix,
scheduler=get_linear_schedule_with_warmup(opt, num_warmup_steps=5000, num_training_steps=125_000),
prefix=f"{run_id}_state_averager",
state_compression=hivemind.Float16Compression(),
bandwidth=optimizer_args.bandwidth,
client_mode=optimizer_args.client_mode,
Expand Down Expand Up @@ -155,8 +156,8 @@ def upload_checkpoint(self, current_loss):
version = ip_address(address).version
monitor_args.announce_maddrs += [f"/ip{version}/{address}/tcp/0"]

experiment_prefix = monitor_args.experiment_prefix
validators, local_public_key = utils.make_validators(experiment_prefix)
run_id = monitor_args.run_id
validators, local_public_key = utils.make_validators(run_id)

dht = hivemind.DHT(
start=True,
Expand All @@ -177,7 +178,7 @@ def upload_checkpoint(self, current_loss):
checkpoint_handler = CheckpointHandler(monitor_args, optimizer_args, averager_args, dht)

while True:
metrics_dict = dht.get(experiment_prefix + "_metrics", latest=True)
metrics_dict = dht.get(run_id + "_metrics", latest=True)
if metrics_dict is not None:
metrics_dict = metrics_dict.value
metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict]
Expand Down
4 changes: 2 additions & 2 deletions examples/albert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ class MetricSchema(BaseModel):
metrics: Dict[BytesWithPublicKey, LocalMetrics]


def make_validators(experiment_prefix: str) -> Tuple[List[RecordValidatorBase], bytes]:
def make_validators(run_id: str) -> Tuple[List[RecordValidatorBase], bytes]:
signature_validator = RSASignatureValidator()
validators = [SchemaValidator(MetricSchema, prefix=experiment_prefix), signature_validator]
validators = [SchemaValidator(MetricSchema, prefix=run_id), signature_validator]
return validators, signature_validator.local_public_key


Expand Down

0 comments on commit 2abda6f

Please sign in to comment.