-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcloud_batch_submit.py
executable file
·130 lines (120 loc) · 4.17 KB
/
cloud_batch_submit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python3
import argparse
import json
import subprocess
import textwrap
import os
import time
import uuid
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# See Cloud Build variable substitutions.
parser.add_argument('--location', required=True)
parser.add_argument('--project-id', required=True)
parser.add_argument('--tag-name', required=True)
# The service account to run the job should match the one used in the
# VM instance template.
parser.add_argument('--service-account', required=True)
parser.add_argument(
'--write-success-file',
help='Waits for the Cloud Batch job to finish and writes a `_SUCCESS` file to the output directory if the job finished successfully.',
action='store_true',
default=False,
)
# See cuking.cu for argument help.
parser.add_argument('--input-uri', required=True)
parser.add_argument('--output-uri', required=True)
parser.add_argument('--requester-pays-project', required=True)
parser.add_argument('--kin-threshold', type=float, required=True)
parser.add_argument('--split-factor', type=int, required=True)
args = parser.parse_args()
batch_job_json = textwrap.dedent(
"""\
{{
"taskGroups": [
{{
"taskSpec": {{
"runnables": [
{{
"script": {{
"text": "sudo docker run --name cuking --gpus all {location}-docker.pkg.dev/{project_id}/images/cuking:{tag_name} cuking --input_uri={input_uri} --output_uri={output_uri} --requester_pays_project={requester_pays_project} --kin_threshold={kin_threshold} --split_factor={split_factor} --shard_index=${{BATCH_TASK_INDEX}}"
}}
}}
],
"computeResource": {{
"cpuMilli": 12000,
"memoryMib": 87040
}},
"maxRunDuration": "36000s"
}},
"taskCount": {task_count}
}}
],
"allocationPolicy": {{
"serviceAccount": {{
"email": "{service_account}"
}},
"instances": [
{{
"instanceTemplate": "cuking-instance-template"
}}
]
}},
"logsPolicy": {{
"destination": "CLOUD_LOGGING"
}}
}}
"""
).format(**vars(args), task_count=args.split_factor * (args.split_factor + 1) // 2)
JSON_FILENAME = 'batch_job.json'
with open(JSON_FILENAME, 'w') as f:
print(batch_job_json, file=f)
# Use a UUID in the job name to avoid collisions.
job_name = f'cuking-{uuid.uuid4()}'
cmd = [
'gcloud',
'batch',
'jobs',
'submit',
job_name,
f'--location={args.location}',
f'--config={JSON_FILENAME}',
]
print(f'Submitting job:\n {" ".join(cmd)}')
subprocess.run(cmd, check=True)
status_cmd = [
'gcloud',
'batch',
'jobs',
'describe',
job_name,
f'--location={args.location}',
]
print(f'\nTo check the status of the job, run:\n {" ".join(status_cmd)}')
if args.write_success_file:
while True:
proc = subprocess.run(
status_cmd + ['--format=json'], check=True, capture_output=True
)
output = proc.stdout.decode('utf-8').strip()
state = json.loads(output)['status']['state']
print(f'Current job state: {state}')
if state == 'SUCCEEDED':
print('Writing `_SUCCESS` file...')
cp_proc = subprocess.run(
[
'gcloud',
'storage',
'cp',
'-',
os.path.join(args.output_uri, '_SUCCESS'),
],
input='',
check=True,
)
break
elif state == 'FAILED':
print('Job failed, exiting...')
break
print('Waiting...')
time.sleep(5 * 60) # 5 minutes