-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patharguments.py
226 lines (189 loc) · 10.1 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
from dataclasses import dataclass, field
from typing import Optional, List, Union
import torch
from transformers import TrainingArguments
from transformers.training_args import ParallelMode
from transformers.utils import cached_property, logging
from utils import _infer_slurm_init
logger = logging.get_logger(__name__)
@dataclass
class ModelArguments:
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
target_model_path: str = field(
default=None,
metadata={"help": "Path to pretrained reranker target model"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
# modeling
shared_encoder: bool = field(
default=True,
metadata={"help": "weight sharing between qry passage encoders"}
)
teacher_model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
teacher_config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as teacher_model_name"}
)
teacher_tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as teacher_model_name"}
)
@dataclass
class DataArguments:
task: str = field(
default="XOR-Retrieve", metadata={"help": "task name"}
)
train_dir: str = field(
default=None, metadata={"help": "Path to train directory"}
)
train_path: str = field(
default=None, metadata={"help": "Path to train data"}
)
dataset_name: str = field(
default=None, metadata={"help": "huggingface dataset name"}
)
dataset_proc_num: int = field(
default=12, metadata={"help": "number of proc used in dataset preprocess"}
)
corpus_file: str = field(default="corpus.tsv", metadata={"help": "corpus text path"})
query_file: str = field(default="train.query.txt", metadata={"help": "query text path"})
eval_query_file: str = field(default="xor_dev_retrieve_eng_span_v1_1.jsonl",
metadata={"help": "query text path for evaluation"})
teacher_corpus_file: str = field(default="corpus.tsv", metadata={"help": "corpus text path for teacher model"})
teacher_query_file: str = field(default="train.query.txt", metadata={"help": "query text path for teacher model"})
qrels_file: str = field(default="qrels.train.tsv", metadata={"help": "query passage relation path"})
output_path: str = field(default=None, metadata={"help": "output path"})
labelled_train_path: str = field(
default=None, metadata={"help": "Path to train data"}
)
labelled_query_file: str = field(default="train.query.txt", metadata={"help": "query text path"})
train_n_passages: int = field(default=8)
no_shuffle_positive: bool = field(default=False)
sample_hard_negative_prob: float = field(default=1.0)
encode_in_path: List[str] = field(default=None, metadata={"help": "Path to data to encode"})
encoded_save_path: str = field(default=None, metadata={"help": "where to save the encode"})
encode_is_qry: bool = field(default=False)
encode_num_shard: int = field(default=1)
encode_shard_index: int = field(default=0)
split: str = field(default='train', metadata={"help": "dataset splits"})
normalize_text: bool = field(default=False, metadata={"help": "normalize text"})
lower_case: bool = field(default=False, metadata={"help": "lower case text"})
max_query_length: int = field(
default=512,
metadata={
"help": "The maximum total input sequence length after tokenization for query. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_passage_length: int = field(
default=512,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_query_passage_length: int = field(
default=512,
metadata={
"help": "The maximum total input sequence length after tokenization for passage & query. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
}
)
max_answer_length: int = field(default=50, metadata={"help": "max answer length"})
retrieval_data_path: str = field(default=None, metadata={"help": "retrieval results data path"})
max_teacher_query_length: int = field(
default=32,
metadata={
"help": "The maximum total input sequence length after tokenization for query. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_teacher_passage_length: int = field(
default=128,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
@dataclass
class BiEncoderTrainingArguments(TrainingArguments):
warmup_ratio: float = field(default=0.1)
negatives_x_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
do_encode: bool = field(default=False, metadata={"help": "run the encoding loop"})
qgen_pretrain: bool = field(default=False, metadata={"help": "do pretraining on cleaned synthetic data"})
retriever_score_scaling: bool = field(default=False, metadata={"help": "scale retriever score or not when "
"computing the distribution"})
grad_cache: bool = field(default=False, metadata={"help": "Use gradient cache update"})
gc_chunk_size: int = field(default=8)
print_steps: int = field(default=100, metadata={"help": "step for displaying"})
teacher_temp: float = field(default=1)
student_temp: float = field(default=1)
distributed_port: int = field(default=None, metadata={"help": "port for multi-node multi-gpu distributed training "
"using slurm"})
multi_task: bool = field(default=False, metadata={"help": "use multi-task training"})
distillation_start_steps: int = field(default=3000, metadata={"help": "start distillation after certain steps"})
separate_joint_encoding: bool = field(default=False, metadata={"help": "separate joint encoding"})
cross_lingual_alignment: bool = field(default=False, metadata={"help": "cross lingual alignment"})
query_alignment: bool = field(default=False, metadata={"help": "query alignment"})
refresh_passages: bool = field(default=False, metadata={"help": "refresh passages"})
refresh_intervals: int = field(default=3000, metadata={"help": "refresh intervals"})
retriever_weight: float = field(default=1.0, metadata={"help": "retriever weight"})
english_distill: bool = field(default=False, metadata={"help": "distillation from english teacher"})
parallel_queries: bool = field(default=False, metadata={"help": "use parallel queries"})
wikidata: bool = field(default=False, metadata={"help": "use wikidata"})
e2e_training: bool = field(default=False, metadata={"help": "end-to-end training"})
eval_on_test: bool = field(default=False, metadata={"help": "evaluate on test set when training is done"})
eval_on_mkqa: bool = field(default=False, metadata={"help": "evaluate on mkqa set when training is done"})
load_corpus: bool = field(default=True, metadata={"help": "load corpus from file"})
debug: bool = field(default=False, metadata={"help": "debug mode"})
eval_at_start: bool = field(default=False, metadata={"help": "evaluate at start of training"})
de_avg_pooling: bool = field(default=False, metadata={"help": "use average pooling for dual-encoder"})
@cached_property
def _setup_devices(self) -> "torch.device":
if self.distributed_port:
logger.info("PyTorch: setting up devices")
distributed_init_method, local_rank, world_size, device_id = _infer_slurm_init(self.distributed_port)
self.local_rank = local_rank
torch.distributed.init_process_group(
backend="nccl", init_method=distributed_init_method, world_size=world_size, rank=local_rank
)
logger.info("local rank {}, device id {}".format(local_rank, device_id))
self._n_gpu = 1
if device_id is None:
device = torch.device("cuda")
else:
device = torch.device("cuda", device_id)
if device.type == "cuda":
torch.cuda.set_device(device)
return device
else:
return super(BiEncoderTrainingArguments, self)._setup_devices
@property
def parallel_mode(self):
if self.local_rank != -1:
return ParallelMode.DISTRIBUTED
elif self.n_gpu > 1:
return ParallelMode.NOT_DISTRIBUTED
else:
return ParallelMode.NOT_PARALLEL
@dataclass
class DistilModelArguments(ModelArguments):
teacher_model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
teacher_config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as teacher_model_name"}
)
teacher_tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as teacher_model_name"}
)