Skip to content

Commit b31f3a8

Browse files
black on factories and data_interface (#620)
Co-authored-by: Ofri Masad <[email protected]>
1 parent d6f84a7 commit b31f3a8

File tree

6 files changed

+51
-59
lines changed

6 files changed

+51
-59
lines changed

src/super_gradients/common/data_interface/adnn_model_repository_data_interface.py

+51-54
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class ADNNModelRepositoryDataInterfaces(ILogger):
1313
ResearchModelRepositoryDataInterface
1414
"""
1515

16-
def __init__(self, data_connection_location: str = 'local', data_connection_credentials: str = None):
16+
def __init__(self, data_connection_location: str = "local", data_connection_credentials: str = None):
1717
"""
1818
ModelCheckpointsDataInterface
1919
:param data_connection_location: 'local' or s3 bucket 's3://my-bucket-name'
@@ -22,35 +22,33 @@ def __init__(self, data_connection_location: str = 'local', data_connection_cred
2222
AWS_PROFILE if left empty
2323
"""
2424
super().__init__()
25-
self.tb_events_file_prefix = 'events.out.tfevents'
26-
self.log_file_prefix = 'log_'
27-
self.latest_checkpoint_filename = 'ckpt_latest.pth'
28-
self.best_checkpoint_filename = 'ckpt_best.pth'
25+
self.tb_events_file_prefix = "events.out.tfevents"
26+
self.log_file_prefix = "log_"
27+
self.latest_checkpoint_filename = "ckpt_latest.pth"
28+
self.best_checkpoint_filename = "ckpt_best.pth"
2929

30-
if data_connection_location.startswith('s3'):
31-
assert data_connection_location.index('s3://') >= 0, 'S3 path must be formatted s3://bucket-name'
32-
self.model_repo_bucket_name = data_connection_location.split('://')[1]
33-
self.data_connection_source = 's3'
30+
if data_connection_location.startswith("s3"):
31+
assert data_connection_location.index("s3://") >= 0, "S3 path must be formatted s3://bucket-name"
32+
self.model_repo_bucket_name = data_connection_location.split("://")[1]
33+
self.data_connection_source = "s3"
3434

3535
if data_connection_credentials is None:
36-
data_connection_credentials = os.getenv('AWS_PROFILE')
36+
data_connection_credentials = os.getenv("AWS_PROFILE")
3737

3838
self.s3_connector = S3Connector(data_connection_credentials, self.model_repo_bucket_name)
3939

40-
@explicit_params_validation(validation_type='None')
40+
@explicit_params_validation(validation_type="None")
4141
def load_all_remote_log_files(self, model_name: str, model_checkpoint_local_dir: str):
4242
"""
4343
load_all_remote_checkpoint_files
4444
:param model_name:
4545
:param model_checkpoint_local_dir:
4646
:return:
4747
"""
48-
self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir,
49-
logging_type='tensorboard')
50-
self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir,
51-
logging_type='text')
48+
self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir, logging_type="tensorboard")
49+
self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir, logging_type="text")
5250

53-
@explicit_params_validation(validation_type='None')
51+
@explicit_params_validation(validation_type="None")
5452
def save_all_remote_checkpoint_files(self, model_name: str, model_checkpoint_local_dir: str, log_file_name: str):
5553
"""
5654
save_all_remote_checkpoint_files - Saves all of the local Checkpoint data into Remote Repo
@@ -64,9 +62,10 @@ def save_all_remote_checkpoint_files(self, model_name: str, model_checkpoint_loc
6462
self.save_remote_checkpoints_file(model_name, model_checkpoint_local_dir, log_file_name)
6563
self.save_remote_tensorboard_event_files(model_name, model_checkpoint_local_dir)
6664

67-
@explicit_params_validation(validation_type='None')
68-
def load_remote_checkpoints_file(self, ckpt_source_remote_dir: str, ckpt_destination_local_dir: str,
69-
ckpt_file_name: str, overwrite_local_checkpoints_file: bool = False) -> str:
65+
@explicit_params_validation(validation_type="None")
66+
def load_remote_checkpoints_file(
67+
self, ckpt_source_remote_dir: str, ckpt_destination_local_dir: str, ckpt_file_name: str, overwrite_local_checkpoints_file: bool = False
68+
) -> str:
7069
"""
7170
load_remote_checkpoints_file - Loads a model's checkpoint from local/cloud file
7271
:param ckpt_source_remote_dir: The source folder to download from
@@ -76,27 +75,26 @@ def load_remote_checkpoints_file(self, ckpt_source_remote_dir: str, ckpt_destina
7675
is to overwrite a previous version of the same files
7776
:return: Model Checkpoint File Path -> Depends on model architecture
7877
"""
79-
ckpt_file_local_full_path = ckpt_destination_local_dir + '/' + ckpt_file_name
78+
ckpt_file_local_full_path = ckpt_destination_local_dir + "/" + ckpt_file_name
8079

81-
if self.data_connection_source == 's3':
80+
if self.data_connection_source == "s3":
8281
if overwrite_local_checkpoints_file:
8382
# DELETE THE LOCAL VERSION ON THE MACHINE
8483
if os.path.exists(ckpt_file_local_full_path):
8584
os.remove(ckpt_file_local_full_path)
8685

87-
key_to_download = ckpt_source_remote_dir + '/' + ckpt_file_name
88-
download_success = self.s3_connector.download_key(target_path=ckpt_file_local_full_path,
89-
key_to_download=key_to_download)
86+
key_to_download = ckpt_source_remote_dir + "/" + ckpt_file_name
87+
download_success = self.s3_connector.download_key(target_path=ckpt_file_local_full_path, key_to_download=key_to_download)
9088

9189
if not download_success:
92-
failed_download_path = 's3://' + self.model_repo_bucket_name + '/' + key_to_download
93-
error_msg = 'Failed to Download Model Checkpoint from ' + failed_download_path
90+
failed_download_path = "s3://" + self.model_repo_bucket_name + "/" + key_to_download
91+
error_msg = "Failed to Download Model Checkpoint from " + failed_download_path
9492
self._logger.error(error_msg)
9593
raise ModelCheckpointNotFoundException(error_msg)
9694

9795
return ckpt_file_local_full_path
9896

99-
@explicit_params_validation(validation_type='NoneOrEmpty')
97+
@explicit_params_validation(validation_type="NoneOrEmpty")
10098
def load_remote_logging_files(self, model_name: str, model_checkpoint_dir_name: str, logging_type: str):
10199
"""
102100
load_remote_tensorboard_event_files - Downloads all of the tb_events Files from remote repository
@@ -106,24 +104,23 @@ def load_remote_logging_files(self, model_name: str, model_checkpoint_dir_name:
106104
:return:
107105
"""
108106
if not os.path.isdir(model_checkpoint_dir_name):
109-
raise ValueError('[' + sys._getframe().f_code.co_name + '] - Provided directory does not exist')
107+
raise ValueError("[" + sys._getframe().f_code.co_name + "] - Provided directory does not exist")
110108

111109
# LOADS THE DATA FROM THE REMOTE REPOSITORY
112110
s3_bucket_path_prefix = model_name
113-
if logging_type == 'tensorboard':
114-
if self.data_connection_source == 's3':
115-
self.s3_connector.download_keys_by_prefix(s3_bucket_path_prefix=s3_bucket_path_prefix,
116-
local_download_dir=model_checkpoint_dir_name,
117-
s3_file_path_prefix=self.tb_events_file_prefix)
118-
elif logging_type == 'text':
119-
if self.data_connection_source == 's3':
120-
self.s3_connector.download_keys_by_prefix(s3_bucket_path_prefix=s3_bucket_path_prefix,
121-
local_download_dir=model_checkpoint_dir_name,
122-
s3_file_path_prefix=self.log_file_prefix)
123-
124-
@explicit_params_validation(validation_type='NoneOrEmpty')
125-
def save_remote_checkpoints_file(self, model_name: str, model_checkpoint_local_dir: str,
126-
checkpoints_file_name: str) -> bool:
111+
if logging_type == "tensorboard":
112+
if self.data_connection_source == "s3":
113+
self.s3_connector.download_keys_by_prefix(
114+
s3_bucket_path_prefix=s3_bucket_path_prefix, local_download_dir=model_checkpoint_dir_name, s3_file_path_prefix=self.tb_events_file_prefix
115+
)
116+
elif logging_type == "text":
117+
if self.data_connection_source == "s3":
118+
self.s3_connector.download_keys_by_prefix(
119+
s3_bucket_path_prefix=s3_bucket_path_prefix, local_download_dir=model_checkpoint_dir_name, s3_file_path_prefix=self.log_file_prefix
120+
)
121+
122+
@explicit_params_validation(validation_type="NoneOrEmpty")
123+
def save_remote_checkpoints_file(self, model_name: str, model_checkpoint_local_dir: str, checkpoints_file_name: str) -> bool:
127124
"""
128125
save_remote_checkpoints_file - Saves a Checkpoints file in the Remote Repo
129126
:param model_name: The Model Name for S3 Prefix
@@ -132,33 +129,33 @@ def save_remote_checkpoints_file(self, model_name: str, model_checkpoint_local_d
132129
:return: True/False for Operation Success/Failure
133130
"""
134131
# LOAD THE LOCAL VERSION
135-
model_checkpoint_file_full_path = model_checkpoint_local_dir + '/' + checkpoints_file_name
132+
model_checkpoint_file_full_path = model_checkpoint_local_dir + "/" + checkpoints_file_name
136133

137134
# SAVE ON THE REMOTE S3 REPOSITORY
138-
if self.data_connection_source == 's3':
139-
model_checkpoint_s3_in_bucket_path = model_name + '/' + checkpoints_file_name
135+
if self.data_connection_source == "s3":
136+
model_checkpoint_s3_in_bucket_path = model_name + "/" + checkpoints_file_name
140137
return self.__update_or_upload_s3_key(model_checkpoint_file_full_path, model_checkpoint_s3_in_bucket_path)
141138

142-
@explicit_params_validation(validation_type='NoneOrEmpty')
139+
@explicit_params_validation(validation_type="NoneOrEmpty")
143140
def save_remote_tensorboard_event_files(self, model_name: str, model_checkpoint_dir_name: str):
144141
"""
145142
save_remote_tensorboard_event_files - Saves all of the tensorboard files remotely
146143
:param model_name: Prefix for Cloud Storage
147144
:param model_checkpoint_dir_name: The directory where the files are stored in
148145
"""
149146
if not os.path.isdir(model_checkpoint_dir_name):
150-
raise ValueError('[' + sys._getframe().f_code.co_name + '] - Provided directory does not exist')
147+
raise ValueError("[" + sys._getframe().f_code.co_name + "] - Provided directory does not exist")
151148

152149
for tb_events_file_name in os.listdir(model_checkpoint_dir_name):
153150
if tb_events_file_name.startswith(self.tb_events_file_prefix):
154-
upload_success = self.save_remote_checkpoints_file(model_name=model_name,
155-
model_checkpoint_local_dir=model_checkpoint_dir_name,
156-
checkpoints_file_name=tb_events_file_name)
151+
upload_success = self.save_remote_checkpoints_file(
152+
model_name=model_name, model_checkpoint_local_dir=model_checkpoint_dir_name, checkpoints_file_name=tb_events_file_name
153+
)
157154

158155
if not upload_success:
159-
self._logger.error('Failed to upload tb_events_file: ' + tb_events_file_name)
156+
self._logger.error("Failed to upload tb_events_file: " + tb_events_file_name)
160157

161-
@explicit_params_validation(validation_type='NoneOrEmpty')
158+
@explicit_params_validation(validation_type="NoneOrEmpty")
162159
def __update_or_upload_s3_key(self, local_file_path: str, s3_key_path: str):
163160
"""
164161
__update_or_upload_s3_key - Uploads/Updates an S3 Key based on a local file path
@@ -169,10 +166,10 @@ def __update_or_upload_s3_key(self, local_file_path: str, s3_key_path: str):
169166
# DELETE KEY TO UPDATE THE FILE IN S3
170167
delete_response = self.s3_connector.delete_key(s3_key_path)
171168
if delete_response:
172-
self._logger.info('Removed previous checkpoint from S3')
169+
self._logger.info("Removed previous checkpoint from S3")
173170

174171
upload_success = self.s3_connector.upload_file(local_file_path, s3_key_path)
175172
if not upload_success:
176-
self._logger.error('Failed to upload model checkpoint')
173+
self._logger.error("Failed to upload model checkpoint")
177174

178175
return upload_success

src/super_gradients/common/factories/callbacks_factory.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,5 @@
33

44

55
class CallbacksFactory(BaseFactory):
6-
76
def __init__(self):
87
super().__init__(CALLBACKS)

src/super_gradients/common/factories/list_factory.py

-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55

66
class ListFactory(AbstractFactory):
7-
87
def __init__(self, factry: AbstractFactory):
98
self.factry = factry
109

src/super_gradients/common/factories/losses_factory.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,5 @@
33

44

55
class LossesFactory(BaseFactory):
6-
76
def __init__(self):
87
super().__init__(LOSSES)

src/super_gradients/common/factories/metrics_factory.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,5 @@
33

44

55
class MetricsFactory(BaseFactory):
6-
76
def __init__(self):
87
super().__init__(METRICS)

src/super_gradients/common/factories/samplers_factory.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,5 @@
33

44

55
class SamplersFactory(BaseFactory):
6-
76
def __init__(self):
87
super().__init__(SAMPLERS)

0 commit comments

Comments
 (0)