Skip to content

Commit

Permalink
fix: clone query in scanAndCountConcurrently to avoid data race
Browse files Browse the repository at this point in the history
Close #1117
  • Loading branch information
j2gg0s committed Feb 8, 2025
1 parent afb8fda commit 66fdc39
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 14 deletions.
142 changes: 128 additions & 14 deletions query_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -967,27 +967,27 @@ func (q *SelectQuery) scanAndCountConcurrently(
var mu sync.Mutex
var firstErr error

if q.limit >= 0 {
wg.Add(1)
go func() {
defer wg.Done()

if err := q.Scan(ctx, dest...); err != nil {
mu.Lock()
if firstErr == nil {
firstErr = err
}
mu.Unlock()
countQuery := q.Clone()

wg.Add(1)
go func() {
defer wg.Done()

if err := q.Scan(ctx, dest...); err != nil {
mu.Lock()
if firstErr == nil {
firstErr = err
}
}()
}
mu.Unlock()
}
}()

wg.Add(1)
go func() {
defer wg.Done()

var err error
count, err = q.Count(ctx)
count, err = countQuery.Count(ctx)
if err != nil {
mu.Lock()
if firstErr == nil {
Expand Down Expand Up @@ -1077,6 +1077,120 @@ func (q *SelectQuery) String() string {
return string(buf)
}

func (q *SelectQuery) Clone() *SelectQuery {
if q == nil {
return nil
}

cloneArgs := func(args []schema.QueryWithArgs) []schema.QueryWithArgs {
if len(args) == 0 {
return nil
}
clone := make([]schema.QueryWithArgs, len(args))
copy(clone, args)
return clone
}
cloneHints := func(hints *indexHints) *indexHints {
if hints == nil {
return nil
}
return &indexHints{
names: cloneArgs(hints.names),
forJoin: cloneArgs(hints.forJoin),
forOrderBy: cloneArgs(hints.forOrderBy),
forGroupBy: cloneArgs(hints.forGroupBy),
}
}

clone := &SelectQuery{
whereBaseQuery: whereBaseQuery{
baseQuery: baseQuery{
db: q.db,
table: q.table,
model: q.model,
tableModel: q.tableModel,
with: make([]withQuery, len(q.with)),
tables: cloneArgs(q.tables),
columns: cloneArgs(q.columns),
modelTableName: q.modelTableName,
},
where: make([]schema.QueryWithSep, len(q.where)),
},

idxHintsQuery: idxHintsQuery{
use: cloneHints(q.idxHintsQuery.use),
ignore: cloneHints(q.idxHintsQuery.ignore),
force: cloneHints(q.idxHintsQuery.force),
},

orderLimitOffsetQuery: orderLimitOffsetQuery{
order: cloneArgs(q.order),
limit: q.limit,
offset: q.offset,
},

distinctOn: cloneArgs(q.distinctOn),
joins: make([]joinQuery, len(q.joins)),
group: cloneArgs(q.group),
having: cloneArgs(q.having),
union: make([]union, len(q.union)),
comment: q.comment,
}

for i, w := range q.with {
clone.with[i] = withQuery{
name: w.name,
recursive: w.recursive,
query: w.query, // TODO: maybe clone is need
}
}

if !q.modelTableName.IsZero() {
clone.modelTableName = schema.SafeQuery(
q.modelTableName.Query,
append([]any(nil), q.modelTableName.Args...),
)
}

for i, w := range q.where {
clone.where[i] = schema.SafeQueryWithSep(
w.Query,
append([]any(nil), w.Args...),
w.Sep,
)
}

for i, j := range q.joins {
clone.joins[i] = joinQuery{
join: schema.SafeQuery(j.join.Query, append([]any(nil), j.join.Args...)),
on: make([]schema.QueryWithSep, len(j.on)),
}
for k, on := range j.on {
clone.joins[i].on[k] = schema.SafeQueryWithSep(
on.Query,
append([]any(nil), on.Args...),
on.Sep,
)
}
}

for i, u := range q.union {
clone.union[i] = union{
expr: u.expr,
query: u.query.Clone(),
}
}

if !q.selFor.IsZero() {
clone.selFor = schema.SafeQuery(
q.selFor.Query,
append([]any(nil), q.selFor.Args...),
)
}

return clone
}

//------------------------------------------------------------------------------

func (q *SelectQuery) QueryBuilder() QueryBuilder {
Expand Down
1 change: 1 addition & 0 deletions schema/sqlfmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ func (s Ident) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {

//------------------------------------------------------------------------------

// NOTE: It should not be modified after creation.
type QueryWithArgs struct {
Query string
Args []interface{}
Expand Down

0 comments on commit 66fdc39

Please sign in to comment.