-
Notifications
You must be signed in to change notification settings - Fork 193
/
Copy pathgpt2_ppo.yml
79 lines (71 loc) · 1.3 KB
/
gpt2_ppo.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
tokenizer:
model_name: gpt2
padding_side: left
truncation_side: left
pad_token_as_eos_token: True
reward_fn:
id: "intent_accuracy"
args:
intent_coeff: 0.75
auto_coeff: 0.25
datapool:
id: "daily_dialog"
args:
context_size: 5
env:
n_envs: 10
args:
max_prompt_length: 128
max_episode_length: 20
terminate_on_eos: True
alg:
id: ppo
args:
n_steps: 128
batch_size: 64
verbose: 1
learning_rate: 0.000001
n_epochs: 5
kl_div:
coeff: 0.2
target_kl: 0.5
policy:
id: causal_lm_actor_critic_policy
args:
model_name: gpt2
apply_model_parallel: True
generation_kwargs:
do_sample: True
top_k: 20
min_length: 2
max_new_tokens: 20
train_evaluation:
eval_batch_size: 32
n_iters: 50
eval_every: 5
save_every: 10
metrics:
- id: intent_accuracy
- id: causal_perplexity
args:
tokenizer_id: gpt2
stride: 128
model_type: causal
- id: diversity
args: {}
- id: meteor
args: {}
- id: rouge
- id: bleu
args: {}
- id: bert_score
args:
language: en
- id: sacre_bleu
args:
tokenize: "intl"
generation_kwargs:
do_sample: True
top_k: 20
min_length: 2
max_new_tokens: 20