-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdocker_run.py
executable file
·317 lines (243 loc) · 9.98 KB
/
docker_run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
#!/opt/conda/bin/python
"""Run the model inside of the docker container"""
import argparse
import gzip
import json
import logging
import os
import re
import threading
from pathlib import Path
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient
from azure.storage.filedatalake import DataLakeServiceClient
import config
from model.helpers import load_params
from run_model import run_all
class RunWithLocalStorage:
"""Methods for running with local storage"""
def __init__(self, filename: str):
self.params = load_params(f"queue/{filename}")
def finish(
self, results_file: str, saved_files: list, save_full_model_results: bool
) -> None:
"""Post model run steps
:param results_file: the path to the results file
:type results_file: str
:param saved_files: filepaths of results, saved in parquet format and params in json format
:type saved_files: list
:param save_full_model_results: whether to save the full model results or not
:type save_full_model_results: bool
"""
def progress_callback(self) -> None:
"""Progress callback method
for local storage do nothing
"""
return lambda _: lambda _: None
class RunWithAzureStorage:
"""Methods for running with azure storage"""
def __init__(self, filename: str, app_version: str = "dev"):
logging.getLogger("azure.storage.common.storageclient").setLevel(
logging.WARNING
)
logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(
logging.WARNING
)
self._app_version = re.sub("(\\d+\\.\\d+)\\..*", "\\1", app_version)
self.params = self._get_params(filename)
self._get_data(self.params["start_year"], self.params["dataset"])
def _get_container(self, container_name: str):
return BlobServiceClient(
account_url=f"https://{config.STORAGE_ACCOUNT}.blob.core.windows.net",
credential=DefaultAzureCredential(),
).get_container_client(container_name)
def _get_params(self, filename: str) -> dict:
"""Get the parameters for the model
:param filename: the name of the params file
:type filename: str
:return: the parameters for the model
:rtype: dict
"""
logging.info("downloading params: %s", filename)
self._queue_blob = self._get_container("queue").get_blob_client(filename)
params_content = self._queue_blob.download_blob().readall()
return json.loads(params_content)
def _get_data(self, year: str, dataset: str) -> None:
"""Get data to run the model
for local storage, the data is already available, so do nothing.
:param year: the year of data to load
:type year: str
:param year: the year of data to load
:type year: str
"""
logging.info("downloading data (%s / %s)", year, dataset)
fs_client = DataLakeServiceClient(
account_url=f"https://{config.STORAGE_ACCOUNT}.dfs.core.windows.net",
credential=DefaultAzureCredential(),
).get_file_system_client("data")
version = config.DATA_VERSION
paths = [p.name for p in fs_client.get_paths(version, recursive=False)]
for p in paths:
subpath = f"{p}/fyear={year}/dataset={dataset}"
os.makedirs(f"data{subpath.removeprefix(version)}", exist_ok=True)
for i in fs_client.get_paths(subpath):
filename = i.name
if not filename.endswith("parquet"):
continue
logging.info(" * %s", filename)
local_name = "data" + filename.removeprefix(version)
with open(local_name, "wb") as local_file:
file_client = fs_client.get_file_client(filename)
local_file.write(file_client.download_file().readall())
def _upload_results_json(self, results_file: str, metadata: dict) -> None:
"""Upload the results
once the model has run, upload the results to blob storage
:param results_file: the saved results file
:type results_file: str
:param metadata: the metadata to attach to the blob
:type metadata: dict
"""
container = self._get_container("results")
with open(f"results/{results_file}.json", "rb") as file:
container.upload_blob(
f"prod/{self._app_version}/{results_file}.json.gz",
gzip.compress(file.read()),
metadata=metadata,
overwrite=True,
)
def _upload_results_files(self, files: list, metadata: dict) -> None:
"""Upload the results
once the model has run, upload the files (parquet for model results and json for model params) to blob storage
:param files: list of files to be uploaded
:type files: list
:param metadata: the metadata to attach to the blob
:type metadata: dict
"""
container = self._get_container("results")
for file in files:
filename = file[8:]
if file.endswith(".json"):
metadata_to_use = metadata
else:
metadata_to_use = None
with open(file, "rb") as f:
container.upload_blob(
f"aggregated-model-results/{self._app_version}/{filename}",
f.read(),
overwrite=True,
metadata=metadata_to_use,
)
def _upload_full_model_results(self) -> None:
container = self._get_container("results")
dataset = self.params["dataset"]
scenario = self.params["scenario"]
create_datetime = self.params["create_datetime"]
path = Path(f"results/{dataset}/{scenario}/{create_datetime}")
for file in path.glob("**/*.parquet"):
filename = file.as_posix()[8:]
with open(file, "rb") as f:
container.upload_blob(
f"full-model-results/{self._app_version}/{filename}",
f.read(),
overwrite=True,
)
def _cleanup(self) -> None:
"""Cleanup
once the model has run, remove the file from the queue
"""
logging.info("cleaning up queue")
self._queue_blob.delete_blob()
def finish(
self, results_file: str, saved_files: list, save_full_model_results: bool
) -> None:
"""Post model run steps
:param results_file: the path to the results file
:type results_file: str
:param saved_files: filepaths of results, saved in parquet format and params in json format
:type saved_files: list
:param save_full_model_results: whether to save the full model results or not
:type save_full_model_results: bool
"""
metadata = {
k: str(v)
for k, v in self.params.items()
if not isinstance(v, dict) and not isinstance(v, list)
}
self._upload_results_json(results_file, metadata)
self._upload_results_files(saved_files, metadata)
if save_full_model_results:
self._upload_full_model_results()
self._cleanup()
def progress_callback(self) -> None:
"""Progress callback method
updates the metadata for the blob in the queue to give progress
"""
blob = self._queue_blob
current_progress = {
**blob.get_blob_properties()["metadata"],
"Inpatients": 0,
"Outpatients": 0,
"AaE": 0,
}
blob.set_blob_metadata({k: str(v) for k, v in current_progress.items()})
def callback(model_type):
def update(n_completed):
current_progress[model_type] = n_completed
blob.set_blob_metadata({k: str(v) for k, v in current_progress.items()})
return update
return callback
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser()
parser.add_argument(
"params_file",
nargs="?",
default="sample_params.json",
help="Name of the parameters file stored in Azure",
)
parser.add_argument(
"--local-storage",
"-l",
action="store_true",
help="Use local storage (instead of Azure)",
)
parser.add_argument("--save-full-model-results", action="store_true")
return parser.parse_args()
def main():
"""the main method"""
args = parse_args()
logging.basicConfig(
format="%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
if args.local_storage:
runner = RunWithLocalStorage(args.params_file)
else:
runner = RunWithAzureStorage(args.params_file, config.APP_VERSION)
logging.info("running model for: %s", args.params_file)
logging.info("container timeout: %ds", config.CONTAINER_TIMEOUT_SECONDS)
logging.info("submitted by: %s", runner.params.get("user"))
logging.info("model_runs: %s", runner.params["model_runs"])
logging.info("start_year: %s", runner.params["start_year"])
logging.info("end_year: %s", runner.params["end_year"])
logging.info("app_version: %s", runner.params["app_version"])
saved_files, results_file = run_all(
runner.params, "data", runner.progress_callback, args.save_full_model_results
)
runner.finish(results_file, saved_files, args.save_full_model_results)
logging.info("complete")
def _exit_container():
logging.error("\nTimed out, killing container")
os._exit(1)
def init():
"""method for calling main"""
if __name__ == "__main__":
# start a timer to kill the container if we reach a timeout
t = threading.Timer(config.CONTAINER_TIMEOUT_SECONDS, _exit_container)
t.start()
# run the model
main()
# cancel the timer
t.cancel()
init()