Skip to content

Commit

Permalink
feat: broker sends back job after accepting result
Browse files Browse the repository at this point in the history
  • Loading branch information
alexghr committed Feb 5, 2025
1 parent 7a2870f commit 81de3c4
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 57 deletions.
13 changes: 11 additions & 2 deletions yarn-project/circuit-types/src/interfaces/prover-broker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,24 @@ export interface ProvingJobConsumer {
* @param id - The ID of the job to report success for
* @param result - The result of the job
*/
reportProvingJobSuccess(id: ProvingJobId, result: ProofUri): Promise<void>;
reportProvingJobSuccess(
id: ProvingJobId,
result: ProofUri,
filter?: ProvingJobFilter,
): Promise<GetProvingJobResponse | undefined>;

/**
* Marks a proving job as errored
* @param id - The ID of the job to report an error for
* @param err - The error that occurred while processing the job
* @param retry - Whether to retry the job
*/
reportProvingJobError(id: ProvingJobId, err: string, retry?: boolean): Promise<void>;
reportProvingJobError(
id: ProvingJobId,
err: string,
retry?: boolean,
filter?: ProvingJobFilter,
): Promise<GetProvingJobResponse | undefined>;

/**
* Sends a heartbeat to the broker to indicate that the agent is still working on the given proving job
Expand Down
2 changes: 1 addition & 1 deletion yarn-project/foundation/src/queue/serial_queue.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ export class SerialQueue {
* @param fn - The function to enqueue.
* @returns A resolution promise. Rejects if the function does, or if the function could not be enqueued.
*/
public put<T>(fn: () => Promise<T>): Promise<T> {
public put<T>(fn: () => T): Promise<Awaited<T>> {
return new Promise((resolve, reject) => {
const accepted = this.queue.put(async () => {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import { makeBaseParityInputs, makeParityPublicInputs } from '@aztec/circuits.js
import { randomBytes } from '@aztec/foundation/crypto';
import { AbortError } from '@aztec/foundation/error';
import { promiseWithResolvers } from '@aztec/foundation/promise';
import { sleep } from '@aztec/foundation/sleep';

import { jest } from '@jest/globals';

Expand All @@ -32,6 +33,7 @@ describe('ProvingAgent', () => {
let agent: ProvingAgent;
let proofDB: jest.Mocked<ProofStore>;
const agentPollIntervalMs = 1000;
let allowList: ProvingRequestType[];

beforeEach(() => {
jest.useFakeTimers();
Expand All @@ -50,7 +52,8 @@ describe('ProvingAgent', () => {
saveProofOutput: jest.fn(),
};

agent = new ProvingAgent(jobSource, proofDB, prover, [ProvingRequestType.BASE_PARITY]);
allowList = [ProvingRequestType.BASE_PARITY];
agent = new ProvingAgent(jobSource, proofDB, prover, allowList);
});

afterEach(async () => {
Expand Down Expand Up @@ -110,7 +113,7 @@ describe('ProvingAgent', () => {

await jest.advanceTimersByTimeAsync(agentPollIntervalMs);
expect(proofDB.saveProofOutput).toHaveBeenCalledWith(job.id, job.type, result);
expect(jobSource.reportProvingJobSuccess).toHaveBeenCalledWith(job.id, 'output-uri');
expect(jobSource.reportProvingJobSuccess).toHaveBeenCalledWith(job.id, 'output-uri', { allowList });
});

it('reports errors to the job source', async () => {
Expand All @@ -122,7 +125,7 @@ describe('ProvingAgent', () => {
agent.start();

await jest.advanceTimersByTimeAsync(agentPollIntervalMs);
expect(jobSource.reportProvingJobError).toHaveBeenCalledWith(job.id, 'test error', false);
expect(jobSource.reportProvingJobError).toHaveBeenCalledWith(job.id, 'test error', false, { allowList });
});

it('sets the retry flag on when reporting an error', async () => {
Expand All @@ -135,7 +138,7 @@ describe('ProvingAgent', () => {
agent.start();

await jest.advanceTimersByTimeAsync(agentPollIntervalMs);
expect(jobSource.reportProvingJobError).toHaveBeenCalledWith(job.id, err.message, true);
expect(jobSource.reportProvingJobError).toHaveBeenCalledWith(job.id, err.message, true, { allowList });
});

it('reports jobs in progress to the job source', async () => {
Expand Down Expand Up @@ -222,6 +225,52 @@ describe('ProvingAgent', () => {
secondProof.resolve(makeBaseParityResult());
});

it('immediately starts working on the next job', async () => {
const job1 = makeBaseParityJob();
const job2 = makeBaseParityJob();

jest
.spyOn(prover, 'getBaseParityProof')
.mockResolvedValueOnce(makeBaseParityResult())
.mockResolvedValueOnce(makeBaseParityResult());

proofDB.getProofInput.mockResolvedValueOnce(job1.inputs).mockResolvedValueOnce(job2.inputs);
proofDB.saveProofOutput.mockResolvedValue('' as ProofUri);

jobSource.getProvingJob.mockResolvedValueOnce(job1);
jobSource.reportProvingJobSuccess.mockResolvedValueOnce(job2);

agent.start();

await jest.advanceTimersByTimeAsync(agentPollIntervalMs);
await jest.advanceTimersByTimeAsync(0);
await Promise.resolve();
expect(jobSource.reportProvingJobSuccess).toHaveBeenCalledWith(job1.job.id, expect.any(String), { allowList });
expect(jobSource.reportProvingJobSuccess).toHaveBeenCalledWith(job2.job.id, expect.any(String), { allowList });
});

it('immediately starts working after reporting an error', async () => {
const job1 = makeBaseParityJob();
const job2 = makeBaseParityJob();

jest
.spyOn(prover, 'getBaseParityProof')
.mockRejectedValueOnce(new Error('test error'))
.mockResolvedValueOnce(makeBaseParityResult());

proofDB.getProofInput.mockResolvedValueOnce(job1.inputs).mockResolvedValueOnce(job2.inputs);
proofDB.saveProofOutput.mockResolvedValue('' as ProofUri);

jobSource.getProvingJob.mockResolvedValueOnce(job1);
jobSource.reportProvingJobError.mockResolvedValueOnce(job2);

agent.start();

await jest.advanceTimersByTimeAsync(agentPollIntervalMs);
expect(jobSource.reportProvingJobError).toHaveBeenCalledWith(job1.job.id, expect.any(String), false, { allowList });
expect(jobSource.reportProvingJobSuccess).toHaveBeenCalledWith(job2.job.id, expect.any(String), { allowList });
});

it('reports an error if inputs cannot be loaded', async () => {
const { job, time } = makeBaseParityJob();
jobSource.getProvingJob.mockResolvedValueOnce({ job, time });
Expand All @@ -230,7 +279,9 @@ describe('ProvingAgent', () => {
agent.start();

await jest.advanceTimersByTimeAsync(agentPollIntervalMs);
expect(jobSource.reportProvingJobError).toHaveBeenCalledWith(job.id, 'Failed to load proof inputs', true);
expect(jobSource.reportProvingJobError).toHaveBeenCalledWith(job.id, 'Failed to load proof inputs', true, {
allowList,
});
});

function makeBaseParityJob(): { job: ProvingJob; time: number; inputs: ProvingJobInputs } {
Expand Down
41 changes: 30 additions & 11 deletions yarn-project/prover-client/src/proving_broker/proving_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,28 +94,45 @@ export class ProvingAgent implements Traceable {
return;
}

if (this.idleTimer) {
this.instrumentation.recordIdleTime(this.idleTimer);
}
this.idleTimer = undefined;

const { job, time } = maybeJob;
await this.startJob(job, time);
}

private async startJob(job: ProvingJob, startedAt: number): Promise<void> {
let abortedProofJobId: string | undefined;
let abortedProofName: string | undefined;

if (this.currentJobController?.getStatus() === ProvingJobControllerStatus.PROVING) {
abortedProofJobId = this.currentJobController.getJobId();
abortedProofName = this.currentJobController.getProofTypeName();
this.currentJobController?.abort();
}

const { job, time } = maybeJob;
let inputs: ProvingJobInputs;
try {
inputs = await this.proofStore.getProofInput(job.inputsUri);
} catch (err) {
await this.broker.reportProvingJobError(job.id, 'Failed to load proof inputs', true);
const maybeJob = await this.broker.reportProvingJobError(job.id, 'Failed to load proof inputs', true, {
allowList: this.proofAllowList,
});

if (maybeJob) {
return this.startJob(maybeJob.job, maybeJob.time);
}

return;
}

this.currentJobController = new ProvingJobController(
job.id,
inputs,
job.epochNumber,
time,
startedAt,
this.circuitProver,
this.handleJobResult,
);
Expand All @@ -134,11 +151,6 @@ export class ProvingAgent implements Traceable {
);
}

if (this.idleTimer) {
this.instrumentation.recordIdleTime(this.idleTimer);
}
this.idleTimer = undefined;

this.currentJobController.start();
}

Expand All @@ -148,15 +160,22 @@ export class ProvingAgent implements Traceable {
err: Error | undefined,
result: ProvingJobResultsMap[T] | undefined,
) => {
this.idleTimer = new Timer();
let maybeJob: { job: ProvingJob; time: number } | undefined;
if (err) {
const retry = err.name === ProvingError.NAME ? (err as ProvingError).retry : false;
this.log.error(`Job id=${jobId} type=${ProvingRequestType[type]} failed err=${err.message} retry=${retry}`, err);
return this.broker.reportProvingJobError(jobId, err.message, retry);
maybeJob = await this.broker.reportProvingJobError(jobId, err.message, retry, { allowList: this.proofAllowList });
} else if (result) {
const outputUri = await this.proofStore.saveProofOutput(jobId, type, result);
this.log.info(`Job id=${jobId} type=${ProvingRequestType[type]} completed outputUri=${truncate(outputUri)}`);
return this.broker.reportProvingJobSuccess(jobId, outputUri);
maybeJob = await this.broker.reportProvingJobSuccess(jobId, outputUri, { allowList: this.proofAllowList });
}

if (maybeJob) {
const { job, time } = maybeJob;
await this.startJob(job, time);
} else {
this.idleTimer = new Timer();
}
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,75 @@ describe.each([
await getAndAssertNextJobId(baseRollup1);
});

it('returns a new job when reporting job success', async () => {
const id = makeRandomProvingJobId();
await broker.enqueueProvingJob({
id,
type: ProvingRequestType.BASE_PARITY,
epochNumber: 1,
inputsUri: makeInputsUri(),
});
await broker.getProvingJob();
await assertJobStatus(id, 'in-progress');

const id2 = makeRandomProvingJobId();
await broker.enqueueProvingJob({
id: id2,
type: ProvingRequestType.BASE_PARITY,
epochNumber: 1,
inputsUri: makeInputsUri(),
});
await expect(
broker.reportProvingJobSuccess(id, 'result' as ProofUri, { allowList: [ProvingRequestType.BASE_PARITY] }),
).resolves.toEqual({ job: expect.objectContaining({ id: id2 }), time: expect.any(Number) });
});

it('returns a new job when reporting permanent error', async () => {
const id = makeRandomProvingJobId();
await broker.enqueueProvingJob({
id,
type: ProvingRequestType.BASE_PARITY,
epochNumber: 1,
inputsUri: makeInputsUri(),
});
await broker.getProvingJob();
await assertJobStatus(id, 'in-progress');

const id2 = makeRandomProvingJobId();
await broker.enqueueProvingJob({
id: id2,
type: ProvingRequestType.BASE_PARITY,
epochNumber: 1,
inputsUri: makeInputsUri(),
});
await expect(
broker.reportProvingJobError(id, 'result' as ProofUri, false, { allowList: [ProvingRequestType.BASE_PARITY] }),
).resolves.toEqual({ job: expect.objectContaining({ id: id2 }), time: expect.any(Number) });
});

it('returns a new job when reporting retry-able error', async () => {
const id = makeRandomProvingJobId();
await broker.enqueueProvingJob({
id,
type: ProvingRequestType.BASE_PARITY,
epochNumber: 1,
inputsUri: makeInputsUri(),
});
await broker.getProvingJob();
await assertJobStatus(id, 'in-progress');

const id2 = makeRandomProvingJobId();
await broker.enqueueProvingJob({
id: id2,
type: ProvingRequestType.BASE_PARITY,
epochNumber: 1,
inputsUri: makeInputsUri(),
});
await expect(
broker.reportProvingJobError(id, 'result' as ProofUri, true, { allowList: [ProvingRequestType.BASE_PARITY] }),
).resolves.toEqual({ job: expect.objectContaining({ id: id2 }), time: expect.any(Number) });
});

it('returns a new job when reporting progress if current one is cancelled', async () => {
const id = makeRandomProvingJobId();
await broker.enqueueProvingJob({
Expand Down Expand Up @@ -590,12 +659,10 @@ describe.each([
// after the restart the new broker thinks job1 is available
// inform the agent of the job completion

await expect(broker.reportProvingJobSuccess(job1.id, makeOutputsUri())).resolves.toBeUndefined();
await assertJobStatus(job1.id, 'fulfilled');

// make sure the the broker sends the next job to the agent
await getAndAssertNextJobId(job2.id);

await expect(broker.reportProvingJobSuccess(job1.id, makeOutputsUri())).resolves.toEqual({
job: job2,
time: expect.any(Number),
});
await assertJobStatus(job1.id, 'fulfilled');
await assertJobStatus(job2.id, 'in-progress');
});
Expand All @@ -618,12 +685,14 @@ describe.each([

await getAndAssertNextJobId(id1);
await assertJobStatus(id1, 'in-progress');
await broker.reportProvingJobSuccess(id1, makeOutputsUri());
await expect(broker.reportProvingJobSuccess(id1, makeOutputsUri())).resolves.toEqual({
job: expect.objectContaining({ id: id2 }),
time: expect.any(Number),
});
await assertJobStatus(id1, 'fulfilled');

await getAndAssertNextJobId(id2);
await assertJobStatus(id2, 'in-progress');
await broker.reportProvingJobError(id2, 'test error');

await expect(broker.reportProvingJobError(id2, 'test error')).resolves.toEqual(undefined);
await assertJobStatus(id2, 'rejected');
});

Expand Down
Loading

0 comments on commit 81de3c4

Please sign in to comment.