Skip to content

Commit

Permalink
Refactor batch and passcode
Browse files Browse the repository at this point in the history
  • Loading branch information
minhduc140583 committed Jul 6, 2024
1 parent f3b5a79 commit 9182ee8
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 174 deletions.
25 changes: 19 additions & 6 deletions batch/batch_inserter.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,24 @@ import (
type BatchInserter[T any] struct {
collection *mongo.Collection
Map func(*T)
retryAll bool
}

func NewBatchInserter[T any](database *mongo.Database, collectionName string, options ...func(*T)) *BatchInserter[T] {
func NewBatchInserterWithRetry[T any](db *mongo.Database, collectionName string, retryAll bool, opts ...func(*T)) *BatchInserter[T] {
var t T
modelType := reflect.TypeOf(t)
if modelType.Kind() != reflect.Struct {
panic("T must be a struct")
}
var mp func(*T)
if len(options) > 0 {
mp = options[0]
if len(opts) > 0 {
mp = opts[0]
}
collection := database.Collection(collectionName)
return &BatchInserter[T]{collection: collection, Map: mp}
collection := db.Collection(collectionName)
return &BatchInserter[T]{collection: collection, Map: mp, retryAll: retryAll}
}
func NewBatchInserter[T any](db *mongo.Database, collectionName string, opts ...func(*T)) *BatchInserter[T] {
return NewBatchInserterWithRetry[T](db, collectionName, false, opts...)
}

func (w *BatchInserter[T]) Write(ctx context.Context, models []T) ([]int, error) {
Expand All @@ -33,5 +37,14 @@ func (w *BatchInserter[T]) Write(ctx context.Context, models []T) ([]int, error)
w.Map(&models[i])
}
}
return InsertMany[T](ctx, w.collection, models)
fails, err := InsertMany[T](ctx, w.collection, models)
if err != nil && len(fails) == 0 && w.retryAll {
l := len(models)
failIndices := make([]int, 0)
for i := 0; i < l; i++ {
failIndices = append(failIndices, i)
}
return failIndices, err
}
return fails, err
}
27 changes: 18 additions & 9 deletions batch/batch_updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,29 @@ type BatchUpdater[T any] struct {
collection *mongo.Collection
Idx int
Map func(*T)
retryAll bool
}

func NewBatchUpdaterWithId[T any](database *mongo.Database, collectionName string, options ...func(*T)) *BatchUpdater[T] {
func NewBatchUpdaterWithRetry[T any](db *mongo.Database, collectionName string, retryAll bool, opts ...func(*T)) *BatchUpdater[T] {
var t T
modelType := reflect.TypeOf(t)
if modelType.Kind() != reflect.Struct {
panic("T must be a struct")
}
idx := FindIdField(modelType)
if idx < 0 {
panic("T must contain Id field, which has '_id' bson tag")
}
var mp func(*T)
if len(options) > 0 {
mp = options[0]
if len(opts) > 0 {
mp = opts[0]
}
collection := database.Collection(collectionName)
return &BatchUpdater[T]{collection, idx, mp}
collection := db.Collection(collectionName)
return &BatchUpdater[T]{collection, idx, mp, retryAll}
}

func NewBatchUpdater[T any](database *mongo.Database, collectionName string, options ...func(*T)) *BatchUpdater[T] {
return NewBatchUpdaterWithId[T](database, collectionName, options...)
func NewBatchUpdater[T any](db *mongo.Database, collectionName string, retryAll bool, opts ...func(*T)) *BatchUpdater[T] {
return NewBatchUpdaterWithRetry[T](db, collectionName, false, opts...)
}

func (w *BatchUpdater[T]) Write(ctx context.Context, models []T) ([]int, error) {
failIndices := make([]int, 0)
var err error
Expand All @@ -49,6 +51,13 @@ func (w *BatchUpdater[T]) Write(ctx context.Context, models []T) ([]int, error)
for _, writeError := range bulkWriteException.WriteErrors {
failIndices = append(failIndices, writeError.Index)
}
} else if w.retryAll {
l := len(models)
fails := make([]int, 0)
for i := 0; i < l; i++ {
fails = append(fails, i)
}
return fails, err
}
return failIndices, err
}
26 changes: 20 additions & 6 deletions batch/batch_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,28 @@ type BatchWriter[T any] struct {
collection *mongo.Collection
Idx int
Map func(*T)
retryAll bool
}

func NewBatchWriter[T any](database *mongo.Database, collectionName string, options ...func(*T)) *BatchWriter[T] {
func NewBatchWriterWithRetry[T any](db *mongo.Database, collectionName string, retryAll bool, opts ...func(*T)) *BatchWriter[T] {
var t T
modelType := reflect.TypeOf(t)
if modelType.Kind() != reflect.Struct {
panic("T must be a struct")
}
idx := FindIdField(modelType)
if idx < 0 {
panic("T must contain Id field, which has '_id' bson tag")
}
var mp func(*T)
if len(options) > 0 {
mp = options[0]
if len(opts) > 0 {
mp = opts[0]
}
idx := FindIdField(modelType)
collection := database.Collection(collectionName)
return &BatchWriter[T]{collection, idx, mp}
collection := db.Collection(collectionName)
return &BatchWriter[T]{collection, idx, mp, retryAll}
}
func NewBatchWriter[T any](db *mongo.Database, collectionName string, retryAll bool, opts ...func(*T)) *BatchWriter[T] {
return NewBatchWriterWithRetry[T](db, collectionName, false, opts...)
}
func (w *BatchWriter[T]) Write(ctx context.Context, models []T) ([]int, error) {
failIndices := make([]int, 0)
Expand All @@ -45,6 +52,13 @@ func (w *BatchWriter[T]) Write(ctx context.Context, models []T) ([]int, error) {
for _, writeError := range bulkWriteException.WriteErrors {
failIndices = append(failIndices, writeError.Index)
}
} else if w.retryAll {
l := len(models)
fails := make([]int, 0)
for i := 0; i < l; i++ {
fails = append(fails, i)
}
return fails, err
}
return failIndices, err
}
3 changes: 0 additions & 3 deletions go.mod

This file was deleted.

Empty file removed go.sum
Empty file.
170 changes: 20 additions & 150 deletions passcode/passcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@ package passcode
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"log"
"reflect"
"strings"
"time"

"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)

type PasscodeRepository struct {
collection *mongo.Collection
collection *mongo.Collection
passcodeName string
expiredAtName string
}
Expand All @@ -38,8 +37,19 @@ func (p *PasscodeRepository) Save(ctx context.Context, id string, passcode strin
pass["_id"] = id
pass[p.passcodeName] = passcode
pass[p.expiredAtName] = expiredAt
idQuery := bson.M{"_id": id}
return UpsertOne(ctx, p.collection, idQuery, pass)
updateQuery := bson.M{
"$set": pass,
}
filter := bson.M{"_id": id}
opts := options.Update().SetUpsert(true)
res, err := p.collection.UpdateOne(ctx, filter, updateQuery, opts)
if res.ModifiedCount > 0 {
return res.ModifiedCount, err
} else if res.UpsertedCount > 0 {
return res.UpsertedCount, err
} else {
return res.MatchedCount, err
}
}

func (p *PasscodeRepository) Load(ctx context.Context, id string) (string, time.Time, error) {
Expand All @@ -63,150 +73,10 @@ func (p *PasscodeRepository) Load(ctx context.Context, id string) (string, time.
}

func (p *PasscodeRepository) Delete(ctx context.Context, id string) (int64, error) {
idQuery := bson.M{"_id": id}
return DeleteOne(ctx, p.collection, idQuery)
}

func DeleteOne(ctx context.Context, coll *mongo.Collection, query bson.M) (int64, error) {
result, err := coll.DeleteOne(ctx, query)
filter := bson.M{"_id": id}
result, err := p.collection.DeleteOne(ctx, filter)
if result == nil {
return 0, err
}
return result.DeletedCount, err
}
func Exist(ctx context.Context, collection *mongo.Collection, id interface{}, objectId bool) (bool, error) {
query := bson.M{"_id": id}
if objectId {
objId, err := primitive.ObjectIDFromHex(id.(string))
if err != nil {
return false, err
}
query = bson.M{"_id": objId}
}
x := collection.FindOne(ctx, query)
if x.Err() != nil {
if fmt.Sprint(x.Err()) == "mongo: no documents in result" {
return false, nil
} else {
return false, x.Err()
}
}
return true, nil
}
func UpsertOne(ctx context.Context, collection *mongo.Collection, filter bson.M, model interface{}) (int64, error) {
defaultObjID, _ := primitive.ObjectIDFromHex("000000000000")

if idValue := filter["_id"]; idValue == "" || idValue == 0 || idValue == defaultObjID {
return InsertOne(ctx, collection, model)
} else {
isExisted, err := Exist(ctx, collection, idValue, false)
if err != nil {
return 0, err
}
if isExisted {
update := bson.M{
"$set": model,
}
result := collection.FindOneAndUpdate(ctx, filter, update)
if result.Err() != nil {
if fmt.Sprint(result.Err()) == "mongo: no documents in result" {
return 0, nil
} else {
return 0, result.Err()
}
}
return 1, result.Err()
} else {
return InsertOne(ctx, collection, model)
}
}
}
func InsertOne(ctx context.Context, collection *mongo.Collection, model interface{}) (int64, error) {
result, err := collection.InsertOne(ctx, model)
if err != nil {
errMsg := err.Error()
if strings.Index(errMsg, "duplicate key error collection:") >= 0 {
return 0, nil
} else {
return 0, err
}
} else {
if idValue, ok := result.InsertedID.(primitive.ObjectID); ok {
valueOfModel := reflect.Indirect(reflect.ValueOf(model))
typeOfModel := valueOfModel.Type()
idIndex, _, _ := FindIdField(typeOfModel)
if idIndex != -1 {
mapObjectIdToModel(idValue, valueOfModel, idIndex)
}
}
return 1, err
}
}
func FindIdField(modelType reflect.Type) (int, string, string) {
return FindField(modelType, "_id")
}
func FindField(modelType reflect.Type, bsonName string) (int, string, string) {
numField := modelType.NumField()
for i := 0; i < numField; i++ {
field := modelType.Field(i)
bsonTag := field.Tag.Get("bson")
tags := strings.Split(bsonTag, ",")
json := field.Name
if tag1, ok1 := field.Tag.Lookup("json"); ok1 {
json = strings.Split(tag1, ",")[0]
}
for _, tag := range tags {
if strings.TrimSpace(tag) == bsonName {
return i, field.Name, json
}
}
}
return -1, "", ""
}
func mapObjectIdToModel(id primitive.ObjectID, valueOfModel reflect.Value, idIndex int) {
switch reflect.Indirect(valueOfModel).Field(idIndex).Kind() {
case reflect.String:
if _, err := setValue(valueOfModel, idIndex, id.Hex()); err != nil {
log.Println("Err: " + err.Error())
}
break
default:
if _, err := setValue(valueOfModel, idIndex, id); err != nil {
log.Println("Err: " + err.Error())
}
break
}
}
func setValue(model interface{}, index int, value interface{}) (interface{}, error) {
vo := reflect.Indirect(reflect.ValueOf(model))
switch reflect.ValueOf(model).Kind() {
case reflect.Ptr:
{
vo.Field(index).Set(reflect.ValueOf(value))
return model, nil
}
default:
if modelWithTypeValue, ok := model.(reflect.Value); ok {
_, err := setValueWithTypeValue(modelWithTypeValue, index, value)
return modelWithTypeValue.Interface(), err
}
}
return model, nil
}
func setValueWithTypeValue(model reflect.Value, index int, value interface{}) (reflect.Value, error) {
trueValue := reflect.Indirect(model)
switch trueValue.Kind() {
case reflect.Struct:
{
val := reflect.Indirect(reflect.ValueOf(value))
if trueValue.Field(index).Kind() == val.Kind() {
trueValue.Field(index).Set(reflect.ValueOf(value))
return trueValue, nil
} else {
return trueValue, fmt.Errorf("value's kind must same as field's kind")
}
}
default:
return trueValue, nil
}
}

0 comments on commit 9182ee8

Please sign in to comment.