Skip to content

Commit

Permalink
feat(coordinator): assign static prover first and avoid reassigning f…
Browse files Browse the repository at this point in the history
…ailing task to same prover
  • Loading branch information
yiweichi committed Dec 30, 2024
1 parent 45b23ed commit 791fcaf
Show file tree
Hide file tree
Showing 11 changed files with 371 additions and 108 deletions.
1 change: 1 addition & 0 deletions coordinator/conf/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"prover_manager": {
"provers_per_session": 1,
"session_attempts": 5,
"external_prover_threshold": 32,
"bundle_collection_time_sec": 180,
"batch_collection_time_sec": 180,
"chunk_collection_time_sec": 180,
Expand Down
2 changes: 2 additions & 0 deletions coordinator/internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ type ProverManager struct {
// Number of attempts that a session can be retried if previous attempts failed.
// Currently we only consider proving timeout as failure here.
SessionAttempts uint8 `json:"session_attempts"`
// Threshold for activating the external prover based on unassigned task count.
ExternalProverThreshold int64 `json:"external_prover_threshold"`
// Zk verifier config.
Verifier *VerifierConfig `json:"verifier"`
// BatchCollectionTimeSec batch Proof collection time (in seconds).
Expand Down
57 changes: 44 additions & 13 deletions coordinator/internal/logic/provertask/batch_prover_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"strings"
"time"

"github.com/gin-gonic/gin"
Expand Down Expand Up @@ -63,29 +64,59 @@ func (bp *BatchProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinato

maxActiveAttempts := bp.cfg.ProverManager.ProversPerSession
maxTotalAttempts := bp.cfg.ProverManager.SessionAttempts
if strings.HasPrefix(taskCtx.ProverName, ExternalProverNamePrefix) {
unassignedBatchCount, err := bp.batchOrm.GetUnassignedBatchCount(ctx.Copy(), maxActiveAttempts, maxTotalAttempts)
if err != nil {
log.Error("failed to get unassigned chunk proving tasks count", "height", getTaskParameter.ProverHeight, "err", err)
return nil, ErrCoordinatorInternalFailure
}
// Assign external prover if unassigned task number exceeds threshold
if unassignedBatchCount < bp.cfg.ProverManager.ExternalProverThreshold {
return nil, nil
}
}

var batchTask *orm.Batch
for i := 0; i < 5; i++ {
var getTaskError error
var tmpBatchTask *orm.Batch
tmpBatchTask, getTaskError = bp.batchOrm.GetAssignedBatch(ctx.Copy(), maxActiveAttempts, maxTotalAttempts)
var assignedOffset, unassignedOffset = 0, 0
tmpAssignedBatchTasks, getTaskError := bp.batchOrm.GetAssignedBatches(ctx.Copy(), maxActiveAttempts, maxTotalAttempts, 50)
if getTaskError != nil {
log.Error("failed to get assigned batch proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError)
log.Error("failed to get assigned chunk proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError)
return nil, ErrCoordinatorInternalFailure
}

// Why here need get again? In order to support a task can assign to multiple prover, need also assign `ProvingTaskAssigned`
// batch to prover. But use `proving_status in (1, 2)` will not use the postgres index. So need split the sql.
if tmpBatchTask == nil {
tmpBatchTask, getTaskError = bp.batchOrm.GetUnassignedBatch(ctx.Copy(), maxActiveAttempts, maxTotalAttempts)
if getTaskError != nil {
log.Error("failed to get unassigned batch proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError)
return nil, ErrCoordinatorInternalFailure
}
// chunk to prover. But use `proving_status in (1, 2)` will not use the postgres index. So need split the sql.
tmpUnassignedBatchTask, getTaskError := bp.batchOrm.GetUnassignedBatches(ctx.Copy(), maxActiveAttempts, maxTotalAttempts, 50)
if getTaskError != nil {
log.Error("failed to get unassigned chunk proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError)
return nil, ErrCoordinatorInternalFailure
}
for {
tmpBatchTask = nil
if assignedOffset < len(tmpAssignedBatchTasks) {
tmpBatchTask = tmpAssignedBatchTasks[assignedOffset]
assignedOffset++
} else if unassignedOffset < len(tmpUnassignedBatchTask) {
tmpBatchTask = tmpUnassignedBatchTask[unassignedOffset]
unassignedOffset++
}

if tmpBatchTask == nil {
log.Debug("get empty batch", "height", getTaskParameter.ProverHeight)
return nil, nil
}

if tmpBatchTask == nil {
log.Debug("get empty batch", "height", getTaskParameter.ProverHeight)
return nil, nil
// Don't dispatch the same failing job to the same prover
proverTask, err := bp.proverTaskOrm.GetTaskOfProver(ctx.Copy(), message.ProofTypeBatch, tmpBatchTask.Hash, taskCtx.PublicKey, taskCtx.ProverVersion)
if err != nil {
log.Error("failed to get prover task of prover", "proof_type", message.ProofTypeBatch.String(), "taskID", tmpBatchTask.Hash, "key", taskCtx.PublicKey, "Prover_version", taskCtx.ProverVersion, "error", err)
return nil, ErrCoordinatorInternalFailure
}
if proverTask == nil || types.ProverProveStatus(proverTask.ProvingStatus) != types.ProverProofInvalid {
break
}
}

rowsAffected, updateAttemptsErr := bp.batchOrm.UpdateBatchAttempts(ctx.Copy(), tmpBatchTask.Index, tmpBatchTask.ActiveAttempts, tmpBatchTask.TotalAttempts)
Expand Down
57 changes: 44 additions & 13 deletions coordinator/internal/logic/provertask/bundle_prover_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"strings"
"time"

"github.com/gin-gonic/gin"
Expand Down Expand Up @@ -63,29 +64,59 @@ func (bp *BundleProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinat

maxActiveAttempts := bp.cfg.ProverManager.ProversPerSession
maxTotalAttempts := bp.cfg.ProverManager.SessionAttempts
if strings.HasPrefix(taskCtx.ProverName, ExternalProverNamePrefix) {
unassignedBundleCount, err := bp.bundleOrm.GetUnassignedBundleCount(ctx.Copy(), maxActiveAttempts, maxTotalAttempts)
if err != nil {
log.Error("failed to get unassigned chunk proving tasks count", "height", getTaskParameter.ProverHeight, "err", err)
return nil, ErrCoordinatorInternalFailure
}
// Assign external prover if unassigned task number exceeds threshold
if unassignedBundleCount < bp.cfg.ProverManager.ExternalProverThreshold {
return nil, nil
}
}

var bundleTask *orm.Bundle
for i := 0; i < 5; i++ {
var getTaskError error
var tmpBundleTask *orm.Bundle
tmpBundleTask, getTaskError = bp.bundleOrm.GetAssignedBundle(ctx.Copy(), maxActiveAttempts, maxTotalAttempts)
var assignedOffset, unassignedOffset = 0, 0
tmpAssignedBundleTasks, getTaskError := bp.bundleOrm.GetAssignedBundles(ctx.Copy(), maxActiveAttempts, maxTotalAttempts, 50)
if getTaskError != nil {
log.Error("failed to get assigned bundle proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError)
log.Error("failed to get assigned chunk proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError)
return nil, ErrCoordinatorInternalFailure
}

// Why here need get again? In order to support a task can assign to multiple prover, need also assign `ProvingTaskAssigned`
// bundle to prover. But use `proving_status in (1, 2)` will not use the postgres index. So need split the sql.
if tmpBundleTask == nil {
tmpBundleTask, getTaskError = bp.bundleOrm.GetUnassignedBundle(ctx.Copy(), maxActiveAttempts, maxTotalAttempts)
if getTaskError != nil {
log.Error("failed to get unassigned bundle proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError)
return nil, ErrCoordinatorInternalFailure
}
// chunk to prover. But use `proving_status in (1, 2)` will not use the postgres index. So need split the sql.
tmpUnassignedBundleTask, getTaskError := bp.bundleOrm.GetUnassignedBundles(ctx.Copy(), maxActiveAttempts, maxTotalAttempts, 50)
if getTaskError != nil {
log.Error("failed to get unassigned chunk proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError)
return nil, ErrCoordinatorInternalFailure
}
for {
tmpBundleTask = nil
if assignedOffset < len(tmpAssignedBundleTasks) {
tmpBundleTask = tmpAssignedBundleTasks[assignedOffset]
assignedOffset++
} else if unassignedOffset < len(tmpUnassignedBundleTask) {
tmpBundleTask = tmpUnassignedBundleTask[unassignedOffset]
unassignedOffset++
}

if tmpBundleTask == nil {
log.Debug("get empty bundle", "height", getTaskParameter.ProverHeight)
return nil, nil
}

if tmpBundleTask == nil {
log.Debug("get empty bundle", "height", getTaskParameter.ProverHeight)
return nil, nil
// Don't dispatch the same failing job to the same prover
proverTask, err := bp.proverTaskOrm.GetTaskOfProver(ctx.Copy(), message.ProofTypeBatch, tmpBundleTask.Hash, taskCtx.PublicKey, taskCtx.ProverVersion)
if err != nil {
log.Error("failed to get prover task of prover", "proof_type", message.ProofTypeBatch.String(), "taskID", tmpBundleTask.Hash, "key", taskCtx.PublicKey, "Prover_version", taskCtx.ProverVersion, "error", err)
return nil, ErrCoordinatorInternalFailure
}
if proverTask == nil || types.ProverProveStatus(proverTask.ProvingStatus) != types.ProverProofInvalid {
break
}
}

rowsAffected, updateAttemptsErr := bp.bundleOrm.UpdateBundleAttempts(ctx.Copy(), tmpBundleTask.Hash, tmpBundleTask.ActiveAttempts, tmpBundleTask.TotalAttempts)
Expand Down
53 changes: 42 additions & 11 deletions coordinator/internal/logic/provertask/chunk_prover_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"strings"
"time"

"github.com/gin-gonic/gin"
Expand Down Expand Up @@ -61,29 +62,59 @@ func (cp *ChunkProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinato

maxActiveAttempts := cp.cfg.ProverManager.ProversPerSession
maxTotalAttempts := cp.cfg.ProverManager.SessionAttempts
if strings.HasPrefix(taskCtx.ProverName, ExternalProverNamePrefix) {
unassignedChunkCount, err := cp.chunkOrm.GetUnassignedChunkCount(ctx.Copy(), maxActiveAttempts, maxTotalAttempts, getTaskParameter.ProverHeight)
if err != nil {
log.Error("failed to get unassigned chunk proving tasks count", "height", getTaskParameter.ProverHeight, "err", err)
return nil, ErrCoordinatorInternalFailure
}
// Assign external prover if unassigned task number exceeds threshold
if unassignedChunkCount < cp.cfg.ProverManager.ExternalProverThreshold {
return nil, nil
}
}

var chunkTask *orm.Chunk
for i := 0; i < 5; i++ {
var getTaskError error
var tmpChunkTask *orm.Chunk
tmpChunkTask, getTaskError = cp.chunkOrm.GetAssignedChunk(ctx.Copy(), maxActiveAttempts, maxTotalAttempts, getTaskParameter.ProverHeight)
var assignedOffset, unassignedOffset = 0, 0
tmpAssignedChunkTasks, getTaskError := cp.chunkOrm.GetAssignedChunks(ctx.Copy(), maxActiveAttempts, maxTotalAttempts, getTaskParameter.ProverHeight, 50)
if getTaskError != nil {
log.Error("failed to get assigned chunk proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError)
return nil, ErrCoordinatorInternalFailure
}

// Why here need get again? In order to support a task can assign to multiple prover, need also assign `ProvingTaskAssigned`
// chunk to prover. But use `proving_status in (1, 2)` will not use the postgres index. So need split the sql.
if tmpChunkTask == nil {
tmpChunkTask, getTaskError = cp.chunkOrm.GetUnassignedChunk(ctx.Copy(), maxActiveAttempts, maxTotalAttempts, getTaskParameter.ProverHeight)
if getTaskError != nil {
log.Error("failed to get unassigned chunk proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError)
return nil, ErrCoordinatorInternalFailure
}
tmpUnassignedChunkTask, getTaskError := cp.chunkOrm.GetUnassignedChunk(ctx.Copy(), maxActiveAttempts, maxTotalAttempts, getTaskParameter.ProverHeight, 50)
if getTaskError != nil {
log.Error("failed to get unassigned chunk proving tasks", "height", getTaskParameter.ProverHeight, "err", getTaskError)
return nil, ErrCoordinatorInternalFailure
}
for {
tmpChunkTask = nil
if assignedOffset < len(tmpAssignedChunkTasks) {
tmpChunkTask = tmpAssignedChunkTasks[assignedOffset]
assignedOffset++
} else if unassignedOffset < len(tmpUnassignedChunkTask) {
tmpChunkTask = tmpUnassignedChunkTask[unassignedOffset]
unassignedOffset++
}

if tmpChunkTask == nil {
log.Debug("get empty chunk", "height", getTaskParameter.ProverHeight)
return nil, nil
}

if tmpChunkTask == nil {
log.Debug("get empty chunk", "height", getTaskParameter.ProverHeight)
return nil, nil
// Don't dispatch the same failing job to the same prover
proverTask, err := cp.proverTaskOrm.GetTaskOfProver(ctx.Copy(), message.ProofTypeChunk, tmpChunkTask.Hash, taskCtx.PublicKey, taskCtx.ProverVersion)
if err != nil {
log.Error("failed to get prover task of prover", "proof_type", message.ProofTypeChunk.String(), "taskID", tmpChunkTask.Hash, "key", taskCtx.PublicKey, "Prover_version", taskCtx.ProverVersion, "error", err)
return nil, ErrCoordinatorInternalFailure
}
if proverTask == nil || types.ProverProveStatus(proverTask.ProvingStatus) != types.ProverProofInvalid {
break
}
}

rowsAffected, updateAttemptsErr := cp.chunkOrm.UpdateChunkAttempts(ctx.Copy(), tmpChunkTask.Index, tmpChunkTask.ActiveAttempts, tmpChunkTask.TotalAttempts)
Expand Down
5 changes: 5 additions & 0 deletions coordinator/internal/logic/provertask/prover_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ var (
getTaskCounterVec *prometheus.CounterVec = nil
)

var (
// ExternalProverNamePrefix prefix of prover name
ExternalProverNamePrefix = "external"
)

// ProverTask the interface of a collector who send data to prover
type ProverTask interface {
Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error)
Expand Down
60 changes: 28 additions & 32 deletions coordinator/internal/orm/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,38 +78,47 @@ func (*Batch) TableName() string {
return "batch"
}

// GetUnassignedBatch retrieves unassigned batch based on the specified limit.
// GetUnassignedBatches retrieves unassigned batches based on the specified limit.
// The returned batches are sorted in ascending order by their index.
func (o *Batch) GetUnassignedBatch(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8) (*Batch, error) {
var batch Batch
func (o *Batch) GetUnassignedBatches(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8, limit uint64) ([]*Batch, error) {
var batch []*Batch
db := o.db.WithContext(ctx)
sql := fmt.Sprintf("SELECT * FROM batch WHERE proving_status = %d AND total_attempts < %d AND active_attempts < %d AND chunk_proofs_status = %d AND batch.deleted_at IS NULL ORDER BY batch.index LIMIT 1;",
int(types.ProvingTaskUnassigned), maxTotalAttempts, maxActiveAttempts, int(types.ChunkProofsStatusReady))
sql := fmt.Sprintf("SELECT * FROM batch WHERE proving_status = %d AND total_attempts < %d AND active_attempts < %d AND chunk_proofs_status = %d AND batch.deleted_at IS NULL ORDER BY batch.index LIMIT %d;",
int(types.ProvingTaskUnassigned), maxTotalAttempts, maxActiveAttempts, int(types.ChunkProofsStatusReady), limit)
err := db.Raw(sql).Scan(&batch).Error
if err != nil {
return nil, fmt.Errorf("Batch.GetUnassignedBatch error: %w", err)
return nil, fmt.Errorf("Batch.GetUnassignedBatches error: %w", err)
}
if batch.Hash == "" {
return nil, nil
return batch, nil
}

// GetUnassignedBatchCount retrieves unassigned batch count based on the specified limit.
func (o *Batch) GetUnassignedBatchCount(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8) (int64, error) {
var count int64
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Where("proving_status = ?", int(types.ProvingTaskUnassigned))
db = db.Where("total_attempts < ?", maxTotalAttempts)
db = db.Where("active_attempts < ?", maxActiveAttempts)
db = db.Where("batch.deleted_at IS NULL")
if err := db.Count(&count).Error; err != nil {
return 0, fmt.Errorf("Batch.GetUnassignedBatchCount error: %w", err)
}
return &batch, nil
return count, nil
}

// GetAssignedBatch retrieves assigned batch based on the specified limit.
// GetAssignedBatches retrieves assigned batches based on the specified limit.
// The returned batches are sorted in ascending order by their index.
func (o *Batch) GetAssignedBatch(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8) (*Batch, error) {
var batch Batch
func (o *Batch) GetAssignedBatches(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8, limit uint64) ([]*Batch, error) {
var batch []*Batch
db := o.db.WithContext(ctx)
sql := fmt.Sprintf("SELECT * FROM batch WHERE proving_status = %d AND total_attempts < %d AND active_attempts < %d AND chunk_proofs_status = %d AND batch.deleted_at IS NULL ORDER BY batch.index LIMIT 1;",
int(types.ProvingTaskAssigned), maxTotalAttempts, maxActiveAttempts, int(types.ChunkProofsStatusReady))
sql := fmt.Sprintf("SELECT * FROM batch WHERE proving_status = %d AND total_attempts < %d AND active_attempts < %d AND chunk_proofs_status = %d AND batch.deleted_at IS NULL ORDER BY batch.index LIMIT %d;",
int(types.ProvingTaskAssigned), maxTotalAttempts, maxActiveAttempts, int(types.ChunkProofsStatusReady), limit)
err := db.Raw(sql).Scan(&batch).Error
if err != nil {
return nil, fmt.Errorf("Batch.GetAssignedBatch error: %w", err)
}
if batch.Hash == "" {
return nil, nil
return nil, fmt.Errorf("Batch.GetAssignedBatches error: %w", err)
}
return &batch, nil
return batch, nil
}

// GetUnassignedAndChunksUnreadyBatches get the batches which is unassigned and chunks is not ready
Expand All @@ -132,19 +141,6 @@ func (o *Batch) GetUnassignedAndChunksUnreadyBatches(ctx context.Context, offset
return batches, nil
}

// GetAssignedBatches retrieves all batches whose proving_status is either types.ProvingTaskAssigned.
func (o *Batch) GetAssignedBatches(ctx context.Context) ([]*Batch, error) {
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Where("proving_status = ?", int(types.ProvingTaskAssigned))

var assignedBatches []*Batch
if err := db.Find(&assignedBatches).Error; err != nil {
return nil, fmt.Errorf("Batch.GetAssignedBatches error: %w", err)
}
return assignedBatches, nil
}

// GetProvingStatusByHash retrieves the proving status of a batch given its hash.
func (o *Batch) GetProvingStatusByHash(ctx context.Context, hash string) (types.ProvingStatus, error) {
db := o.db.WithContext(ctx)
Expand Down
Loading

0 comments on commit 791fcaf

Please sign in to comment.