-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathupdate.go
118 lines (91 loc) · 2.34 KB
/
update.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package sqlcraft
import (
"strconv"
"strings"
"github.com/techforge-lat/dafi/v2"
)
type UpdateQuery struct {
table string
columns []string
returningValues []string
values []any
isPartialUpdate bool
sqlColumnByDomainField map[string]string
filters dafi.Filters
}
func Update(table string) UpdateQuery {
return UpdateQuery{
table: table,
columns: []string{},
returningValues: []string{},
values: []any{},
}
}
func (u UpdateQuery) WithColumns(columns ...string) UpdateQuery {
u.columns = columns
return u
}
func (u UpdateQuery) WithValues(values ...any) UpdateQuery {
u.values = values
return u
}
func (u UpdateQuery) Where(filters ...dafi.Filter) UpdateQuery {
u.filters = filters
return u
}
func (u UpdateQuery) SQLColumnByDomainField(sqlColumnByDomainField map[string]string) UpdateQuery {
u.sqlColumnByDomainField = sqlColumnByDomainField
return u
}
func (u UpdateQuery) Returning(columns ...string) UpdateQuery {
u.returningValues = columns
return u
}
func (u UpdateQuery) WithPartialUpdate() UpdateQuery {
u.isPartialUpdate = true
return u
}
func (u UpdateQuery) ToSQL() (Result, error) {
if len(u.values) > 0 && len(u.values) != len(u.columns) {
return Result{}, ErrMissMatchValues
}
builder := strings.Builder{}
builder.WriteString("UPDATE ")
builder.WriteString(u.table)
builder.WriteString(" SET ")
for i, column := range u.columns {
if u.isPartialUpdate {
builder.WriteString(column)
builder.WriteString(" = ")
builder.WriteString("COALESCE(")
builder.WriteString("$")
builder.WriteString(strconv.Itoa(i + 1))
builder.WriteString(", ")
builder.WriteString(column)
builder.WriteString(")")
} else {
builder.WriteString(column)
builder.WriteString(" = $")
builder.WriteString(strconv.Itoa(i + 1))
}
if i+1 < len(u.columns) {
builder.WriteString(", ")
}
}
if len(u.filters) > 0 {
whereResult, err := WhereSafe(len(u.values), u.sqlColumnByDomainField, u.filters...)
if err != nil {
return Result{}, err
}
u.values = append(u.values, whereResult.Args...)
builder.WriteString(whereResult.Sql)
}
if len(u.returningValues) > 0 {
builder.WriteString(" RETURNING ")
builder.WriteString(strings.Join(u.returningValues, ", "))
}
return Result{
Sql: builder.String(),
Args: u.values,
}, nil
}