From 6098a87d51330c7ca9168f655a30bbcb549e7a67 Mon Sep 17 00:00:00 2001 From: Uncertainty Baselines Team Date: Sun, 21 May 2023 06:51:16 -0700 Subject: [PATCH] Enable distributed training across multiple GPUs or TPUs. PiperOrigin-RevId: 533851979 --- experimental/shoshin/models.py | 1 + experimental/shoshin/train_tf.py | 5 +- experimental/shoshin/train_tf_lib.py | 96 ++++++++++++++++++++-------- 3 files changed, 74 insertions(+), 28 deletions(-) diff --git a/experimental/shoshin/models.py b/experimental/shoshin/models.py index 26da9a5f..1e0bbe01 100644 --- a/experimental/shoshin/models.py +++ b/experimental/shoshin/models.py @@ -68,6 +68,7 @@ class ModelTrainingParameters: reweighting_signal: Optional[str] = 'bias' reweighting_lambda: Optional[float] = 0.5 reweighting_error_percentile_threshold: Optional[float] = 0.2 + tpu_bns: Optional[str] = '' def asdict(self): return dataclasses.asdict(self) diff --git a/experimental/shoshin/train_tf.py b/experimental/shoshin/train_tf.py index 0747cb5e..42b1be82 100644 --- a/experimental/shoshin/train_tf.py +++ b/experimental/shoshin/train_tf.py @@ -42,6 +42,8 @@ 'directory. If False, only logs to console.') flags.DEFINE_string('ensemble_dir', '', 'If specified, loads the models at ' 'this directory to consider the ensemble.') +# The 'tpu' flag will be set by the `build_tpu_jobs` in the launch script. +flags.DEFINE_string('tpu', '', 'The BNS address of the first TPU worker.') def main(_) -> None: @@ -120,7 +122,8 @@ def main(_) -> None: use_pytorch_style_resnet=config.model.use_pytorch_style_resnet, do_reweighting=config.reweighting.do_reweighting, reweighting_lambda=config.reweighting.lambda_value, - reweighting_signal=config.reweighting.signal + reweighting_signal=config.reweighting.signal, + tpu_bns=FLAGS.tpu ) model_params.train_bias = config.train_bias output_dir = config.output_dir diff --git a/experimental/shoshin/train_tf_lib.py b/experimental/shoshin/train_tf_lib.py index 86274d76..db80c22c 100644 --- a/experimental/shoshin/train_tf_lib.py +++ b/experimental/shoshin/train_tf_lib.py @@ -109,7 +109,10 @@ def _compute_average_metrics( accs.append(m.result()) subgroup_label = m.name.split('_')[1] weighted_accs.append( - m.result() * float(self.subgroup_sizes[subgroup_label]) / total_size + m.result() + * float(self.subgroup_sizes[subgroup_label]) + / total_size + / tf.distribute.get_strategy().num_replicas_in_sync ) self.avg_acc.reset_state() self.avg_acc.update_state(accs) @@ -123,7 +126,6 @@ def _compute_average_metrics( def train_step(self, inputs): features = inputs['input_feature'] labels = inputs['label'] - example_ids = inputs['example_id'] subgroup_labels = inputs['subgroup_label'] y_true_main = tf.one_hot(labels, depth=self.num_classes) @@ -136,6 +138,8 @@ def train_step(self, inputs): self.reweighting_signal == 'bias'): if self.id_to_bias_table is None: raise ValueError('id_to_bias_table must not be None.') + # TODO(b/280491870): Change example id type to support TPU training. + example_ids = inputs['example_id'] y_true_bias = self.id_to_bias_table.lookup(example_ids) y_true_bias_original = y_true_bias y_true_bias = tf.one_hot(y_true_bias, depth=2) @@ -168,8 +172,11 @@ def train_step(self, inputs): below_threshold_example_multiplex) total_loss = self.compiled_loss( - y_true, y_pred, sample_weight=sample_weight) - total_loss += sum(self.losses) # Regularization loss. + y_true, + y_pred, + sample_weight=sample_weight, + regularization_losses=self.losses + ) gradients = tape.gradient(total_loss, self.model.trainable_variables) self.optimizer.apply_gradients( @@ -198,7 +205,6 @@ def train_step(self, inputs): def test_step(self, inputs): features = inputs['input_feature'] labels = inputs['label'] - example_ids = inputs['example_id'] subgroup_labels = inputs['subgroup_label'] y_true_main = tf.one_hot(labels, depth=2) y_pred = self(features, training=False) @@ -206,6 +212,7 @@ def test_step(self, inputs): if self.train_bias: if self.id_to_bias_table is None: raise ValueError('id_to_bias_table must not be None.') + example_ids = inputs['example_id'] y_true_bias = self.id_to_bias_table.lookup(example_ids) y_true['bias'] = tf.one_hot(y_true_bias, depth=2) @@ -345,6 +352,40 @@ def evaluate_model( logging.info(results) +def _create_strategy( + tpu_bns: Optional[str] = '') -> tf.distribute.Strategy: + """Creates distribution strategy used in training. + + Args: + tpu_bns: The bns address of the first TPU worker. + + Returns: + tf.distribute.Strategy + """ + use_tpu = False + if tpu_bns: # Use tpu if tpu_bns is specified + use_tpu = True + else: # Use tpu if tpu is available + visible_devices = tf.config.get_visible_devices() + for device in visible_devices: + if device.device_type == 'TPU': + use_tpu = True + break + + if use_tpu: + resolver = tf.distribute.cluster_resolver.TPUClusterResolver( + tpu=tpu_bns + ) + tf.config.experimental_connect_to_cluster(resolver) + tf.tpu.experimental.initialize_tpu_system(resolver) + logging.info('Train with TPU Strategy.') + return tf.distribute.experimental.TPUStrategy(resolver) + else: + # MirroredStrategy will use the available GPUs or CPUs. + logging.info('Train with Mirrored Strategy.') + return tf.distribute.MirroredStrategy() + + def init_model( model_params: models.ModelTrainingParameters, experiment_name: str, @@ -361,28 +402,29 @@ def init_model( Returns: Initialized TwoHeadedOutputModel model. """ - model_class = models.get_model(model_params.model_name) - base_model = model_class(model_params=model_params) - - two_head_model = TwoHeadedOutputModel( - model=base_model, - num_subgroups=model_params.num_subgroups, - subgroup_sizes=model_params.subgroup_sizes, - worst_group_label=model_params.worst_group_label, - train_bias=model_params.train_bias, - name=experiment_name, - do_reweighting=model_params.do_reweighting, - reweighting_signal=model_params.reweighting_signal, - reweighting_lambda=model_params.reweighting_lambda, - error_percentile_threshold=model_params - .reweighting_error_percentile_threshold, - num_classes=model_params.num_classes) - - if model_params.train_bias or model_params.do_reweighting: - if example_id_to_bias_table: - two_head_model.update_id_to_bias_table(example_id_to_bias_table) - - two_head_model = compile_model(two_head_model, model_params) + strategy = _create_strategy(tpu_bns=model_params.tpu_bns) + with strategy.scope(): + model_class = models.get_model(model_params.model_name) + base_model = model_class(model_params=model_params) + + two_head_model = TwoHeadedOutputModel( + model=base_model, + num_subgroups=model_params.num_subgroups, + subgroup_sizes=model_params.subgroup_sizes, + worst_group_label=model_params.worst_group_label, + train_bias=model_params.train_bias, + name=experiment_name, + do_reweighting=model_params.do_reweighting, + reweighting_signal=model_params.reweighting_signal, + reweighting_lambda=model_params.reweighting_lambda, + error_percentile_threshold=model_params + .reweighting_error_percentile_threshold) + + if model_params.train_bias or model_params.do_reweighting: + if example_id_to_bias_table: + two_head_model.update_id_to_bias_table(example_id_to_bias_table) + + two_head_model = compile_model(two_head_model, model_params) return two_head_model