-
Notifications
You must be signed in to change notification settings - Fork 478
/
Copy pathpipeline.py
1221 lines (1036 loc) · 43.2 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2024 The TensorFlow Ranking Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ranking pipeline to train tf.keras.Model in tfr.keras."""
import abc
import dataclasses
import os
from typing import Any, Dict, List, Optional, Tuple, Union
import tensorflow as tf
from tensorflow_ranking.python import data
from tensorflow_ranking.python.keras import losses
from tensorflow_ranking.python.keras import metrics
from tensorflow_ranking.python.keras import model as model_lib
from tensorflow_ranking.python.keras import saved_model
from tensorflow_ranking.python.keras import strategy_utils
class AbstractPipeline(metaclass=abc.ABCMeta):
"""Interface for ranking pipeline to train a `tf.keras.Model`.
The `AbstractPipeline` class is an abstract class to train and validate a
ranking model in tfr.keras.
To be implemented by subclasses:
* `build_loss()`: Contains the logic to build a `tf.keras.losses.Loss` or a
dict or list of `tf.keras.losses.Loss`s to be optimized in training.
* `build_metrics()`: Contains the logic to build a list or dict of
`tf.keras.metrics.Metric`s to monitor and evaluate the training.
* `build_weighted_metrics()`: Contains the logic to build a list or dict of
`tf.keras.metrics.Metric`s which will take the weights.
* `train_and_validate()`: Contrains the main training pipeline for training
and validation.
Example subclass implementation:
```python
class BasicPipeline(AbstractPipeline):
def __init__(self, model, train_data, valid_data, name=None):
self._model = model
self._train_data = train_data
self._valid_data = valid_data
self._name = name
def build_loss(self):
return tfr.keras.losses.get('softmax_loss')
def build_metrics(self):
return [
tfr.keras.metrics.get(
'ndcg', topn=topn, name='ndcg_{}'.format(topn)
) for topn in [1, 5, 10]
]
def build_weighted_metrics(self):
return [
tfr.keras.metrics.get(
'ndcg', topn=topn, name='weighted_ndcg_{}'.format(topn)
) for topn in [1, 5, 10]
]
def train_and_validate(self, *arg, **kwargs):
self._model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
loss=self.build_loss(),
metrics=self.build_metrics(),
weighted_metrics=self.build_weighted_metrics())
self._model.fit(
x=self._train_data,
epochs=100,
validation_data=self._valid_data)
```
"""
@abc.abstractmethod
def build_loss(self) -> Any:
"""Returns the loss for model.compile.
Example usage:
```python
pipeline = BasicPipeline(model, train_data, valid_data)
loss = pipeline.build_loss()
```
Returns:
A `tf.keras.losses.Loss` or a dict or list of `tf.keras.losses.Loss`.
"""
raise NotImplementedError("Calling an abstract method.")
@abc.abstractmethod
def build_metrics(self) -> Any:
"""Returns a list of ranking metrics for `model.compile()`.
Example usage:
```python
pipeline = BasicPipeline(model, train_data, valid_data)
metrics = pipeline.build_metrics()
```
Returns:
A list or a dict of `tf.keras.metrics.Metric`s.
"""
raise NotImplementedError("Calling an abstract method.")
@abc.abstractmethod
def build_weighted_metrics(self) -> Any:
"""Returns a list of weighted ranking metrics for model.compile.
Example usage:
```python
pipeline = BasicPipeline(model, train_data, valid_data)
weighted_metrics = pipeline.build_weighted_metrics()
```
Returns:
A list or a dict of `tf.keras.metrics.Metric`s.
"""
raise NotImplementedError("Calling an abstract method.")
@abc.abstractmethod
def train_and_validate(self, *arg, **kwargs) -> Any:
"""Constructs and runs the training pipeline.
Example usage:
```python
pipeline = BasicPipeline(model, train_data, valid_data)
pipeline.train_and_validate()
```
Args:
*arg: arguments that might be used in the training pipeline.
**kwargs: keyword arguments that might be used in the training pipeline.
Returns:
None or a trained `tf.keras.Model` or a path to a saved `tf.keras.Model`.
"""
raise NotImplementedError("Calling an abstract method.")
class AbstractDatasetBuilder(metaclass=abc.ABCMeta):
"""Interface for datasets and signatures.
The `AbstractDatasetBuilder` class is an abstract class to serve data in
tfr.keras. A `DatasetBuilder` will be passed to an instance of
`AbstractPipeline` and called to serve the training and validation datasets
and to define the serving signatures for saved models to treat the
corresponding format of data.
To be implemented by subclasses:
* `build_train_dataset()`: Contains the logic to build a `tf.data.Dataset`
for training.
* `build_valid_dataset()`: Contains the logic to build a `tf.data.Dataset`
for validation.
* `build_signatures()`: Contains the logic to build a dict of signatures
that formulate the model in functions that render the input data with given
format.
Example subclass implementation:
```python
class NullDatasetBuilder(AbstractDatasetBuilder):
def __init__(self, train_dataset, valid_dataset, signatures=None):
self._train_dataset = train_dataset
self._valid_dataset = valid_dataset
self._signatures = signatures
def build_train_dataset(self, *arg, **kwargs) -> tf.data.Dataset:
return self._train_dataset
def build_valid_dataset(self, *arg, **kwargs) -> tf.data.Dataset:
return self._valid_dataset
def build_signatures(self, *arg, **kwargs) -> Any:
return self._signatures
```
"""
@abc.abstractmethod
def build_train_dataset(self, *arg, **kwargs) -> tf.data.Dataset:
"""Returns the training dataset.
Example usage:
```python
dataset_builder = NullDatasetBuilder(train_data, valid_data)
train_dataset = dataset_builder.build_train_dataset()
```
Args:
*arg: arguments that might be used to build training dataset.
**kwargs: keyword arguments that might be used to build training dataset.
Returns:
A `tf.data.Dataset`.
"""
raise NotImplementedError("Calling an abstract method.")
@abc.abstractmethod
def build_valid_dataset(self, *arg, **kwargs) -> tf.data.Dataset:
"""Returns the validation dataset.
Example usage:
```python
dataset_builder = NullDatasetBuilder(train_data, valid_data)
valid_dataset = dataset_builder.build_valid_dataset()
```
Args:
*arg: arguments that might be used to build validation dataset.
**kwargs: keyword arguments that might be used to build validation
dataset.
Returns:
A `tf.data.Dataset`.
"""
raise NotImplementedError("Calling an abstract method.")
@abc.abstractmethod
def build_signatures(self, *arg, **kwargs) -> Any:
"""Returns the signatures to export a SavedModel.
Example usage:
```python
dataset_builder = NullDatasetBuilder(train_data, valid_data)
signatures = dataset_builder.build_signatures()
```
Args:
*arg: arguments that might be used to build signatures.
**kwargs: keyword arguments that might be used to build signatures.
Returns:
None or a dict of concrete functions.
"""
raise NotImplementedError("Calling an abstract method.")
@dataclasses.dataclass
class PipelineHparams:
"""Hyperparameters used in `ModelFitPipeline`.
Hyperparameters to be specified for ranking pipeline.
Attributes:
model_dir: A path to output the model and training data.
num_epochs: An integer to specify the number of epochs of training.
steps_per_epoch: An integer to specify the number of steps per epoch. When
it is None, going over the training data once is counted as an epoch.
validation_steps: An integer to specify the number of validation steps in
each epoch. Note that a mini-batch of data will be evaluated in each step
and this is the number of steps taken for validation in each epoch.
learning_rate: A float to indicate the learning rate of the optimizer.
loss: A string or a map to strings that indicate the loss to be used. When
`loss` is a string, all outputs and labels will be trained with the same
loss. When `loss` is a map, outputs and labels will be trained with losses
implied by the corresponding keys.
loss_reduction: An option in `tf.keras.losses.Reduction` to specify the
reduction method.
optimizer: An option in `tf.keras.optimizers` identifiers to specify the
optimizer to be used.
loss_weights: None or a float or a map to floats that indicate the relative
weights for each loss. When not specified, all losses are applied with the
same weight 1.
steps_per_execution: An integer to specify the number of steps executed in
each operation. Tuning this to optimize the training performance in
distributed training.
automatic_reduce_lr: A boolean to indicate whether to use
`ReduceLROnPlateau` callback.
early_stopping_patience: Number of epochs with no improvement after which
training will be stopped.
early_stopping_min_delta: Minimum change in the monitored quantity to
qualify as an improvement, i.e. an absolute change of less than
early_stopping_min_delta, will count as no improvement.
use_weighted_metrics: A boolean to indicate whether to use weighted metrics.
export_best_model: A boolean to indicate whether to export the best model
evaluated by the `best_exporter_metric` on the validation data.
best_exporter_metric_higher_better: A boolean to indicate whether the
`best_exporter_metric` is the higher the better.
best_exporter_metric: A string to specify the metric used to monitor the
training and to export the best model. Default to the 'loss'.
strategy: An option of strategies supported in `strategy_utils`. Choose from
["MirroredStrategy", "MultiWorkerMirroredStrategy",
"ParameterServerStrategy", "TPUStrategy"].
cluster_resolver: A cluster_resolver to build strategy.
variable_partitioner: Variable partitioner to be used in
ParameterServerStrategy.
tpu: TPU address for TPUStrategy. Not used for other strategy.
"""
model_dir: str
num_epochs: int
steps_per_epoch: int
validation_steps: int
learning_rate: float
loss: Union[str, Dict[str, str]]
loss_reduction: str = tf.losses.Reduction.AUTO
optimizer: str = "adam"
loss_weights: Optional[Union[float, Dict[str, float]]] = None
steps_per_execution: int = 10
automatic_reduce_lr: bool = False
early_stopping_patience: int = 0
early_stopping_min_delta: float = 0.0
use_weighted_metrics: bool = False
export_best_model: bool = False
best_exporter_metric_higher_better: bool = False
best_exporter_metric: str = "loss"
strategy: Optional[str] = None
cluster_resolver: Optional[
tf.distribute.cluster_resolver.ClusterResolver] = None
variable_partitioner: Optional[
tf.distribute.experimental.partitioners.Partitioner] = None
tpu: Optional[str] = ""
@dataclasses.dataclass
class DatasetHparams:
"""Hyperparameters used in `BaseDatasetBuilder`.
Hyperparameters to be specified to create the dataset_builder.
Attributes:
train_input_pattern: A glob pattern to specify the paths to the input data
for training.
valid_input_pattern: A glob pattern to specify the paths to the input data
for validation.
train_batch_size: An integer to specify the batch size of training dataset.
valid_batch_size: An integer to specify the batch size of valid dataset.
list_size: An integer to specify the list size. When None, data will be
padded to the longest list in each batch.
valid_list_size: An integer to specify the list size in valid dataset. When
not specified, valid dataset uses the same list size as `list_size`.
dataset_reader: A function or class that can be called with a `filenames`
tensor and (optional) `reader_args` and returns a `Dataset`. Defaults to
`tf.data.TFRecordDataset`.
convert_labels_to_binary: A boolean to indicate whether to use binary label.
"""
train_input_pattern: str
valid_input_pattern: str
train_batch_size: int
valid_batch_size: int
list_size: Optional[int] = None
valid_list_size: Optional[int] = None
dataset_reader: Any = tf.data.TFRecordDataset
convert_labels_to_binary: bool = False
class ModelFitPipeline(AbstractPipeline):
"""Pipeline using `model.fit` to train a ranking `tf.keras.Model`.
The `ModelFitPipeline` class is an abstract class inherit from
`AbstractPipeline` to train and validate a ranking `model` with `model.fit`
in a distributed strategy specified in hparams.
To be implemented by subclasses:
* `build_loss()`: Contains the logic to build a `tf.keras.losses.Loss` or a
dict or list of `tf.keras.losses.Loss`s to be optimized in training.
* `build_metrics()`: Contains the logic to build a list or dict of
`tf.keras.metrics.Metric`s to monitor and evaluate the training.
* `build_weighted_metrics()`: Contains the logic to build a list or dict of
`tf.keras.metrics.Metric`s which will take the weights.
Example subclass implementation:
```python
class BasicModelFitPipeline(ModelFitPipeline):
def build_loss(self):
return tfr.keras.losses.get('softmax_loss')
def build_metrics(self):
return [
tfr.keras.metrics.get(
'ndcg', topn=topn, name='ndcg_{}'.format(topn)
) for topn in [1, 5, 10]
]
def build_weighted_metrics(self):
return [
tfr.keras.metrics.get(
'ndcg', topn=topn, name='weighted_ndcg_{}'.format(topn)
) for topn in [1, 5, 10]
]
```
"""
def __init__(
self,
model_builder: model_lib.AbstractModelBuilder,
dataset_builder: AbstractDatasetBuilder,
hparams: PipelineHparams,
):
"""Initializes the instance.
Args:
model_builder: A `ModelBuilder` instance for model fit.
dataset_builder: An `AbstractDatasetBuilder` instance to load train and
validate datasets and create signatures for SavedModel.
hparams: A dict containing model hyperparameters.
"""
self._validate_parameters(model_builder, dataset_builder)
self._model_builder = model_builder
self._dataset_builder = dataset_builder
self._hparams = hparams
self._optimizer = tf.keras.optimizers.get({
"class_name": self._hparams.optimizer,
"config": {
"learning_rate": self._hparams.learning_rate
}
})
self._strategy = strategy_utils.get_strategy(
self._hparams.strategy, self._hparams.cluster_resolver,
self._hparams.variable_partitioner, self._hparams.tpu)
def _validate_parameters(self, model_builder: model_lib.AbstractModelBuilder,
dataset_builder: AbstractDatasetBuilder):
"""Validates the passed-in model_builder and dataset_builder.
Args:
model_builder: A `ModelBuilder` instance.
dataset_builder: A `DatasetBuilder` instance.
Raises:
ValueError: If the `model_builder` is None.
ValueError: If the `model_builder` is not an `ModelBuilder`.
ValueError: If the `dataset_builder` is None.
ValueError: If the `dataset_builder` is not an `DatasetBuilder`.
"""
if model_builder is None:
raise ValueError("The `model_builder` cannot be empty!")
if not isinstance(model_builder, model_lib.AbstractModelBuilder):
raise ValueError(
"The argument `model_builder` needs to be of type "
"tensorflow_ranking.keras.model.AbstractModelBuilder, not {}.".format(
type(model_builder)))
if dataset_builder is None:
raise ValueError("The `dataset_builder` cannot be empty!")
if not isinstance(dataset_builder, AbstractDatasetBuilder):
raise ValueError(
"The argument `dataset_builder` needs to be of type "
"tensorflow_ranking.keras.pipeline.DatasetBuilder, not {}.".format(
type(dataset_builder)))
def build_callbacks(self) -> List[tf.keras.callbacks.Callback]:
"""Sets up Callbacks.
Example usage:
```python
model_builder = ModelBuilder(...)
dataset_builder = DatasetBuilder(...)
hparams = PipelineHparams(...)
pipeline = BasicModelFitPipeline(model_builder, dataset_builder, hparams)
callbacks = pipeline.build_callbacks()
```
Returns:
A list of `tf.keras.callbacks.Callback` or a
`tf.keras.callbacks.CallbackList` for tensorboard and checkpoint.
"""
# Writing summary logs to file may have performance impact. Therefore, we
# only write summary events every epoch.
callbacks = [
tf.keras.callbacks.TensorBoard(self._hparams.model_dir),
tf.keras.callbacks.experimental.BackupAndRestore(
backup_dir=self._hparams.model_dir)
]
if self._hparams.export_best_model:
# default to be min of loss metric.
best_export_metric = self._hparams.best_exporter_metric
if best_export_metric != "loss":
best_export_metric = "metric/" + best_export_metric
callbacks.append(
tf.keras.callbacks.ModelCheckpoint(
os.path.join(self._hparams.model_dir,
"best_checkpoint/ckpt-{epoch:04d}"),
monitor="val_" + best_export_metric,
mode=("max" if self._hparams.best_exporter_metric_higher_better
else "min"),
save_weights_only=True,
save_best_only=True))
if self._hparams.automatic_reduce_lr:
callbacks.append(
tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
min_delta=0.01 * self._hparams.learning_rate,
))
if self._hparams.early_stopping_patience:
best_export_metric = self._hparams.best_exporter_metric
if best_export_metric != "loss":
best_export_metric = "metric/" + best_export_metric
callbacks.append(
tf.keras.callbacks.EarlyStopping(
monitor="val_" + best_export_metric,
min_delta=self._hparams.early_stopping_min_delta,
patience=self._hparams.early_stopping_patience,
mode=("max" if self._hparams.best_exporter_metric_higher_better
else "min"),
))
return callbacks
def export_saved_model(self,
model: tf.keras.Model,
export_to: str,
checkpoint: Optional[tf.train.Checkpoint] = None):
"""Exports the trained model with signatures.
Example usage:
```python
model_builder = ModelBuilder(...)
dataset_builder = DatasetBuilder(...)
hparams = PipelineHparams(...)
pipeline = BasicModelFitPipeline(model_builder, dataset_builder, hparams)
pipeline.export_saved_model(model_builder.build(), 'saved_model/')
```
Args:
model: Model to be saved.
export_to: Specifies the directory the model is be exported to.
checkpoint: If given, export the model with weights from this checkpoint.
"""
if checkpoint:
model.load_weights(checkpoint)
model.save(
filepath=export_to,
signatures=self._dataset_builder.build_signatures(model))
def train_and_validate(self, verbose=0):
"""Main function to train the model with TPU strategy.
Example usage:
```python
context_feature_spec = {}
example_feature_spec = {
"example_feature_1": tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=0.0)
}
mask_feature_name = "list_mask"
label_spec = {
"utility": tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=0.0)
}
dataset_hparams = DatasetHparams(
train_input_pattern="train.dat",
valid_input_pattern="valid.dat",
train_batch_size=128,
valid_batch_size=128)
pipeline_hparams = pipeline.PipelineHparams(
model_dir="model/",
num_epochs=2,
steps_per_epoch=5,
validation_steps=2,
learning_rate=0.01,
loss="softmax_loss")
model_builder = SimpleModelBuilder(
context_feature_spec, example_feature_spec, mask_feature_name)
dataset_builder = SimpleDatasetBuilder(
context_feature_spec,
example_feature_spec,
mask_feature_name,
label_spec,
dataset_hparams)
pipeline = BasicModelFitPipeline(
model_builder, dataset_builder, pipeline_hparams)
pipeline.train_and_validate(verbose=1)
```
Args:
verbose: An int for the verbosity level.
"""
strategy = self._strategy
with strategy_utils.strategy_scope(strategy):
model = self._model_builder.build()
# Note that all losses and metrics need to be constructed within the
# strategy scope. This is why we use member function like `build_loss` and
# don't use passed-in objects.
model.compile(
optimizer=self._optimizer,
loss=self.build_loss(),
metrics=self.build_metrics(),
loss_weights=self._hparams.loss_weights,
weighted_metrics=(self.build_weighted_metrics()
if self._hparams.use_weighted_metrics else None),
steps_per_execution=self._hparams.steps_per_execution)
# Move the following out of strategy.scope only after b/173547275 fixed.
# Otherwise, MultiWorkerMirroredStrategy will fail.
train_dataset, valid_dataset = (
self._dataset_builder.build_train_dataset(),
self._dataset_builder.build_valid_dataset())
model.fit(
x=train_dataset,
epochs=self._hparams.num_epochs,
steps_per_epoch=self._hparams.steps_per_epoch,
validation_steps=self._hparams.validation_steps,
validation_data=valid_dataset,
callbacks=self.build_callbacks(),
verbose=verbose)
model_output_dir = strategy_utils.get_output_filepath(
self._hparams.model_dir, strategy)
self.export_saved_model(
model,
export_to=os.path.join(model_output_dir, "export/latest_model"))
if self._hparams.export_best_model:
best_checkpoint = tf.train.latest_checkpoint(
os.path.join(self._hparams.model_dir, "best_checkpoint"))
if best_checkpoint:
self.export_saved_model(
model,
export_to=os.path.join(model_output_dir,
"export/best_model_by_metric"),
checkpoint=best_checkpoint)
else:
raise ValueError("Didn't find the best checkpoint.")
def _get_metric(prefix, key, topn=None, **kwargs):
"""Helper function to construct a metric."""
name = "{}{}{}".format(prefix, key, "_%s" % topn if topn else "")
return metrics.get(key, name=name, topn=topn, **kwargs)
class SimplePipeline(ModelFitPipeline):
"""Pipleine for single-task training.
This handles a single loss and works with `SimpleDatasetBuilder`. This can
also work with `MultiLabelDatasetBuilder`. In this case, the same loss, as
well as all metrics, will be applied to all labels and their predictions
uniformly.
Use subclassing to customize the loss and metrics.
Example usage:
```python
context_feature_spec = {}
example_feature_spec = {
"example_feature_1": tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=0.0)
}
mask_feature_name = "list_mask"
label_spec = {
"utility": tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=0.0)
}
dataset_hparams = DatasetHparams(
train_input_pattern="train.dat",
valid_input_pattern="valid.dat",
train_batch_size=128,
valid_batch_size=128)
pipeline_hparams = pipeline.PipelineHparams(
model_dir="model/",
num_epochs=2,
steps_per_epoch=5,
validation_steps=2,
learning_rate=0.01,
loss="softmax_loss")
model_builder = SimpleModelBuilder(
context_feature_spec, example_feature_spec, mask_feature_name)
dataset_builder = SimpleDatasetBuilder(
context_feature_spec,
example_feature_spec,
mask_feature_name,
label_spec,
dataset_hparams)
pipeline = SimplePipeline(model_builder, dataset_builder, pipeline_hparams)
pipeline.train_and_validate(verbose=1)
```
"""
def build_loss(self) -> tf.keras.losses.Loss:
"""See `AbstractPipeline`."""
if not isinstance(self._hparams.loss, str):
raise TypeError("In the simple pipeline, losses are expected to be "
"specified in a str.")
return losses.get(
loss=self._hparams.loss, reduction=self._hparams.loss_reduction)
def build_metrics(self) -> List[tf.keras.metrics.Metric]:
"""See `AbstractPipeline`."""
eval_metrics = [
_get_metric("metric/", metrics.RankingMetricKey.NDCG, topn=topn)
for topn in [1, 5, 10, None]
]
return eval_metrics
def build_weighted_metrics(self) -> List[tf.keras.metrics.Metric]:
"""See `AbstractPipeline`."""
eval_metrics = [
_get_metric(
"weighted_metric/", metrics.RankingMetricKey.NDCG, topn=topn)
for topn in [1, 5, 10, None]
]
return eval_metrics
class MultiTaskPipeline(ModelFitPipeline):
"""Pipeline for multi-task training.
This handles a set of losses and labels. It is intended to mainly work with
`MultiLabelDatasetBuilder`.
Use subclassing to customize the losses and metrics.
Example usage:
```python
context_feature_spec = {}
example_feature_spec = {
"example_feature_1": tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=0.0)
}
mask_feature_name = "list_mask"
label_spec_tuple = ("utility",
tf.io.FixedLenFeature(
shape=(1,),
dtype=tf.float32,
default_value=_PADDING_LABEL))
label_spec = {"task1": label_spec_tuple, "task2": label_spec_tuple}
weight_spec = ("weight",
tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=1.))
dataset_hparams = DatasetHparams(
train_input_pattern="train.dat",
valid_input_pattern="valid.dat",
train_batch_size=128,
valid_batch_size=128)
pipeline_hparams = PipelineHparams(
model_dir="model/",
num_epochs=2,
steps_per_epoch=5,
validation_steps=2,
learning_rate=0.01,
loss={
"task1": "softmax_loss",
"task2": "pairwise_logistic_loss"
},
loss_weights={
"task1": 1.0,
"task2": 2.0
},
export_best_model=True)
model_builder = MultiTaskModelBuilder(...)
dataset_builder = MultiLabelDatasetBuilder(
context_feature_spec,
example_feature_spec,
mask_feature_name,
label_spec,
dataset_hparams,
sample_weight_spec=weight_spec)
pipeline = MultiTaskPipeline(model_builder, dataset_builder, pipeline_hparams)
pipeline.train_and_validate(verbose=1)
```
"""
def build_loss(self) -> Dict[str, tf.keras.losses.Loss]:
"""See `AbstractPipeline`."""
reduction = self._hparams.loss_reduction
if not isinstance(self._hparams.loss, dict):
raise TypeError("In the multi-task pipeline, losses are expected to be "
"specified in a dict.")
return {
task_name: losses.get(loss=loss_name, reduction=reduction)
for task_name, loss_name in self._hparams.loss.items()
}
def build_metrics(self) -> Dict[str, List[tf.keras.metrics.Metric]]:
"""See `AbstractPipeline`."""
def eval_metrics():
return [
_get_metric("metric/", metrics.RankingMetricKey.NDCG, topn=topn)
for topn in [1, 5, 10, None]
]
return {task_name: eval_metrics() for task_name in self._hparams.loss}
def build_weighted_metrics(self) -> Dict[str, List[tf.keras.metrics.Metric]]:
"""See `AbstractPipeline`."""
def eval_metrics():
return [
_get_metric(
"weighted_metric/", metrics.RankingMetricKey.NDCG, topn=topn)
for topn in [1, 5, 10, None]
]
return {task_name: eval_metrics() for task_name in self._hparams.loss}
class NullDatasetBuilder(AbstractDatasetBuilder):
"""A no-op wrapper of datasets and signatures.
Example usage:
```python
train_dataset = tf.data.Dataset(...)
valid_dataset = tf.data.Dataset(...)
dataset_builder = NullDatasetBuilder(train_dataset, valid_dataset)
```
"""
def __init__(self, train_dataset, valid_dataset, signatures=None):
"""Initializes the instance.
Args:
train_dataset: A `tf.data.Dataset` for training.
valid_dataset: A `tf.data.Dataset` for validation.
signatures: A dict of signatures that formulate the model in functions
that render the input data with given types. When None, no signatures
assigned.
"""
self._train_dataset = train_dataset
self._valid_dataset = valid_dataset
self._signatures = signatures
def build_train_dataset(self, *arg, **kwargs) -> tf.data.Dataset:
"""See `AbstractDatasetBuilder`."""
return self._train_dataset
def build_valid_dataset(self, *arg, **kwargs) -> tf.data.Dataset:
"""See `AbstractDatasetBuilder`."""
return self._valid_dataset
def build_signatures(self, *arg, **kwargs) -> Any:
"""See `AbstractDatasetBuilder`."""
return self._signatures
class BaseDatasetBuilder(AbstractDatasetBuilder):
"""Builds datasets from feature specs.
The `BaseDatasetBuilder` class is an abstract class inherit from
`AbstractDatasetBuilder` to serve training and validation datasets and
signatures for training `ModelFitPipeline`.
To be implemented by subclasses:
* `_features_and_labels()`: Contains the logic to map a dict of tensors of
dataset to feature tensors and label tensors.
Example subclass implementation:
```python
class SimpleDatasetBuilder(BaseDatasetBuilder):
def _features_and_labels(self, features):
label = features.pop("utility")
return features, label
```
"""
# TODO: Define these bulky types as globals at the top.
def __init__(self,
context_feature_spec: Dict[str, Union[tf.io.FixedLenFeature,
tf.io.VarLenFeature,
tf.io.RaggedFeature]],
example_feature_spec: Dict[str, Union[tf.io.FixedLenFeature,
tf.io.VarLenFeature,
tf.io.RaggedFeature]],
training_only_example_spec: Dict[str,
Union[tf.io.FixedLenFeature,
tf.io.VarLenFeature,
tf.io.RaggedFeature]],
mask_feature_name: str,
hparams: DatasetHparams,
training_only_context_spec: Optional[Dict[
str, Union[tf.io.FixedLenFeature, tf.io.VarLenFeature,
tf.io.RaggedFeature]]] = None):
"""Intializes the instance.
Args:
context_feature_spec: Maps context (aka, query) names to feature specs.
example_feature_spec: Maps example (aka, document) names to feature specs.
training_only_example_spec: Feature specs used for training only like
labels and per-example weights.
mask_feature_name: If set, populates the feature dictionary with this name
and the coresponding value is a `tf.bool` Tensor of shape [batch_size,
list_size] indicating the actual example is padded or not.
hparams: A dict containing model hyperparameters.
training_only_context_spec: Feature specs used for training only per-list
weights.
"""
self._context_feature_spec = context_feature_spec
self._example_feature_spec = example_feature_spec
self._training_only_example_spec = training_only_example_spec
self._mask_feature_name = mask_feature_name
self._hparams = hparams
self._training_only_context_spec = training_only_context_spec or {}
@abc.abstractmethod
def _features_and_labels(self, features: Dict[str, tf.Tensor]) -> Any:
"""Extracts labels and weights from features.
Args:
features: Maps feature name and label name to corresponding tensors.
Returns:
A tuple of a dict of the rest of features, labels and optional weights.
"""
raise NotImplementedError("Calling an abstract method.")
def _build_dataset(self,
file_pattern: str,
batch_size: int,
list_size: Optional[int] = None,
randomize_input: bool = True,
num_epochs: Optional[int] = None) -> tf.data.Dataset:
"""Returns `tf.data.Dataset` for training or validating the model.
Args:
file_pattern: File pattern for input data.
batch_size: Number of input examples to process per batch.
list_size: The list size for an ELWC example.
randomize_input: If true, randomize input example order. It should almost
always be true except for unittest/debug purposes.
num_epochs: Number of times the input dataset must be repeated. None to
repeat the data indefinitely.
Returns:
A `tf.data.Dataset`.
"""
# TODO: Remove defaults common in Estimator pipeline and here.
dataset = data.build_ranking_dataset(
file_pattern=file_pattern,
data_format=data.ELWC,
batch_size=batch_size,
list_size=list_size,
context_feature_spec=dict(
list(self._context_feature_spec.items()) +
list(self._training_only_context_spec.items())),
example_feature_spec=dict(
list(self._example_feature_spec.items()) +
list(self._training_only_example_spec.items())),
mask_feature_name=self._mask_feature_name,
reader=self._hparams.dataset_reader,
num_epochs=num_epochs,
shuffle=randomize_input,
shuffle_buffer_size=1000,
sloppy_ordering=True,
drop_final_batch=False,
shuffle_examples=False)
return dataset.map(
self._features_and_labels,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
def build_train_dataset(self) -> tf.data.Dataset:
"""See `AbstractDatasetBuilder`."""
train_list_size = self._hparams.list_size
return self._build_dataset(
file_pattern=self._hparams.train_input_pattern,
batch_size=self._hparams.train_batch_size,
list_size=train_list_size)
def build_valid_dataset(self) -> tf.data.Dataset:
"""See `AbstractDatasetBuilder`."""
valid_list_size = (self._hparams.valid_list_size or self._hparams.list_size)
return self._build_dataset(
file_pattern=self._hparams.valid_input_pattern,
batch_size=self._hparams.valid_batch_size,
list_size=valid_list_size,
randomize_input=False)