Skip to content

Commit

Permalink
pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 13, 2024
1 parent a085e8c commit bce85e6
Showing 1 changed file with 59 additions and 3 deletions.
62 changes: 59 additions & 3 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,64 @@ async def metrics_reporter():
await asyncio.sleep(10)

def submit_beaker_job(args):
from gantry.commands.run import run
from beaker import (
Beaker,
Constraints,
DataMount,
DataSource,
EnvVar,
ExperimentSpec,
ImageSource,
Priority,
ResultSpec,
SecretNotFound,
TaskContext,
TaskResources,
TaskSpec,
)

b = Beaker.from_env(default_workspace=args.beaker_workspace)
account = b.account.whoami()
beaker_image = "ai2/cuda11.8-ubuntu20.04"

task_name = f"pdelfin-{os.path.basename(args.workspace)}"
priority = "normal"

args_list = sum(([f"--{k}", str(v)] if not isinstance(v, bool) else [f"--{k}"] for k, v in vars(args).items() if v is not None), [])

# Create the experiment spec
experiment_spec = ExperimentSpec(
budget="ai2/oe-data",
description=task_name,
tasks=[
TaskSpec(
name=task_name,
propagate_failure=False,
propagate_preemption=False,
replicas=1,
context=TaskContext(
priority=Priority(priority),
preemptible=True,
),
image=ImageSource(beaker=beaker_image),
command=["python", "-m", "pdelfin.beakerpipeline"] + args_list,
env_vars=[
EnvVar(name="BEAKER_JOB_NAME", value=task_name),
EnvVar(name="OWNER", value=account.name),
#EnvVar(name="AWS_ACCESS_KEY_ID", secret=f"{account.name}-{S2_AWS_ACCESS_KEY_ID_SECRET_NAME}"),
#EnvVar(name="AWS_SECRET_ACCESS_KEY", secret=f"{account.name}-{S2_AWS_SECRET_ACCESS_KEY_SECRET_NAME}"),
],
resources=TaskResources(gpu_count=1),
constraints=Constraints(cluster=args.beaker_cluster),
result=ResultSpec(path="/noop-results"),
)
],
)

experiment_data = b.experiment.create(spec=experiment_spec, workspace=args.beaker_workspace)

print(f"Experiment URL: https://beaker.org/ex/{experiment_data.id}")

run()

async def main():
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
Expand All @@ -542,6 +597,7 @@ async def main():
# Beaker/job running stuff
parser.add_argument('--beaker', action='store_true', help='Submit this job to beaker instead of running locally')
parser.add_argument('--beaker_workspace', help='Beaker workspace to submit to', default='ai2/pdelfin')
parser.add_argument('--beaker_cluster', help='Beaker clusters you want to run on', default=["ai2/jupiter-cirrascale-2", "ai2/pluto-cirrascale", "ai2/saturn-cirrascale"])
args = parser.parse_args()

if args.workspace_profile:
Expand All @@ -561,7 +617,7 @@ async def main():
await populate_pdf_work_queue(args)

if args.beaker:
submit_beaker_job()
submit_beaker_job(args)
return

# Create a semaphore to control worker access
Expand Down

0 comments on commit bce85e6

Please sign in to comment.