@@ -13,7 +13,7 @@ class ADNNModelRepositoryDataInterfaces(ILogger):
13
13
ResearchModelRepositoryDataInterface
14
14
"""
15
15
16
- def __init__ (self , data_connection_location : str = ' local' , data_connection_credentials : str = None ):
16
+ def __init__ (self , data_connection_location : str = " local" , data_connection_credentials : str = None ):
17
17
"""
18
18
ModelCheckpointsDataInterface
19
19
:param data_connection_location: 'local' or s3 bucket 's3://my-bucket-name'
@@ -22,35 +22,33 @@ def __init__(self, data_connection_location: str = 'local', data_connection_cred
22
22
AWS_PROFILE if left empty
23
23
"""
24
24
super ().__init__ ()
25
- self .tb_events_file_prefix = ' events.out.tfevents'
26
- self .log_file_prefix = ' log_'
27
- self .latest_checkpoint_filename = ' ckpt_latest.pth'
28
- self .best_checkpoint_filename = ' ckpt_best.pth'
25
+ self .tb_events_file_prefix = " events.out.tfevents"
26
+ self .log_file_prefix = " log_"
27
+ self .latest_checkpoint_filename = " ckpt_latest.pth"
28
+ self .best_checkpoint_filename = " ckpt_best.pth"
29
29
30
- if data_connection_location .startswith ('s3' ):
31
- assert data_connection_location .index (' s3://' ) >= 0 , ' S3 path must be formatted s3://bucket-name'
32
- self .model_repo_bucket_name = data_connection_location .split (' ://' )[1 ]
33
- self .data_connection_source = 's3'
30
+ if data_connection_location .startswith ("s3" ):
31
+ assert data_connection_location .index (" s3://" ) >= 0 , " S3 path must be formatted s3://bucket-name"
32
+ self .model_repo_bucket_name = data_connection_location .split (" ://" )[1 ]
33
+ self .data_connection_source = "s3"
34
34
35
35
if data_connection_credentials is None :
36
- data_connection_credentials = os .getenv (' AWS_PROFILE' )
36
+ data_connection_credentials = os .getenv (" AWS_PROFILE" )
37
37
38
38
self .s3_connector = S3Connector (data_connection_credentials , self .model_repo_bucket_name )
39
39
40
- @explicit_params_validation (validation_type = ' None' )
40
+ @explicit_params_validation (validation_type = " None" )
41
41
def load_all_remote_log_files (self , model_name : str , model_checkpoint_local_dir : str ):
42
42
"""
43
43
load_all_remote_checkpoint_files
44
44
:param model_name:
45
45
:param model_checkpoint_local_dir:
46
46
:return:
47
47
"""
48
- self .load_remote_logging_files (model_name = model_name , model_checkpoint_dir_name = model_checkpoint_local_dir ,
49
- logging_type = 'tensorboard' )
50
- self .load_remote_logging_files (model_name = model_name , model_checkpoint_dir_name = model_checkpoint_local_dir ,
51
- logging_type = 'text' )
48
+ self .load_remote_logging_files (model_name = model_name , model_checkpoint_dir_name = model_checkpoint_local_dir , logging_type = "tensorboard" )
49
+ self .load_remote_logging_files (model_name = model_name , model_checkpoint_dir_name = model_checkpoint_local_dir , logging_type = "text" )
52
50
53
- @explicit_params_validation (validation_type = ' None' )
51
+ @explicit_params_validation (validation_type = " None" )
54
52
def save_all_remote_checkpoint_files (self , model_name : str , model_checkpoint_local_dir : str , log_file_name : str ):
55
53
"""
56
54
save_all_remote_checkpoint_files - Saves all of the local Checkpoint data into Remote Repo
@@ -64,9 +62,10 @@ def save_all_remote_checkpoint_files(self, model_name: str, model_checkpoint_loc
64
62
self .save_remote_checkpoints_file (model_name , model_checkpoint_local_dir , log_file_name )
65
63
self .save_remote_tensorboard_event_files (model_name , model_checkpoint_local_dir )
66
64
67
- @explicit_params_validation (validation_type = 'None' )
68
- def load_remote_checkpoints_file (self , ckpt_source_remote_dir : str , ckpt_destination_local_dir : str ,
69
- ckpt_file_name : str , overwrite_local_checkpoints_file : bool = False ) -> str :
65
+ @explicit_params_validation (validation_type = "None" )
66
+ def load_remote_checkpoints_file (
67
+ self , ckpt_source_remote_dir : str , ckpt_destination_local_dir : str , ckpt_file_name : str , overwrite_local_checkpoints_file : bool = False
68
+ ) -> str :
70
69
"""
71
70
load_remote_checkpoints_file - Loads a model's checkpoint from local/cloud file
72
71
:param ckpt_source_remote_dir: The source folder to download from
@@ -76,27 +75,26 @@ def load_remote_checkpoints_file(self, ckpt_source_remote_dir: str, ckpt_destina
76
75
is to overwrite a previous version of the same files
77
76
:return: Model Checkpoint File Path -> Depends on model architecture
78
77
"""
79
- ckpt_file_local_full_path = ckpt_destination_local_dir + '/' + ckpt_file_name
78
+ ckpt_file_local_full_path = ckpt_destination_local_dir + "/" + ckpt_file_name
80
79
81
- if self .data_connection_source == 's3' :
80
+ if self .data_connection_source == "s3" :
82
81
if overwrite_local_checkpoints_file :
83
82
# DELETE THE LOCAL VERSION ON THE MACHINE
84
83
if os .path .exists (ckpt_file_local_full_path ):
85
84
os .remove (ckpt_file_local_full_path )
86
85
87
- key_to_download = ckpt_source_remote_dir + '/' + ckpt_file_name
88
- download_success = self .s3_connector .download_key (target_path = ckpt_file_local_full_path ,
89
- key_to_download = key_to_download )
86
+ key_to_download = ckpt_source_remote_dir + "/" + ckpt_file_name
87
+ download_success = self .s3_connector .download_key (target_path = ckpt_file_local_full_path , key_to_download = key_to_download )
90
88
91
89
if not download_success :
92
- failed_download_path = ' s3://' + self .model_repo_bucket_name + '/' + key_to_download
93
- error_msg = ' Failed to Download Model Checkpoint from ' + failed_download_path
90
+ failed_download_path = " s3://" + self .model_repo_bucket_name + "/" + key_to_download
91
+ error_msg = " Failed to Download Model Checkpoint from " + failed_download_path
94
92
self ._logger .error (error_msg )
95
93
raise ModelCheckpointNotFoundException (error_msg )
96
94
97
95
return ckpt_file_local_full_path
98
96
99
- @explicit_params_validation (validation_type = ' NoneOrEmpty' )
97
+ @explicit_params_validation (validation_type = " NoneOrEmpty" )
100
98
def load_remote_logging_files (self , model_name : str , model_checkpoint_dir_name : str , logging_type : str ):
101
99
"""
102
100
load_remote_tensorboard_event_files - Downloads all of the tb_events Files from remote repository
@@ -106,24 +104,23 @@ def load_remote_logging_files(self, model_name: str, model_checkpoint_dir_name:
106
104
:return:
107
105
"""
108
106
if not os .path .isdir (model_checkpoint_dir_name ):
109
- raise ValueError ('[' + sys ._getframe ().f_code .co_name + ' ] - Provided directory does not exist' )
107
+ raise ValueError ("[" + sys ._getframe ().f_code .co_name + " ] - Provided directory does not exist" )
110
108
111
109
# LOADS THE DATA FROM THE REMOTE REPOSITORY
112
110
s3_bucket_path_prefix = model_name
113
- if logging_type == 'tensorboard' :
114
- if self .data_connection_source == 's3' :
115
- self .s3_connector .download_keys_by_prefix (s3_bucket_path_prefix = s3_bucket_path_prefix ,
116
- local_download_dir = model_checkpoint_dir_name ,
117
- s3_file_path_prefix = self .tb_events_file_prefix )
118
- elif logging_type == 'text' :
119
- if self .data_connection_source == 's3' :
120
- self .s3_connector .download_keys_by_prefix (s3_bucket_path_prefix = s3_bucket_path_prefix ,
121
- local_download_dir = model_checkpoint_dir_name ,
122
- s3_file_path_prefix = self .log_file_prefix )
123
-
124
- @explicit_params_validation (validation_type = 'NoneOrEmpty' )
125
- def save_remote_checkpoints_file (self , model_name : str , model_checkpoint_local_dir : str ,
126
- checkpoints_file_name : str ) -> bool :
111
+ if logging_type == "tensorboard" :
112
+ if self .data_connection_source == "s3" :
113
+ self .s3_connector .download_keys_by_prefix (
114
+ s3_bucket_path_prefix = s3_bucket_path_prefix , local_download_dir = model_checkpoint_dir_name , s3_file_path_prefix = self .tb_events_file_prefix
115
+ )
116
+ elif logging_type == "text" :
117
+ if self .data_connection_source == "s3" :
118
+ self .s3_connector .download_keys_by_prefix (
119
+ s3_bucket_path_prefix = s3_bucket_path_prefix , local_download_dir = model_checkpoint_dir_name , s3_file_path_prefix = self .log_file_prefix
120
+ )
121
+
122
+ @explicit_params_validation (validation_type = "NoneOrEmpty" )
123
+ def save_remote_checkpoints_file (self , model_name : str , model_checkpoint_local_dir : str , checkpoints_file_name : str ) -> bool :
127
124
"""
128
125
save_remote_checkpoints_file - Saves a Checkpoints file in the Remote Repo
129
126
:param model_name: The Model Name for S3 Prefix
@@ -132,33 +129,33 @@ def save_remote_checkpoints_file(self, model_name: str, model_checkpoint_local_d
132
129
:return: True/False for Operation Success/Failure
133
130
"""
134
131
# LOAD THE LOCAL VERSION
135
- model_checkpoint_file_full_path = model_checkpoint_local_dir + '/' + checkpoints_file_name
132
+ model_checkpoint_file_full_path = model_checkpoint_local_dir + "/" + checkpoints_file_name
136
133
137
134
# SAVE ON THE REMOTE S3 REPOSITORY
138
- if self .data_connection_source == 's3' :
139
- model_checkpoint_s3_in_bucket_path = model_name + '/' + checkpoints_file_name
135
+ if self .data_connection_source == "s3" :
136
+ model_checkpoint_s3_in_bucket_path = model_name + "/" + checkpoints_file_name
140
137
return self .__update_or_upload_s3_key (model_checkpoint_file_full_path , model_checkpoint_s3_in_bucket_path )
141
138
142
- @explicit_params_validation (validation_type = ' NoneOrEmpty' )
139
+ @explicit_params_validation (validation_type = " NoneOrEmpty" )
143
140
def save_remote_tensorboard_event_files (self , model_name : str , model_checkpoint_dir_name : str ):
144
141
"""
145
142
save_remote_tensorboard_event_files - Saves all of the tensorboard files remotely
146
143
:param model_name: Prefix for Cloud Storage
147
144
:param model_checkpoint_dir_name: The directory where the files are stored in
148
145
"""
149
146
if not os .path .isdir (model_checkpoint_dir_name ):
150
- raise ValueError ('[' + sys ._getframe ().f_code .co_name + ' ] - Provided directory does not exist' )
147
+ raise ValueError ("[" + sys ._getframe ().f_code .co_name + " ] - Provided directory does not exist" )
151
148
152
149
for tb_events_file_name in os .listdir (model_checkpoint_dir_name ):
153
150
if tb_events_file_name .startswith (self .tb_events_file_prefix ):
154
- upload_success = self .save_remote_checkpoints_file (model_name = model_name ,
155
- model_checkpoint_local_dir = model_checkpoint_dir_name ,
156
- checkpoints_file_name = tb_events_file_name )
151
+ upload_success = self .save_remote_checkpoints_file (
152
+ model_name = model_name , model_checkpoint_local_dir = model_checkpoint_dir_name , checkpoints_file_name = tb_events_file_name
153
+ )
157
154
158
155
if not upload_success :
159
- self ._logger .error (' Failed to upload tb_events_file: ' + tb_events_file_name )
156
+ self ._logger .error (" Failed to upload tb_events_file: " + tb_events_file_name )
160
157
161
- @explicit_params_validation (validation_type = ' NoneOrEmpty' )
158
+ @explicit_params_validation (validation_type = " NoneOrEmpty" )
162
159
def __update_or_upload_s3_key (self , local_file_path : str , s3_key_path : str ):
163
160
"""
164
161
__update_or_upload_s3_key - Uploads/Updates an S3 Key based on a local file path
@@ -169,10 +166,10 @@ def __update_or_upload_s3_key(self, local_file_path: str, s3_key_path: str):
169
166
# DELETE KEY TO UPDATE THE FILE IN S3
170
167
delete_response = self .s3_connector .delete_key (s3_key_path )
171
168
if delete_response :
172
- self ._logger .info (' Removed previous checkpoint from S3' )
169
+ self ._logger .info (" Removed previous checkpoint from S3" )
173
170
174
171
upload_success = self .s3_connector .upload_file (local_file_path , s3_key_path )
175
172
if not upload_success :
176
- self ._logger .error (' Failed to upload model checkpoint' )
173
+ self ._logger .error (" Failed to upload model checkpoint" )
177
174
178
175
return upload_success
0 commit comments