diff --git a/api/graphql/schema.resolvers.go b/api/graphql/schema.resolvers.go index a6c2c2b..a000dc7 100644 --- a/api/graphql/schema.resolvers.go +++ b/api/graphql/schema.resolvers.go @@ -8,8 +8,9 @@ import ( "errors" "fmt" - "github.com/beyondstorage/beyond-tp/models" "github.com/beyondstorage/go-toolbox/zapcontext" + + "github.com/beyondstorage/beyond-tp/models" ) func (r *mutationResolver) CreateTask(ctx context.Context, input *CreateTask) (*Task, error) { @@ -51,7 +52,7 @@ func (r *mutationResolver) RunTask(ctx context.Context, id string) (*Task, error return nil, err } - if err = r.runTask(ctx, task); err != nil { + if err = r.DB.RunTask(task.Id); err != nil { return nil, err } return formatTask(task), nil diff --git a/api/graphql/task.go b/api/graphql/task.go deleted file mode 100644 index f1f3870..0000000 --- a/api/graphql/task.go +++ /dev/null @@ -1,26 +0,0 @@ -package graphql - -import ( - "context" - - "github.com/beyondstorage/go-toolbox/zapcontext" - "go.uber.org/zap" - "google.golang.org/protobuf/types/known/timestamppb" - - "github.com/beyondstorage/beyond-tp/models" -) - -// runTask handle publish task and update -func (r *mutationResolver) runTask(ctx context.Context, task *models.Task) error { - gc := GinContextFrom(ctx) - logger := zapcontext.FromGin(gc) - - task.UpdatedAt = timestamppb.Now() - task.Status = models.TaskStatus_Ready - err := r.DB.UpdateTask(task) - if err != nil { - logger.Error("save task", zap.String("id", task.Id), zap.Error(err)) - return err - } - return nil -} diff --git a/cmd/beyondtp/task.go b/cmd/beyondtp/task.go index de890d9..45ff869 100644 --- a/cmd/beyondtp/task.go +++ b/cmd/beyondtp/task.go @@ -142,7 +142,7 @@ func taskRunRun(c *cobra.Command, args []string) error { task.UpdatedAt = timestamppb.Now() task.Status = models.TaskStatus_Ready - err = db.UpdateTask(task) + err = db.UpdateTask(nil, task) if err != nil { logger.Error("save task", zap.String("id", task.Id), zap.Error(err)) return err diff --git a/models/key.go b/models/key.go index 8a0b165..87251db 100644 --- a/models/key.go +++ b/models/key.go @@ -1,6 +1,8 @@ package models import ( + "strings" + "github.com/Xuanwo/go-bufferpool" ) @@ -73,7 +75,7 @@ func StaffTaskPrefix(staffId string) []byte { return b.BytesCopy() } -// StaffTaskKey Style: st:: +// StaffTaskKey Style: s_t:: func StaffTaskKey(staffId, taskId string) []byte { b := pool.Get() defer b.Free() @@ -147,3 +149,12 @@ func init() { // Set init size to 64 to prevent alloc extra space. pool = bufferpool.New(64) } + +// GetTaskIDFromStaffTaskKey Style: s_t:: +func GetTaskIDFromStaffTaskKey(key string) string { + results := strings.Split(key, ":") + if len(results) < 3 { + return "" + } + return results[2] +} diff --git a/models/task.go b/models/task.go index d4b5b18..540abbd 100644 --- a/models/task.go +++ b/models/task.go @@ -3,10 +3,12 @@ package models import ( "context" "errors" + "fmt" "github.com/dgraph-io/badger/v3" protobuf "github.com/golang/protobuf/proto" "github.com/google/uuid" + "go.uber.org/zap" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -35,7 +37,7 @@ func NewTaskFromBytes(bs []byte) *Task { return t } -// Insert will insert task and update all staffs task queue. +// InsertTask will insert a task. func (d *DB) InsertTask(txn *badger.Txn, t *Task) (err error) { if txn == nil { txn = d.db.NewTransaction(true) @@ -44,13 +46,6 @@ func (d *DB) InsertTask(txn *badger.Txn, t *Task) (err error) { }() } - for _, v := range t.StaffIds { - err = d.InsertStaffTask(txn, v, t.Id) - if err != nil { - return - } - } - bs, err := protobuf.Marshal(t) if err != nil { return err @@ -62,9 +57,13 @@ func (d *DB) InsertTask(txn *badger.Txn, t *Task) (err error) { return } -func (d *DB) UpdateTask(t *Task) error { - txn := d.db.NewTransaction(true) - defer txn.Discard() +func (d *DB) UpdateTask(txn *badger.Txn, t *Task) (err error) { + if txn == nil { + txn = d.db.NewTransaction(true) + defer func() { + err = d.CloseTxn(txn, err) + }() + } bs, err := protobuf.Marshal(t) if err != nil { @@ -75,7 +74,7 @@ func (d *DB) UpdateTask(t *Task) error { if err = txn.Set(TaskKey(t.Id), bs); err != nil { return err } - return txn.Commit() + return } // DeleteTask delete a task by given ID from DB @@ -90,6 +89,34 @@ func (d *DB) DeleteTask(id string) error { return txn.Commit() } +func (d *DB) RunTask(id string) (err error) { + task, err := d.GetTask(id) + if err != nil { + return err + } + + task.UpdatedAt = timestamppb.Now() + task.Status = TaskStatus_Running + + txn := d.db.NewTransaction(true) + defer func() { + err = d.CloseTxn(txn, err) + }() + + err = d.UpdateTask(txn, task) + if err != nil { + return fmt.Errorf("update task %s failed: [%w]", task.Id, err) + } + + for _, staffId := range task.StaffIds { + err = d.InsertStaffTask(txn, staffId, task.Id) + if err != nil { + return fmt.Errorf("insert staff task %s to staff %s failed: [%w]", task.Id, staffId, err) + } + } + return +} + // GetTask get task from db and parsed into struct with specific ID func (d *DB) GetTask(id string) (t *Task, err error) { txn := d.db.NewTransaction(false) @@ -141,6 +168,24 @@ func (d *DB) SubscribeTask(ctx context.Context, fn func(t *Task)) (err error) { }, TaskPrefix) } +func (d *DB) StaffWatchTaskRun(staffID string, fn func(staffTaskKey string) error) error { + return d.db.Subscribe(context.TODO(), func(kv *badger.KVList) error { + for _, v := range kv.Kv { + // do not handle key delete + if v.Value == nil { + continue + } + d.logger.Debug("key change", zap.String("key", string(v.Key)), zap.String("val", string(v.Value)), zap.Bool("del", v.Value == nil)) + err := fn(string(v.Key)) + if err != nil { + d.logger.Error("handle task key", zap.String("staff_task_key", string(v.Key))) + return err + } + } + return nil + }, StaffTaskPrefix(staffID)) +} + func (d *DB) WaitTask(ctx context.Context, taskId string) (err error) { _, err = d.GetTask(taskId) // If job doesn't exist, we can return directly. diff --git a/task/manager.go b/task/manager.go index 642ff17..ce2d356 100644 --- a/task/manager.go +++ b/task/manager.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net" - "time" "github.com/beyondstorage/go-toolbox/zapcontext" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" @@ -117,50 +116,36 @@ func (p *Manager) Elect(ctx context.Context, req *models.ElectRequest) (reply *m func (p *Manager) Poll(req *models.PollRequest, srv models.Staff_PollServer) (err error) { logger := p.logger - for { - reply := &models.PollReply{} - - taskId, err := p.db.NextStaffTask(nil, req.StaffId) + logger.Debug("start poll", zap.String("staff", req.StaffId)) + err = p.db.StaffWatchTaskRun(req.StaffId, func(staffTaskKey string) error { + taskID := models.GetTaskIDFromStaffTaskKey(staffTaskKey) + task, err := p.db.GetTask(taskID) if err != nil { - logger.Error("next staff task", zap.Error(err)) - return err - } - - // task_id == "" means there is no task for out staff. - if taskId == "" { - reply.Status = models.PollStatus_Empty - - err = srv.Send(reply) - if err != nil { - return err - } - - // FIXME: we need to find a way to watch staff task changes. - time.Sleep(60 * time.Second) - return err - } - - task, err := p.db.GetTask(taskId) - if err != nil { - logger.Error("get task", zap.Error(err)) + logger.Error("get task when staff watch", zap.String("staff", req.StaffId), zap.Error(err)) return err } + reply := &models.PollReply{} reply.Status = models.PollStatus_Valid reply.Task = task err = srv.Send(reply) if err != nil { + logger.Error("poll send task", + zap.String("task", task.Id), zap.String("staff", req.StaffId), zap.Error(err)) return err } - logger.Info("polled task", zap.String("id", task.Id)) + logger.Info("polled task, ready to remove", zap.String("id", task.Id), zap.String("staff", req.StaffId)) - err = p.db.DeleteStaffTask(nil, req.StaffId, taskId) + err = p.db.DeleteStaffTask(nil, req.StaffId, taskID) if err != nil { return err } - } + return nil + }) + + return } func (p *Manager) Finish(ctx context.Context, req *models.FinishRequest) (reply *models.FinishReply, err error) { @@ -176,7 +161,7 @@ func (p *Manager) Finish(ctx context.Context, req *models.FinishRequest) (reply t.Status = models.TaskStatus_Finished - err = p.db.UpdateTask(t) + err = p.db.UpdateTask(nil, t) if err != nil { logger.Error("update task", zap.String("id", req.TaskId)) return diff --git a/task/staff.go b/task/staff.go index 290aacf..94b4fd5 100644 --- a/task/staff.go +++ b/task/staff.go @@ -45,6 +45,7 @@ func NewStaff(ctx context.Context, cfg StaffConfig) (s *Staff, err error) { ctx: ctx, logger: logger, } + logger.Info("staff created", zap.String("id", s.id)) return } diff --git a/task/worker_test.go b/task/worker_test.go index 07508c6..4e0daaa 100644 --- a/task/worker_test.go +++ b/task/worker_test.go @@ -4,6 +4,7 @@ import ( "context" "os" "testing" + "time" "github.com/beyondstorage/go-toolbox/zapcontext" "github.com/google/uuid" @@ -44,10 +45,10 @@ func TestWorker(t *testing.T) { w, err := NewStaff(ctx, StaffConfig{ Host: "localhost", ManagerAddr: "localhost:7000", - DataPath: "/tmp", + DataPath: "/tmp/badger", }) if err != nil { - t.Error(err) + t.Fatal(err) } staffIds = append(staffIds, w.id) @@ -55,7 +56,7 @@ func TestWorker(t *testing.T) { go w.Start(ctx) } - copyFileTask := &models.Task{ + task := &models.Task{ Id: uuid.NewString(), Type: models.TaskType_CopyDir, Status: models.TaskStatus_Ready, @@ -66,19 +67,26 @@ func TestWorker(t *testing.T) { }, } - err := p.db.InsertTask(nil, copyFileTask) + err := p.db.InsertTask(nil, task) if err != nil { - t.Errorf("insert task: %v", err) + t.Fatalf("insert task: %v", err) } - err = p.db.WaitTask(ctx, copyFileTask.Id) + time.Sleep(time.Second) + + err = p.db.RunTask(task.Id) + if err != nil { + t.Fatalf("run task: %v", err) + } + + err = p.db.WaitTask(ctx, task.Id) if err != nil { - t.Errorf("wait task: %v", err) + t.Fatalf("wait task: %v", err) } t.Logf("task has been finished") err = p.Stop(ctx) if err != nil { - t.Errorf("stop: %v", err) + t.Fatalf("stop: %v", err) } }