Skip to content

Commit

Permalink
Enable distributed training across multiple GPUs or TPUs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 533851979
  • Loading branch information
Uncertainty Baselines Team authored and copybara-github committed May 21, 2023
1 parent 7fa69de commit 6098a87
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 28 deletions.
1 change: 1 addition & 0 deletions experimental/shoshin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class ModelTrainingParameters:
reweighting_signal: Optional[str] = 'bias'
reweighting_lambda: Optional[float] = 0.5
reweighting_error_percentile_threshold: Optional[float] = 0.2
tpu_bns: Optional[str] = ''

def asdict(self):
return dataclasses.asdict(self)
Expand Down
5 changes: 4 additions & 1 deletion experimental/shoshin/train_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
'directory. If False, only logs to console.')
flags.DEFINE_string('ensemble_dir', '', 'If specified, loads the models at '
'this directory to consider the ensemble.')
# The 'tpu' flag will be set by the `build_tpu_jobs` in the launch script.
flags.DEFINE_string('tpu', '', 'The BNS address of the first TPU worker.')


def main(_) -> None:
Expand Down Expand Up @@ -120,7 +122,8 @@ def main(_) -> None:
use_pytorch_style_resnet=config.model.use_pytorch_style_resnet,
do_reweighting=config.reweighting.do_reweighting,
reweighting_lambda=config.reweighting.lambda_value,
reweighting_signal=config.reweighting.signal
reweighting_signal=config.reweighting.signal,
tpu_bns=FLAGS.tpu
)
model_params.train_bias = config.train_bias
output_dir = config.output_dir
Expand Down
96 changes: 69 additions & 27 deletions experimental/shoshin/train_tf_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ def _compute_average_metrics(
accs.append(m.result())
subgroup_label = m.name.split('_')[1]
weighted_accs.append(
m.result() * float(self.subgroup_sizes[subgroup_label]) / total_size
m.result()
* float(self.subgroup_sizes[subgroup_label])
/ total_size
/ tf.distribute.get_strategy().num_replicas_in_sync
)
self.avg_acc.reset_state()
self.avg_acc.update_state(accs)
Expand All @@ -123,7 +126,6 @@ def _compute_average_metrics(
def train_step(self, inputs):
features = inputs['input_feature']
labels = inputs['label']
example_ids = inputs['example_id']
subgroup_labels = inputs['subgroup_label']

y_true_main = tf.one_hot(labels, depth=self.num_classes)
Expand All @@ -136,6 +138,8 @@ def train_step(self, inputs):
self.reweighting_signal == 'bias'):
if self.id_to_bias_table is None:
raise ValueError('id_to_bias_table must not be None.')
# TODO(b/280491870): Change example id type to support TPU training.
example_ids = inputs['example_id']
y_true_bias = self.id_to_bias_table.lookup(example_ids)
y_true_bias_original = y_true_bias
y_true_bias = tf.one_hot(y_true_bias, depth=2)
Expand Down Expand Up @@ -168,8 +172,11 @@ def train_step(self, inputs):
below_threshold_example_multiplex)

total_loss = self.compiled_loss(
y_true, y_pred, sample_weight=sample_weight)
total_loss += sum(self.losses) # Regularization loss.
y_true,
y_pred,
sample_weight=sample_weight,
regularization_losses=self.losses
)

gradients = tape.gradient(total_loss, self.model.trainable_variables)
self.optimizer.apply_gradients(
Expand Down Expand Up @@ -198,14 +205,14 @@ def train_step(self, inputs):
def test_step(self, inputs):
features = inputs['input_feature']
labels = inputs['label']
example_ids = inputs['example_id']
subgroup_labels = inputs['subgroup_label']
y_true_main = tf.one_hot(labels, depth=2)
y_pred = self(features, training=False)
y_true = {'main': y_true_main}
if self.train_bias:
if self.id_to_bias_table is None:
raise ValueError('id_to_bias_table must not be None.')
example_ids = inputs['example_id']
y_true_bias = self.id_to_bias_table.lookup(example_ids)
y_true['bias'] = tf.one_hot(y_true_bias, depth=2)

Expand Down Expand Up @@ -345,6 +352,40 @@ def evaluate_model(
logging.info(results)


def _create_strategy(
tpu_bns: Optional[str] = '') -> tf.distribute.Strategy:
"""Creates distribution strategy used in training.
Args:
tpu_bns: The bns address of the first TPU worker.
Returns:
tf.distribute.Strategy
"""
use_tpu = False
if tpu_bns: # Use tpu if tpu_bns is specified
use_tpu = True
else: # Use tpu if tpu is available
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
if device.device_type == 'TPU':
use_tpu = True
break

if use_tpu:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=tpu_bns
)
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
logging.info('Train with TPU Strategy.')
return tf.distribute.experimental.TPUStrategy(resolver)
else:
# MirroredStrategy will use the available GPUs or CPUs.
logging.info('Train with Mirrored Strategy.')
return tf.distribute.MirroredStrategy()


def init_model(
model_params: models.ModelTrainingParameters,
experiment_name: str,
Expand All @@ -361,28 +402,29 @@ def init_model(
Returns:
Initialized TwoHeadedOutputModel model.
"""
model_class = models.get_model(model_params.model_name)
base_model = model_class(model_params=model_params)

two_head_model = TwoHeadedOutputModel(
model=base_model,
num_subgroups=model_params.num_subgroups,
subgroup_sizes=model_params.subgroup_sizes,
worst_group_label=model_params.worst_group_label,
train_bias=model_params.train_bias,
name=experiment_name,
do_reweighting=model_params.do_reweighting,
reweighting_signal=model_params.reweighting_signal,
reweighting_lambda=model_params.reweighting_lambda,
error_percentile_threshold=model_params
.reweighting_error_percentile_threshold,
num_classes=model_params.num_classes)

if model_params.train_bias or model_params.do_reweighting:
if example_id_to_bias_table:
two_head_model.update_id_to_bias_table(example_id_to_bias_table)

two_head_model = compile_model(two_head_model, model_params)
strategy = _create_strategy(tpu_bns=model_params.tpu_bns)
with strategy.scope():
model_class = models.get_model(model_params.model_name)
base_model = model_class(model_params=model_params)

two_head_model = TwoHeadedOutputModel(
model=base_model,
num_subgroups=model_params.num_subgroups,
subgroup_sizes=model_params.subgroup_sizes,
worst_group_label=model_params.worst_group_label,
train_bias=model_params.train_bias,
name=experiment_name,
do_reweighting=model_params.do_reweighting,
reweighting_signal=model_params.reweighting_signal,
reweighting_lambda=model_params.reweighting_lambda,
error_percentile_threshold=model_params
.reweighting_error_percentile_threshold)

if model_params.train_bias or model_params.do_reweighting:
if example_id_to_bias_table:
two_head_model.update_id_to_bias_table(example_id_to_bias_table)

two_head_model = compile_model(two_head_model, model_params)
return two_head_model


Expand Down

0 comments on commit 6098a87

Please sign in to comment.