Skip to content

Commit

Permalink
Implement DMP-110: Refactor task distribute (#122)
Browse files Browse the repository at this point in the history
* Implement DMP-110

Signed-off-by: Prnyself <[email protected]>

* model: Add RunTask

Signed-off-by: Prnyself <[email protected]>

* task: Refactor task run

Signed-off-by: Prnyself <[email protected]>

* add info in log and test

Signed-off-by: Prnyself <[email protected]>

* fix go.sum

Signed-off-by: Prnyself <[email protected]>

* fix UpdateTask in cmd

Signed-off-by: Prnyself <[email protected]>
  • Loading branch information
Prnyself authored Jul 27, 2021
1 parent 93391cb commit 783ad49
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 80 deletions.
5 changes: 3 additions & 2 deletions api/graphql/schema.resolvers.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import (
"errors"
"fmt"

"github.com/beyondstorage/beyond-tp/models"
"github.com/beyondstorage/go-toolbox/zapcontext"

"github.com/beyondstorage/beyond-tp/models"
)

func (r *mutationResolver) CreateTask(ctx context.Context, input *CreateTask) (*Task, error) {
Expand Down Expand Up @@ -51,7 +52,7 @@ func (r *mutationResolver) RunTask(ctx context.Context, id string) (*Task, error
return nil, err
}

if err = r.runTask(ctx, task); err != nil {
if err = r.DB.RunTask(task.Id); err != nil {
return nil, err
}
return formatTask(task), nil
Expand Down
26 changes: 0 additions & 26 deletions api/graphql/task.go

This file was deleted.

2 changes: 1 addition & 1 deletion cmd/beyondtp/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func taskRunRun(c *cobra.Command, args []string) error {
task.UpdatedAt = timestamppb.Now()
task.Status = models.TaskStatus_Ready

err = db.UpdateTask(task)
err = db.UpdateTask(nil, task)
if err != nil {
logger.Error("save task", zap.String("id", task.Id), zap.Error(err))
return err
Expand Down
13 changes: 12 additions & 1 deletion models/key.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package models

import (
"strings"

"github.com/Xuanwo/go-bufferpool"
)

Expand Down Expand Up @@ -73,7 +75,7 @@ func StaffTaskPrefix(staffId string) []byte {
return b.BytesCopy()
}

// StaffTaskKey Style: st:<staff_id>:<task_id>
// StaffTaskKey Style: s_t:<staff_id>:<task_id>
func StaffTaskKey(staffId, taskId string) []byte {
b := pool.Get()
defer b.Free()
Expand Down Expand Up @@ -147,3 +149,12 @@ func init() {
// Set init size to 64 to prevent alloc extra space.
pool = bufferpool.New(64)
}

// GetTaskIDFromStaffTaskKey Style: s_t:<staff_id>:<task_id>
func GetTaskIDFromStaffTaskKey(key string) string {
results := strings.Split(key, ":")
if len(results) < 3 {
return ""
}
return results[2]
}
69 changes: 57 additions & 12 deletions models/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package models
import (
"context"
"errors"
"fmt"

"github.com/dgraph-io/badger/v3"
protobuf "github.com/golang/protobuf/proto"
"github.com/google/uuid"
"go.uber.org/zap"
"google.golang.org/protobuf/types/known/timestamppb"
)

Expand Down Expand Up @@ -35,7 +37,7 @@ func NewTaskFromBytes(bs []byte) *Task {
return t
}

// Insert will insert task and update all staffs task queue.
// InsertTask will insert a task.
func (d *DB) InsertTask(txn *badger.Txn, t *Task) (err error) {
if txn == nil {
txn = d.db.NewTransaction(true)
Expand All @@ -44,13 +46,6 @@ func (d *DB) InsertTask(txn *badger.Txn, t *Task) (err error) {
}()
}

for _, v := range t.StaffIds {
err = d.InsertStaffTask(txn, v, t.Id)
if err != nil {
return
}
}

bs, err := protobuf.Marshal(t)
if err != nil {
return err
Expand All @@ -62,9 +57,13 @@ func (d *DB) InsertTask(txn *badger.Txn, t *Task) (err error) {
return
}

func (d *DB) UpdateTask(t *Task) error {
txn := d.db.NewTransaction(true)
defer txn.Discard()
func (d *DB) UpdateTask(txn *badger.Txn, t *Task) (err error) {
if txn == nil {
txn = d.db.NewTransaction(true)
defer func() {
err = d.CloseTxn(txn, err)
}()
}

bs, err := protobuf.Marshal(t)
if err != nil {
Expand All @@ -75,7 +74,7 @@ func (d *DB) UpdateTask(t *Task) error {
if err = txn.Set(TaskKey(t.Id), bs); err != nil {
return err
}
return txn.Commit()
return
}

// DeleteTask delete a task by given ID from DB
Expand All @@ -90,6 +89,34 @@ func (d *DB) DeleteTask(id string) error {
return txn.Commit()
}

func (d *DB) RunTask(id string) (err error) {
task, err := d.GetTask(id)
if err != nil {
return err
}

task.UpdatedAt = timestamppb.Now()
task.Status = TaskStatus_Running

txn := d.db.NewTransaction(true)
defer func() {
err = d.CloseTxn(txn, err)
}()

err = d.UpdateTask(txn, task)
if err != nil {
return fmt.Errorf("update task %s failed: [%w]", task.Id, err)
}

for _, staffId := range task.StaffIds {
err = d.InsertStaffTask(txn, staffId, task.Id)
if err != nil {
return fmt.Errorf("insert staff task %s to staff %s failed: [%w]", task.Id, staffId, err)
}
}
return
}

// GetTask get task from db and parsed into struct with specific ID
func (d *DB) GetTask(id string) (t *Task, err error) {
txn := d.db.NewTransaction(false)
Expand Down Expand Up @@ -141,6 +168,24 @@ func (d *DB) SubscribeTask(ctx context.Context, fn func(t *Task)) (err error) {
}, TaskPrefix)
}

func (d *DB) StaffWatchTaskRun(staffID string, fn func(staffTaskKey string) error) error {
return d.db.Subscribe(context.TODO(), func(kv *badger.KVList) error {
for _, v := range kv.Kv {
// do not handle key delete
if v.Value == nil {
continue
}
d.logger.Debug("key change", zap.String("key", string(v.Key)), zap.String("val", string(v.Value)), zap.Bool("del", v.Value == nil))
err := fn(string(v.Key))
if err != nil {
d.logger.Error("handle task key", zap.String("staff_task_key", string(v.Key)))
return err
}
}
return nil
}, StaffTaskPrefix(staffID))
}

func (d *DB) WaitTask(ctx context.Context, taskId string) (err error) {
_, err = d.GetTask(taskId)
// If job doesn't exist, we can return directly.
Expand Down
45 changes: 15 additions & 30 deletions task/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"net"
"time"

"github.com/beyondstorage/go-toolbox/zapcontext"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
Expand Down Expand Up @@ -117,50 +116,36 @@ func (p *Manager) Elect(ctx context.Context, req *models.ElectRequest) (reply *m
func (p *Manager) Poll(req *models.PollRequest, srv models.Staff_PollServer) (err error) {
logger := p.logger

for {
reply := &models.PollReply{}

taskId, err := p.db.NextStaffTask(nil, req.StaffId)
logger.Debug("start poll", zap.String("staff", req.StaffId))
err = p.db.StaffWatchTaskRun(req.StaffId, func(staffTaskKey string) error {
taskID := models.GetTaskIDFromStaffTaskKey(staffTaskKey)
task, err := p.db.GetTask(taskID)
if err != nil {
logger.Error("next staff task", zap.Error(err))
return err
}

// task_id == "" means there is no task for out staff.
if taskId == "" {
reply.Status = models.PollStatus_Empty

err = srv.Send(reply)
if err != nil {
return err
}

// FIXME: we need to find a way to watch staff task changes.
time.Sleep(60 * time.Second)
return err
}

task, err := p.db.GetTask(taskId)
if err != nil {
logger.Error("get task", zap.Error(err))
logger.Error("get task when staff watch", zap.String("staff", req.StaffId), zap.Error(err))
return err
}

reply := &models.PollReply{}
reply.Status = models.PollStatus_Valid
reply.Task = task

err = srv.Send(reply)
if err != nil {
logger.Error("poll send task",
zap.String("task", task.Id), zap.String("staff", req.StaffId), zap.Error(err))
return err
}

logger.Info("polled task", zap.String("id", task.Id))
logger.Info("polled task, ready to remove", zap.String("id", task.Id), zap.String("staff", req.StaffId))

err = p.db.DeleteStaffTask(nil, req.StaffId, taskId)
err = p.db.DeleteStaffTask(nil, req.StaffId, taskID)
if err != nil {
return err
}
}
return nil
})

return
}

func (p *Manager) Finish(ctx context.Context, req *models.FinishRequest) (reply *models.FinishReply, err error) {
Expand All @@ -176,7 +161,7 @@ func (p *Manager) Finish(ctx context.Context, req *models.FinishRequest) (reply

t.Status = models.TaskStatus_Finished

err = p.db.UpdateTask(t)
err = p.db.UpdateTask(nil, t)
if err != nil {
logger.Error("update task", zap.String("id", req.TaskId))
return
Expand Down
1 change: 1 addition & 0 deletions task/staff.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func NewStaff(ctx context.Context, cfg StaffConfig) (s *Staff, err error) {
ctx: ctx,
logger: logger,
}
logger.Info("staff created", zap.String("id", s.id))
return
}

Expand Down
24 changes: 16 additions & 8 deletions task/worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"os"
"testing"
"time"

"github.com/beyondstorage/go-toolbox/zapcontext"
"github.com/google/uuid"
Expand Down Expand Up @@ -44,18 +45,18 @@ func TestWorker(t *testing.T) {
w, err := NewStaff(ctx, StaffConfig{
Host: "localhost",
ManagerAddr: "localhost:7000",
DataPath: "/tmp",
DataPath: "/tmp/badger",
})
if err != nil {
t.Error(err)
t.Fatal(err)
}

staffIds = append(staffIds, w.id)

go w.Start(ctx)
}

copyFileTask := &models.Task{
task := &models.Task{
Id: uuid.NewString(),
Type: models.TaskType_CopyDir,
Status: models.TaskStatus_Ready,
Expand All @@ -66,19 +67,26 @@ func TestWorker(t *testing.T) {
},
}

err := p.db.InsertTask(nil, copyFileTask)
err := p.db.InsertTask(nil, task)
if err != nil {
t.Errorf("insert task: %v", err)
t.Fatalf("insert task: %v", err)
}

err = p.db.WaitTask(ctx, copyFileTask.Id)
time.Sleep(time.Second)

err = p.db.RunTask(task.Id)
if err != nil {
t.Fatalf("run task: %v", err)
}

err = p.db.WaitTask(ctx, task.Id)
if err != nil {
t.Errorf("wait task: %v", err)
t.Fatalf("wait task: %v", err)
}
t.Logf("task has been finished")

err = p.Stop(ctx)
if err != nil {
t.Errorf("stop: %v", err)
t.Fatalf("stop: %v", err)
}
}

0 comments on commit 783ad49

Please sign in to comment.