Skip to content

Commit

Permalink
feat(coordinator): assign static prover first and avoid reassigning f…
Browse files Browse the repository at this point in the history
…ailed task to same prover (#1584)

Co-authored-by: yiweichi <[email protected]>
Co-authored-by: colin <[email protected]>
  • Loading branch information
3 people authored Jan 15, 2025
1 parent fa0927c commit a6f2457
Show file tree
Hide file tree
Showing 19 changed files with 287 additions and 44 deletions.
2 changes: 1 addition & 1 deletion common/version/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"runtime/debug"
)

var tag = "v4.4.85"
var tag = "v4.4.86"

var commit = func() string {
if info, ok := debug.ReadBuildInfo(); ok {
Expand Down
1 change: 1 addition & 0 deletions coordinator/conf/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"prover_manager": {
"provers_per_session": 1,
"session_attempts": 5,
"external_prover_threshold": 32,
"bundle_collection_time_sec": 180,
"batch_collection_time_sec": 180,
"chunk_collection_time_sec": 180,
Expand Down
2 changes: 2 additions & 0 deletions coordinator/internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ type ProverManager struct {
// Number of attempts that a session can be retried if previous attempts failed.
// Currently we only consider proving timeout as failure here.
SessionAttempts uint8 `json:"session_attempts"`
// Threshold for activating the external prover based on unassigned task count.
ExternalProverThreshold int64 `json:"external_prover_threshold"`
// Zk verifier config.
Verifier *VerifierConfig `json:"verifier"`
// BatchCollectionTimeSec batch Proof collection time (in seconds).
Expand Down
13 changes: 9 additions & 4 deletions coordinator/internal/controller/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,11 @@ func (a *AuthController) PayloadFunc(data interface{}) jwt.MapClaims {
}

return jwt.MapClaims{
types.HardForkName: v.HardForkName,
types.PublicKey: v.PublicKey,
types.ProverName: v.Message.ProverName,
types.ProverVersion: v.Message.ProverVersion,
types.HardForkName: v.HardForkName,
types.PublicKey: v.PublicKey,
types.ProverName: v.Message.ProverName,
types.ProverVersion: v.Message.ProverVersion,
types.ProverProviderTypeKey: v.Message.ProverProviderType,
}
}

Expand All @@ -96,5 +97,9 @@ func (a *AuthController) IdentityHandler(c *gin.Context) interface{} {
c.Set(types.HardForkName, hardForkName)
}

if providerType, ok := claims[types.ProverProviderTypeKey]; ok {
c.Set(types.ProverProviderTypeKey, providerType)
}

return nil
}
11 changes: 11 additions & 0 deletions coordinator/internal/logic/auth/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ func (l *LoginLogic) Check(login *types.LoginParameter) error {
}
}
}

if login.Message.ProverProviderType != types.ProverProviderTypeInternal && login.Message.ProverProviderType != types.ProverProviderTypeExternal {
// for backward compatibility, set ProverProviderType as internal
if login.Message.ProverProviderType == types.ProverProviderTypeUndefined {
login.Message.ProverProviderType = types.ProverProviderTypeInternal
} else {
log.Error("invalid prover_provider_type", "value", login.Message.ProverProviderType, "prover name", login.Message.ProverName, "prover version", login.Message.ProverVersion)
return errors.New("invalid prover provider type.")
}
}

return nil
}

Expand Down
27 changes: 27 additions & 0 deletions coordinator/internal/logic/provertask/batch_prover_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"scroll-tech/coordinator/internal/config"
"scroll-tech/coordinator/internal/orm"
coordinatorType "scroll-tech/coordinator/internal/types"
cutils "scroll-tech/coordinator/internal/utils"
)

// BatchProverTask is prover task implement for batch proof
Expand Down Expand Up @@ -63,6 +64,18 @@ func (bp *BatchProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinato

maxActiveAttempts := bp.cfg.ProverManager.ProversPerSession
maxTotalAttempts := bp.cfg.ProverManager.SessionAttempts
if taskCtx.ProverProviderType == uint8(coordinatorType.ProverProviderTypeExternal) {
unassignedBatchCount, getCountError := bp.batchOrm.GetUnassignedBatchCount(ctx.Copy(), maxActiveAttempts, maxTotalAttempts)
if getCountError != nil {
log.Error("failed to get unassigned batch proving tasks count", "height", getTaskParameter.ProverHeight, "err", getCountError)
return nil, ErrCoordinatorInternalFailure
}
// Assign external prover if unassigned task number exceeds threshold
if unassignedBatchCount < bp.cfg.ProverManager.ExternalProverThreshold {
return nil, nil
}
}

var batchTask *orm.Batch
for i := 0; i < 5; i++ {
var getTaskError error
Expand All @@ -88,6 +101,20 @@ func (bp *BatchProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinato
return nil, nil
}

// Don't dispatch the same failing job to the same prover
proverTasks, getTaskError := bp.proverTaskOrm.GetFailedProverTasksByHash(ctx.Copy(), message.ProofTypeBatch, tmpBatchTask.Hash, 2)
if getTaskError != nil {
log.Error("failed to get prover tasks", "proof type", message.ProofTypeBatch.String(), "task ID", tmpBatchTask.Hash, "error", getTaskError)
return nil, ErrCoordinatorInternalFailure
}
for i := 0; i < len(proverTasks); i++ {
if proverTasks[i].ProverPublicKey == taskCtx.PublicKey ||
taskCtx.ProverProviderType == uint8(coordinatorType.ProverProviderTypeExternal) && cutils.IsExternalProverNameMatch(proverTasks[i].ProverName, taskCtx.ProverName) {
log.Debug("get empty batch, the prover already failed this task", "height", getTaskParameter.ProverHeight)
return nil, nil
}
}

rowsAffected, updateAttemptsErr := bp.batchOrm.UpdateBatchAttempts(ctx.Copy(), tmpBatchTask.Index, tmpBatchTask.ActiveAttempts, tmpBatchTask.TotalAttempts)
if updateAttemptsErr != nil {
log.Error("failed to update batch attempts", "height", getTaskParameter.ProverHeight, "err", updateAttemptsErr)
Expand Down
27 changes: 27 additions & 0 deletions coordinator/internal/logic/provertask/bundle_prover_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"scroll-tech/coordinator/internal/config"
"scroll-tech/coordinator/internal/orm"
coordinatorType "scroll-tech/coordinator/internal/types"
cutils "scroll-tech/coordinator/internal/utils"
)

// BundleProverTask is prover task implement for bundle proof
Expand Down Expand Up @@ -63,6 +64,18 @@ func (bp *BundleProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinat

maxActiveAttempts := bp.cfg.ProverManager.ProversPerSession
maxTotalAttempts := bp.cfg.ProverManager.SessionAttempts
if taskCtx.ProverProviderType == uint8(coordinatorType.ProverProviderTypeExternal) {
unassignedBundleCount, getCountError := bp.bundleOrm.GetUnassignedBundleCount(ctx.Copy(), maxActiveAttempts, maxTotalAttempts)
if getCountError != nil {
log.Error("failed to get unassigned bundle proving tasks count", "height", getTaskParameter.ProverHeight, "err", getCountError)
return nil, ErrCoordinatorInternalFailure
}
// Assign external prover if unassigned task number exceeds threshold
if unassignedBundleCount < bp.cfg.ProverManager.ExternalProverThreshold {
return nil, nil
}
}

var bundleTask *orm.Bundle
for i := 0; i < 5; i++ {
var getTaskError error
Expand All @@ -88,6 +101,20 @@ func (bp *BundleProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinat
return nil, nil
}

// Don't dispatch the same failing job to the same prover
proverTasks, getTaskError := bp.proverTaskOrm.GetFailedProverTasksByHash(ctx.Copy(), message.ProofTypeBundle, tmpBundleTask.Hash, 2)
if getTaskError != nil {
log.Error("failed to get prover tasks", "proof type", message.ProofTypeBundle.String(), "task ID", tmpBundleTask.Hash, "error", getTaskError)
return nil, ErrCoordinatorInternalFailure
}
for i := 0; i < len(proverTasks); i++ {
if proverTasks[i].ProverPublicKey == taskCtx.PublicKey ||
taskCtx.ProverProviderType == uint8(coordinatorType.ProverProviderTypeExternal) && cutils.IsExternalProverNameMatch(proverTasks[i].ProverName, taskCtx.ProverName) {
log.Debug("get empty bundle, the prover already failed this task", "height", getTaskParameter.ProverHeight)
return nil, nil
}
}

rowsAffected, updateAttemptsErr := bp.bundleOrm.UpdateBundleAttempts(ctx.Copy(), tmpBundleTask.Hash, tmpBundleTask.ActiveAttempts, tmpBundleTask.TotalAttempts)
if updateAttemptsErr != nil {
log.Error("failed to update bundle attempts", "height", getTaskParameter.ProverHeight, "err", updateAttemptsErr)
Expand Down
27 changes: 27 additions & 0 deletions coordinator/internal/logic/provertask/chunk_prover_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"scroll-tech/coordinator/internal/config"
"scroll-tech/coordinator/internal/orm"
coordinatorType "scroll-tech/coordinator/internal/types"
cutils "scroll-tech/coordinator/internal/utils"
)

// ChunkProverTask the chunk prover task
Expand Down Expand Up @@ -61,6 +62,18 @@ func (cp *ChunkProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinato

maxActiveAttempts := cp.cfg.ProverManager.ProversPerSession
maxTotalAttempts := cp.cfg.ProverManager.SessionAttempts
if taskCtx.ProverProviderType == uint8(coordinatorType.ProverProviderTypeExternal) {
unassignedChunkCount, getCountError := cp.chunkOrm.GetUnassignedChunkCount(ctx.Copy(), maxActiveAttempts, maxTotalAttempts, getTaskParameter.ProverHeight)
if getCountError != nil {
log.Error("failed to get unassigned chunk proving tasks count", "height", getTaskParameter.ProverHeight, "err", getCountError)
return nil, ErrCoordinatorInternalFailure
}
// Assign external prover if unassigned task number exceeds threshold
if unassignedChunkCount < cp.cfg.ProverManager.ExternalProverThreshold {
return nil, nil
}
}

var chunkTask *orm.Chunk
for i := 0; i < 5; i++ {
var getTaskError error
Expand All @@ -86,6 +99,20 @@ func (cp *ChunkProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinato
return nil, nil
}

// Don't dispatch the same failing job to the same prover
proverTasks, getTaskError := cp.proverTaskOrm.GetFailedProverTasksByHash(ctx.Copy(), message.ProofTypeChunk, tmpChunkTask.Hash, 2)
if getTaskError != nil {
log.Error("failed to get prover tasks", "proof type", message.ProofTypeChunk.String(), "task ID", tmpChunkTask.Hash, "error", getTaskError)
return nil, ErrCoordinatorInternalFailure
}
for i := 0; i < len(proverTasks); i++ {
if proverTasks[i].ProverPublicKey == taskCtx.PublicKey ||
taskCtx.ProverProviderType == uint8(coordinatorType.ProverProviderTypeExternal) && cutils.IsExternalProverNameMatch(proverTasks[i].ProverName, taskCtx.ProverName) {
log.Debug("get empty chunk, the prover already failed this task", "height", getTaskParameter.ProverHeight)
return nil, nil
}
}

rowsAffected, updateAttemptsErr := cp.chunkOrm.UpdateChunkAttempts(ctx.Copy(), tmpChunkTask.Index, tmpChunkTask.ActiveAttempts, tmpChunkTask.TotalAttempts)
if updateAttemptsErr != nil {
log.Error("failed to update chunk attempts", "height", getTaskParameter.ProverHeight, "err", updateAttemptsErr)
Expand Down
15 changes: 11 additions & 4 deletions coordinator/internal/logic/provertask/prover_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ type BaseProverTask struct {
}

type proverTaskContext struct {
PublicKey string
ProverName string
ProverVersion string
HardForkNames map[string]struct{}
PublicKey string
ProverName string
ProverVersion string
ProverProviderType uint8
HardForkNames map[string]struct{}
}

// checkParameter check the prover task parameter illegal
Expand All @@ -76,6 +77,12 @@ func (b *BaseProverTask) checkParameter(ctx *gin.Context) (*proverTaskContext, e
}
ptc.ProverVersion = proverVersion.(string)

ProverProviderType, ProverProviderTypeExist := ctx.Get(coordinatorType.ProverProviderTypeKey)
if !ProverProviderTypeExist {
return nil, errors.New("get prover provider type from context failed")
}
ptc.ProverProviderType = uint8(ProverProviderType.(float64))

hardForkNamesStr, hardForkNameExist := ctx.Get(coordinatorType.HardForkName)
if !hardForkNameExist {
return nil, errors.New("get hard fork name from context failed")
Expand Down
16 changes: 16 additions & 0 deletions coordinator/internal/orm/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,22 @@ func (o *Batch) GetUnassignedBatch(ctx context.Context, maxActiveAttempts, maxTo
return &batch, nil
}

// GetUnassignedBatchCount retrieves unassigned batch count.
func (o *Batch) GetUnassignedBatchCount(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8) (int64, error) {
var count int64
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Where("proving_status = ?", int(types.ProvingTaskUnassigned))
db = db.Where("total_attempts < ?", maxTotalAttempts)
db = db.Where("active_attempts < ?", maxActiveAttempts)
db = db.Where("chunk_proofs_status = ?", int(types.ChunkProofsStatusReady))
db = db.Where("batch.deleted_at IS NULL")
if err := db.Count(&count).Error; err != nil {
return 0, fmt.Errorf("Batch.GetUnassignedBatchCount error: %w", err)
}
return count, nil
}

// GetAssignedBatch retrieves assigned batch based on the specified limit.
// The returned batches are sorted in ascending order by their index.
func (o *Batch) GetAssignedBatch(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8) (*Batch, error) {
Expand Down
16 changes: 16 additions & 0 deletions coordinator/internal/orm/bundle.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,22 @@ func (o *Bundle) GetUnassignedBundle(ctx context.Context, maxActiveAttempts, max
return &bundle, nil
}

// GetUnassignedBundleCount retrieves unassigned bundle count.
func (o *Bundle) GetUnassignedBundleCount(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8) (int64, error) {
var count int64
db := o.db.WithContext(ctx)
db = db.Model(&Bundle{})
db = db.Where("proving_status = ?", int(types.ProvingTaskUnassigned))
db = db.Where("total_attempts < ?", maxTotalAttempts)
db = db.Where("active_attempts < ?", maxActiveAttempts)
db = db.Where("batch_proofs_status = ?", int(types.BatchProofsStatusReady))
db = db.Where("bundle.deleted_at IS NULL")
if err := db.Count(&count).Error; err != nil {
return 0, fmt.Errorf("Bundle.GetUnassignedBundleCount error: %w", err)
}
return count, nil
}

// GetAssignedBundle retrieves assigned bundle based on the specified limit.
// The returned bundle sorts in ascending order by their index.
func (o *Bundle) GetAssignedBundle(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8) (*Bundle, error) {
Expand Down
16 changes: 16 additions & 0 deletions coordinator/internal/orm/chunk.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,22 @@ func (o *Chunk) GetUnassignedChunk(ctx context.Context, maxActiveAttempts, maxTo
return &chunk, nil
}

// GetUnassignedChunkCount retrieves unassigned chunk count.
func (o *Chunk) GetUnassignedChunkCount(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8, height uint64) (int64, error) {
var count int64
db := o.db.WithContext(ctx)
db = db.Model(&Chunk{})
db = db.Where("proving_status = ?", int(types.ProvingTaskUnassigned))
db = db.Where("total_attempts < ?", maxTotalAttempts)
db = db.Where("active_attempts < ?", maxActiveAttempts)
db = db.Where("end_block_number <= ?", height)
db = db.Where("chunk.deleted_at IS NULL")
if err := db.Count(&count).Error; err != nil {
return 0, fmt.Errorf("Chunk.GetUnassignedChunkCount error: %w", err)
}
return count, nil
}

// GetAssignedChunk retrieves assigned chunk based on the specified limit.
// The returned chunks are sorted in ascending order by their index.
func (o *Chunk) GetAssignedChunk(ctx context.Context, maxActiveAttempts, maxTotalAttempts uint8, height uint64) (*Chunk, error) {
Expand Down
21 changes: 21 additions & 0 deletions coordinator/internal/orm/prover_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,27 @@ func (o *ProverTask) GetProverTasksByHashes(ctx context.Context, taskType messag
return proverTasks, nil
}

// GetFailedProverTasksByHash retrieves the failed ProverTask records associated with the specified hash.
// The returned prover task objects are sorted in descending order by their ids.
func (o *ProverTask) GetFailedProverTasksByHash(ctx context.Context, taskType message.ProofType, hash string, limit int) ([]*ProverTask, error) {
db := o.db.WithContext(ctx)
db = db.Model(&ProverTask{})
db = db.Where("task_type", int(taskType))
db = db.Where("task_id", hash)
db = db.Where("proving_status = ?", int(types.ProverProofInvalid))
db = db.Order("id desc")

if limit != 0 {
db = db.Limit(limit)
}

var proverTasks []*ProverTask
if err := db.Find(&proverTasks).Error; err != nil {
return nil, fmt.Errorf("ProverTask.GetFailedProverTasksByHash error: %w, hash: %v", err, hash)
}
return proverTasks, nil
}

// GetProverTaskByUUIDAndPublicKey get prover task taskID by uuid and public key
func (o *ProverTask) GetProverTaskByUUIDAndPublicKey(ctx context.Context, uuid, publicKey string) (*ProverTask, error) {
db := o.db.WithContext(ctx)
Expand Down
Loading

0 comments on commit a6f2457

Please sign in to comment.