From 4cc05ae23458587ffe02d329163e3c5d16dfea3b Mon Sep 17 00:00:00 2001 From: Alex Gherghisan Date: Wed, 5 Feb 2025 15:13:06 +0000 Subject: [PATCH] feat: broker sends back job after accepting result --- .../src/interfaces/prover-broker.ts | 13 ++- .../foundation/src/queue/serial_queue.ts | 2 +- .../src/proving_broker/proving_agent.test.ts | 60 +++++++++++-- .../src/proving_broker/proving_agent.ts | 41 ++++++--- .../src/proving_broker/proving_broker.test.ts | 89 ++++++++++++++++--- .../src/proving_broker/proving_broker.ts | 80 +++++++++++------ .../prover-client/src/proving_broker/rpc.ts | 10 ++- 7 files changed, 237 insertions(+), 58 deletions(-) diff --git a/yarn-project/circuit-types/src/interfaces/prover-broker.ts b/yarn-project/circuit-types/src/interfaces/prover-broker.ts index c586ad1157e..5a03de48419 100644 --- a/yarn-project/circuit-types/src/interfaces/prover-broker.ts +++ b/yarn-project/circuit-types/src/interfaces/prover-broker.ts @@ -60,7 +60,11 @@ export interface ProvingJobConsumer { * @param id - The ID of the job to report success for * @param result - The result of the job */ - reportProvingJobSuccess(id: ProvingJobId, result: ProofUri): Promise; + reportProvingJobSuccess( + id: ProvingJobId, + result: ProofUri, + filter?: ProvingJobFilter, + ): Promise; /** * Marks a proving job as errored @@ -68,7 +72,12 @@ export interface ProvingJobConsumer { * @param err - The error that occurred while processing the job * @param retry - Whether to retry the job */ - reportProvingJobError(id: ProvingJobId, err: string, retry?: boolean): Promise; + reportProvingJobError( + id: ProvingJobId, + err: string, + retry?: boolean, + filter?: ProvingJobFilter, + ): Promise; /** * Sends a heartbeat to the broker to indicate that the agent is still working on the given proving job diff --git a/yarn-project/foundation/src/queue/serial_queue.ts b/yarn-project/foundation/src/queue/serial_queue.ts index c00e565de80..74727c18805 100644 --- a/yarn-project/foundation/src/queue/serial_queue.ts +++ b/yarn-project/foundation/src/queue/serial_queue.ts @@ -61,7 +61,7 @@ export class SerialQueue { * @param fn - The function to enqueue. * @returns A resolution promise. Rejects if the function does, or if the function could not be enqueued. */ - public put(fn: () => Promise): Promise { + public put(fn: () => T | Promise): Promise> { return new Promise((resolve, reject) => { const accepted = this.queue.put(async () => { try { diff --git a/yarn-project/prover-client/src/proving_broker/proving_agent.test.ts b/yarn-project/prover-client/src/proving_broker/proving_agent.test.ts index edaa6bba016..40ac83bb3e8 100644 --- a/yarn-project/prover-client/src/proving_broker/proving_agent.test.ts +++ b/yarn-project/prover-client/src/proving_broker/proving_agent.test.ts @@ -32,6 +32,7 @@ describe('ProvingAgent', () => { let agent: ProvingAgent; let proofDB: jest.Mocked; const agentPollIntervalMs = 1000; + let allowList: ProvingRequestType[]; beforeEach(() => { jest.useFakeTimers(); @@ -50,7 +51,8 @@ describe('ProvingAgent', () => { saveProofOutput: jest.fn(), }; - agent = new ProvingAgent(jobSource, proofDB, prover, [ProvingRequestType.BASE_PARITY]); + allowList = [ProvingRequestType.BASE_PARITY]; + agent = new ProvingAgent(jobSource, proofDB, prover, allowList); }); afterEach(async () => { @@ -110,7 +112,7 @@ describe('ProvingAgent', () => { await jest.advanceTimersByTimeAsync(agentPollIntervalMs); expect(proofDB.saveProofOutput).toHaveBeenCalledWith(job.id, job.type, result); - expect(jobSource.reportProvingJobSuccess).toHaveBeenCalledWith(job.id, 'output-uri'); + expect(jobSource.reportProvingJobSuccess).toHaveBeenCalledWith(job.id, 'output-uri', { allowList }); }); it('reports errors to the job source', async () => { @@ -122,7 +124,7 @@ describe('ProvingAgent', () => { agent.start(); await jest.advanceTimersByTimeAsync(agentPollIntervalMs); - expect(jobSource.reportProvingJobError).toHaveBeenCalledWith(job.id, 'test error', false); + expect(jobSource.reportProvingJobError).toHaveBeenCalledWith(job.id, 'test error', false, { allowList }); }); it('sets the retry flag on when reporting an error', async () => { @@ -135,7 +137,7 @@ describe('ProvingAgent', () => { agent.start(); await jest.advanceTimersByTimeAsync(agentPollIntervalMs); - expect(jobSource.reportProvingJobError).toHaveBeenCalledWith(job.id, err.message, true); + expect(jobSource.reportProvingJobError).toHaveBeenCalledWith(job.id, err.message, true, { allowList }); }); it('reports jobs in progress to the job source', async () => { @@ -222,6 +224,52 @@ describe('ProvingAgent', () => { secondProof.resolve(makeBaseParityResult()); }); + it('immediately starts working on the next job', async () => { + const job1 = makeBaseParityJob(); + const job2 = makeBaseParityJob(); + + jest + .spyOn(prover, 'getBaseParityProof') + .mockResolvedValueOnce(makeBaseParityResult()) + .mockResolvedValueOnce(makeBaseParityResult()); + + proofDB.getProofInput.mockResolvedValueOnce(job1.inputs).mockResolvedValueOnce(job2.inputs); + proofDB.saveProofOutput.mockResolvedValue('' as ProofUri); + + jobSource.getProvingJob.mockResolvedValueOnce(job1); + jobSource.reportProvingJobSuccess.mockResolvedValueOnce(job2); + + agent.start(); + + await jest.advanceTimersByTimeAsync(agentPollIntervalMs); + await jest.advanceTimersByTimeAsync(0); + await Promise.resolve(); + expect(jobSource.reportProvingJobSuccess).toHaveBeenCalledWith(job1.job.id, expect.any(String), { allowList }); + expect(jobSource.reportProvingJobSuccess).toHaveBeenCalledWith(job2.job.id, expect.any(String), { allowList }); + }); + + it('immediately starts working after reporting an error', async () => { + const job1 = makeBaseParityJob(); + const job2 = makeBaseParityJob(); + + jest + .spyOn(prover, 'getBaseParityProof') + .mockRejectedValueOnce(new Error('test error')) + .mockResolvedValueOnce(makeBaseParityResult()); + + proofDB.getProofInput.mockResolvedValueOnce(job1.inputs).mockResolvedValueOnce(job2.inputs); + proofDB.saveProofOutput.mockResolvedValue('' as ProofUri); + + jobSource.getProvingJob.mockResolvedValueOnce(job1); + jobSource.reportProvingJobError.mockResolvedValueOnce(job2); + + agent.start(); + + await jest.advanceTimersByTimeAsync(agentPollIntervalMs); + expect(jobSource.reportProvingJobError).toHaveBeenCalledWith(job1.job.id, expect.any(String), false, { allowList }); + expect(jobSource.reportProvingJobSuccess).toHaveBeenCalledWith(job2.job.id, expect.any(String), { allowList }); + }); + it('reports an error if inputs cannot be loaded', async () => { const { job, time } = makeBaseParityJob(); jobSource.getProvingJob.mockResolvedValueOnce({ job, time }); @@ -230,7 +278,9 @@ describe('ProvingAgent', () => { agent.start(); await jest.advanceTimersByTimeAsync(agentPollIntervalMs); - expect(jobSource.reportProvingJobError).toHaveBeenCalledWith(job.id, 'Failed to load proof inputs', true); + expect(jobSource.reportProvingJobError).toHaveBeenCalledWith(job.id, 'Failed to load proof inputs', true, { + allowList, + }); }); function makeBaseParityJob(): { job: ProvingJob; time: number; inputs: ProvingJobInputs } { diff --git a/yarn-project/prover-client/src/proving_broker/proving_agent.ts b/yarn-project/prover-client/src/proving_broker/proving_agent.ts index bba0099463b..729cc06df7c 100644 --- a/yarn-project/prover-client/src/proving_broker/proving_agent.ts +++ b/yarn-project/prover-client/src/proving_broker/proving_agent.ts @@ -94,20 +94,37 @@ export class ProvingAgent implements Traceable { return; } + if (this.idleTimer) { + this.instrumentation.recordIdleTime(this.idleTimer); + } + this.idleTimer = undefined; + + const { job, time } = maybeJob; + await this.startJob(job, time); + } + + private async startJob(job: ProvingJob, startedAt: number): Promise { let abortedProofJobId: string | undefined; let abortedProofName: string | undefined; + if (this.currentJobController?.getStatus() === ProvingJobControllerStatus.PROVING) { abortedProofJobId = this.currentJobController.getJobId(); abortedProofName = this.currentJobController.getProofTypeName(); this.currentJobController?.abort(); } - const { job, time } = maybeJob; let inputs: ProvingJobInputs; try { inputs = await this.proofStore.getProofInput(job.inputsUri); } catch (err) { - await this.broker.reportProvingJobError(job.id, 'Failed to load proof inputs', true); + const maybeJob = await this.broker.reportProvingJobError(job.id, 'Failed to load proof inputs', true, { + allowList: this.proofAllowList, + }); + + if (maybeJob) { + return this.startJob(maybeJob.job, maybeJob.time); + } + return; } @@ -115,7 +132,7 @@ export class ProvingAgent implements Traceable { job.id, inputs, job.epochNumber, - time, + startedAt, this.circuitProver, this.handleJobResult, ); @@ -134,11 +151,6 @@ export class ProvingAgent implements Traceable { ); } - if (this.idleTimer) { - this.instrumentation.recordIdleTime(this.idleTimer); - } - this.idleTimer = undefined; - this.currentJobController.start(); } @@ -148,15 +160,22 @@ export class ProvingAgent implements Traceable { err: Error | undefined, result: ProvingJobResultsMap[T] | undefined, ) => { - this.idleTimer = new Timer(); + let maybeJob: { job: ProvingJob; time: number } | undefined; if (err) { const retry = err.name === ProvingError.NAME ? (err as ProvingError).retry : false; this.log.error(`Job id=${jobId} type=${ProvingRequestType[type]} failed err=${err.message} retry=${retry}`, err); - return this.broker.reportProvingJobError(jobId, err.message, retry); + maybeJob = await this.broker.reportProvingJobError(jobId, err.message, retry, { allowList: this.proofAllowList }); } else if (result) { const outputUri = await this.proofStore.saveProofOutput(jobId, type, result); this.log.info(`Job id=${jobId} type=${ProvingRequestType[type]} completed outputUri=${truncate(outputUri)}`); - return this.broker.reportProvingJobSuccess(jobId, outputUri); + maybeJob = await this.broker.reportProvingJobSuccess(jobId, outputUri, { allowList: this.proofAllowList }); + } + + if (maybeJob) { + const { job, time } = maybeJob; + await this.startJob(job, time); + } else { + this.idleTimer = new Timer(); } }; } diff --git a/yarn-project/prover-client/src/proving_broker/proving_broker.test.ts b/yarn-project/prover-client/src/proving_broker/proving_broker.test.ts index 14e2f7e0624..bc6b18c58ba 100644 --- a/yarn-project/prover-client/src/proving_broker/proving_broker.test.ts +++ b/yarn-project/prover-client/src/proving_broker/proving_broker.test.ts @@ -407,6 +407,75 @@ describe.each([ await getAndAssertNextJobId(baseRollup1); }); + it('returns a new job when reporting job success', async () => { + const id = makeRandomProvingJobId(); + await broker.enqueueProvingJob({ + id, + type: ProvingRequestType.BASE_PARITY, + epochNumber: 1, + inputsUri: makeInputsUri(), + }); + await broker.getProvingJob(); + await assertJobStatus(id, 'in-progress'); + + const id2 = makeRandomProvingJobId(); + await broker.enqueueProvingJob({ + id: id2, + type: ProvingRequestType.BASE_PARITY, + epochNumber: 1, + inputsUri: makeInputsUri(), + }); + await expect( + broker.reportProvingJobSuccess(id, 'result' as ProofUri, { allowList: [ProvingRequestType.BASE_PARITY] }), + ).resolves.toEqual({ job: expect.objectContaining({ id: id2 }), time: expect.any(Number) }); + }); + + it('returns a new job when reporting permanent error', async () => { + const id = makeRandomProvingJobId(); + await broker.enqueueProvingJob({ + id, + type: ProvingRequestType.BASE_PARITY, + epochNumber: 1, + inputsUri: makeInputsUri(), + }); + await broker.getProvingJob(); + await assertJobStatus(id, 'in-progress'); + + const id2 = makeRandomProvingJobId(); + await broker.enqueueProvingJob({ + id: id2, + type: ProvingRequestType.BASE_PARITY, + epochNumber: 1, + inputsUri: makeInputsUri(), + }); + await expect( + broker.reportProvingJobError(id, 'result' as ProofUri, false, { allowList: [ProvingRequestType.BASE_PARITY] }), + ).resolves.toEqual({ job: expect.objectContaining({ id: id2 }), time: expect.any(Number) }); + }); + + it('returns a new job when reporting retry-able error', async () => { + const id = makeRandomProvingJobId(); + await broker.enqueueProvingJob({ + id, + type: ProvingRequestType.BASE_PARITY, + epochNumber: 1, + inputsUri: makeInputsUri(), + }); + await broker.getProvingJob(); + await assertJobStatus(id, 'in-progress'); + + const id2 = makeRandomProvingJobId(); + await broker.enqueueProvingJob({ + id: id2, + type: ProvingRequestType.BASE_PARITY, + epochNumber: 1, + inputsUri: makeInputsUri(), + }); + await expect( + broker.reportProvingJobError(id, 'result' as ProofUri, true, { allowList: [ProvingRequestType.BASE_PARITY] }), + ).resolves.toEqual({ job: expect.objectContaining({ id: id2 }), time: expect.any(Number) }); + }); + it('returns a new job when reporting progress if current one is cancelled', async () => { const id = makeRandomProvingJobId(); await broker.enqueueProvingJob({ @@ -590,12 +659,10 @@ describe.each([ // after the restart the new broker thinks job1 is available // inform the agent of the job completion - await expect(broker.reportProvingJobSuccess(job1.id, makeOutputsUri())).resolves.toBeUndefined(); - await assertJobStatus(job1.id, 'fulfilled'); - - // make sure the the broker sends the next job to the agent - await getAndAssertNextJobId(job2.id); - + await expect(broker.reportProvingJobSuccess(job1.id, makeOutputsUri())).resolves.toEqual({ + job: job2, + time: expect.any(Number), + }); await assertJobStatus(job1.id, 'fulfilled'); await assertJobStatus(job2.id, 'in-progress'); }); @@ -618,12 +685,14 @@ describe.each([ await getAndAssertNextJobId(id1); await assertJobStatus(id1, 'in-progress'); - await broker.reportProvingJobSuccess(id1, makeOutputsUri()); + await expect(broker.reportProvingJobSuccess(id1, makeOutputsUri())).resolves.toEqual({ + job: expect.objectContaining({ id: id2 }), + time: expect.any(Number), + }); await assertJobStatus(id1, 'fulfilled'); - - await getAndAssertNextJobId(id2); await assertJobStatus(id2, 'in-progress'); - await broker.reportProvingJobError(id2, 'test error'); + + await expect(broker.reportProvingJobError(id2, 'test error')).resolves.toEqual(undefined); await assertJobStatus(id2, 'rejected'); }); diff --git a/yarn-project/prover-client/src/proving_broker/proving_broker.ts b/yarn-project/prover-client/src/proving_broker/proving_broker.ts index 583fa0eb803..a66f25d71da 100644 --- a/yarn-project/prover-client/src/proving_broker/proving_broker.ts +++ b/yarn-project/prover-client/src/proving_broker/proving_broker.ts @@ -1,4 +1,5 @@ import { + type GetProvingJobResponse, type ProofUri, type ProvingJob, type ProvingJobConsumer, @@ -205,16 +206,25 @@ export class ProvingBroker implements ProvingJobProducer, ProvingJobConsumer, Tr return this.requestQueue.put(() => this.#getCompletedJobs(ids)); } - public getProvingJob(filter?: ProvingJobFilter): Promise<{ job: ProvingJob; time: number } | undefined> { + public getProvingJob(filter?: ProvingJobFilter): Promise { return this.requestQueue.put(() => this.#getProvingJob(filter)); } - public reportProvingJobSuccess(id: ProvingJobId, value: ProofUri): Promise { - return this.requestQueue.put(() => this.#reportProvingJobSuccess(id, value)); + public reportProvingJobSuccess( + id: ProvingJobId, + value: ProofUri, + filter?: ProvingJobFilter, + ): Promise { + return this.requestQueue.put(() => this.#reportProvingJobSuccess(id, value, filter)); } - public reportProvingJobError(id: ProvingJobId, err: string, retry = false): Promise { - return this.requestQueue.put(() => this.#reportProvingJobError(id, err, retry)); + public reportProvingJobError( + id: ProvingJobId, + err: string, + retry = false, + filter?: ProvingJobFilter, + ): Promise { + return this.requestQueue.put(() => this.#reportProvingJobError(id, err, retry, filter)); } public reportProvingJobProgress( @@ -305,9 +315,7 @@ export class ProvingBroker implements ProvingJobProducer, ProvingJobConsumer, Tr } // eslint-disable-next-line require-await - async #getProvingJob( - filter: ProvingJobFilter = { allowList: [] }, - ): Promise<{ job: ProvingJob; time: number } | undefined> { + #getProvingJob(filter: ProvingJobFilter = { allowList: [] }): { job: ProvingJob; time: number } | undefined { const allowedProofs: ProvingRequestType[] = Array.isArray(filter.allowList) && filter.allowList.length > 0 ? [...filter.allowList] @@ -343,7 +351,12 @@ export class ProvingBroker implements ProvingJobProducer, ProvingJobConsumer, Tr return undefined; } - async #reportProvingJobError(id: ProvingJobId, err: string, retry = false): Promise { + async #reportProvingJobError( + id: ProvingJobId, + err: string, + retry = false, + filter?: ProvingJobFilter, + ): Promise { const info = this.inProgress.get(id); const item = this.jobsCache.get(id); const retries = this.retries.get(id) ?? 0; @@ -365,7 +378,7 @@ export class ProvingBroker implements ProvingJobProducer, ProvingJobConsumer, Tr this.logger.warn(`Proving job id=${id} is already settled, ignoring err=${err}`, { provingJobId: id, }); - return; + return this.#getProvingJob(filter); } if (retry && retries + 1 < this.maxRetries && !this.isJobStale(item)) { @@ -375,10 +388,16 @@ export class ProvingBroker implements ProvingJobProducer, ProvingJobConsumer, Tr provingJobId: id, }, ); + + // assign another job to this agent + // do this first, before we put the failed job back in the queue + const maybeAnotherJob = this.#getProvingJob(filter); + this.retries.set(id, retries + 1); this.enqueueJobInternal(item); this.instrumentation.incRetriedJobs(item.type); - return; + + return maybeAnotherJob; } this.logger.info( @@ -412,22 +431,24 @@ export class ProvingBroker implements ProvingJobProducer, ProvingJobConsumer, Tr throw saveErr; } + + return this.#getProvingJob(filter); } #reportProvingJobProgress( id: ProvingJobId, startedAt: number, filter?: ProvingJobFilter, - ): Promise<{ job: ProvingJob; time: number } | undefined> { + ): { job: ProvingJob; time: number } | undefined { const job = this.jobsCache.get(id); if (!job) { this.logger.warn(`Proving job id=${id} does not exist`, { provingJobId: id }); - return filter ? this.#getProvingJob(filter) : Promise.resolve(undefined); + return this.#getProvingJob(filter); } if (this.resultsCache.has(id)) { this.logger.warn(`Proving job id=${id} has already been completed`, { provingJobId: id }); - return filter ? this.#getProvingJob(filter) : Promise.resolve(undefined); + return this.#getProvingJob(filter); } const metadata = this.inProgress.get(id); @@ -445,7 +466,7 @@ export class ProvingBroker implements ProvingJobProducer, ProvingJobConsumer, Tr startedAt, lastUpdatedAt: this.msTimeSource(), }); - return Promise.resolve(undefined); + return undefined; } else if (startedAt <= metadata.startedAt) { if (startedAt < metadata.startedAt) { this.logger.info( @@ -457,21 +478,24 @@ export class ProvingBroker implements ProvingJobProducer, ProvingJobConsumer, Tr } metadata.startedAt = startedAt; metadata.lastUpdatedAt = now; - return Promise.resolve(undefined); - } else if (filter) { - this.logger.warn( - `Proving job id=${id} type=${ - ProvingRequestType[job.type] - } already being worked on by another agent. Sending new one`, - { provingJobId: id }, - ); - return this.#getProvingJob(filter); - } else { - return Promise.resolve(undefined); + return undefined; } + + this.logger.warn( + `Proving job id=${id} type=${ + ProvingRequestType[job.type] + } already being worked on by another agent. Sending new one`, + { provingJobId: id }, + ); + + return this.#getProvingJob(filter); } - async #reportProvingJobSuccess(id: ProvingJobId, value: ProofUri): Promise { + async #reportProvingJobSuccess( + id: ProvingJobId, + value: ProofUri, + filter?: ProvingJobFilter, + ): Promise { const info = this.inProgress.get(id); const item = this.jobsCache.get(id); const retries = this.retries.get(id) ?? 0; @@ -521,6 +545,8 @@ export class ProvingBroker implements ProvingJobProducer, ProvingJobConsumer, Tr throw saveErr; } + + return this.#getProvingJob(filter); } @trackSpan('ProvingBroker.cleanupPass') diff --git a/yarn-project/prover-client/src/proving_broker/rpc.ts b/yarn-project/prover-client/src/proving_broker/rpc.ts index 74a66580357..63d2a5ca4c9 100644 --- a/yarn-project/prover-client/src/proving_broker/rpc.ts +++ b/yarn-project/prover-client/src/proving_broker/rpc.ts @@ -34,12 +34,18 @@ export const ProvingJobProducerSchema: ApiSchemaFor = { export const ProvingJobConsumerSchema: ApiSchemaFor = { getProvingJob: z.function().args(optional(ProvingJobFilterSchema)).returns(GetProvingJobResponse.optional()), - reportProvingJobError: z.function().args(ProvingJobId, z.string(), optional(z.boolean())).returns(z.void()), + reportProvingJobError: z + .function() + .args(ProvingJobId, z.string(), optional(z.boolean()), optional(ProvingJobFilterSchema)) + .returns(GetProvingJobResponse.optional()), reportProvingJobProgress: z .function() .args(ProvingJobId, z.number(), optional(ProvingJobFilterSchema)) .returns(GetProvingJobResponse.optional()), - reportProvingJobSuccess: z.function().args(ProvingJobId, ProofUri).returns(z.void()), + reportProvingJobSuccess: z + .function() + .args(ProvingJobId, ProofUri, optional(ProvingJobFilterSchema)) + .returns(GetProvingJobResponse.optional()), }; export const ProvingJobBrokerSchema: ApiSchemaFor = {