-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdistributed-imagenet.yaml
99 lines (95 loc) · 2.99 KB
/
distributed-imagenet.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
name: imagenet_byol_distributed
entrypoint: model_def:BYOLTrial
records_per_epoch: 1271158
resources:
slots_per_trial: 64
shm_size: 17179869184
min_checkpoint_period:
epochs: 1
min_validation_period:
epochs: 1
profiling:
begin_on_batch: 0
enabled: true
end_after_batch: null
sync_timings: true
data:
dataset_name: imagenet
# TODO: Change this to the name of your bucket.
gcs_bucket: bucket-name-here
# TODO: Path to files inside the above bucket that contain newline separated list of all image blob paths.
# Generate using generate_blob_list.py.
gcs_train_blob_list_path: imagenet/train-blob-list.txt
gcs_validation_blob_list_path: imagenet/validation-blob-list.txt
download_dir: /data
num_workers: 8
# Classifier LR is selected using a subset of the training dataset before reporting on the test dataset.
validation_subset_size: 10009
eval_transform:
resize_short_edge: 256
center_crop_size: 224
train_transform1:
random_crop_size: 224
random_crop_min_scale: 0.08
random_hflip_prob: 0.5
color_jitter_prob: 0.8
color_jitter_brightness: 0.4
color_jitter_contrast: 0.4
color_jitter_saturation: 0.2
color_jitter_hue: 0.1
grayscale_prob: 0.2
gaussian_blur_prob: 1.0
gaussian_blur_kernel_size: 23
gaussian_blur_min_std: 0.1
gaussian_blur_max_std: 2.0
solarization_prob: 0.0
train_transform2:
random_crop_size: 224
random_crop_min_scale: 0.08
random_hflip_prob: 0.5
color_jitter_prob: 0.8
color_jitter_brightness: 0.4
color_jitter_contrast: 0.4
color_jitter_saturation: 0.2
color_jitter_hue: 0.1
grayscale_prob: 0.2
gaussian_blur_prob: 0.1
gaussian_blur_kernel_size: 23
gaussian_blur_min_std: 0.1
gaussian_blur_max_std: 2.0
solarization_prob: 0.2
hyperparameters:
training_mode: SELF_SUPERVISED
validate_with_classifier: false
backbone_name: resnet50
global_batch_size: 4096
# Moving average decay increases to 1.0 over training as described in section 3.3 of BYOL paper.
classifier:
learning_rates: [0.4, 0.3, 0.2, 0.1, 0.05]
# Alpha and beta are described in section C.1 of BYOL paper.
# These settings match the "variant" evaluation protocol described there, which reuses
# the same training augmentations for classification as the self-supervised training.
logit_clipping:
enabled: true
alpha: 20
logit_regularization_beta: 1e-2
momentum: 0.9
train_epochs: 80
self_supervised:
# lars_eta and momentum not specified in the paper, copied from JAX implementation.
lars_eta: 0.001
momentum: 0.9
moving_average_decay_base: 0.996
weight_decay: 1.5e-6
learning_rate:
# Learning rate scales linearly with batch size, is equal to base when global_batch_size==base_batch_size.
base: 0.2
base_batch_size: 256
warmup_epochs: 10
searcher:
name: single
metric: validation_loss
smaller_is_better: true
max_length:
# Note: BYOL paper trains for 1000 epochs for full convergence.
epochs: 20