Skip to content

Commit

Permalink
Use errors and fmt package to wrap errors (#33)
Browse files Browse the repository at this point in the history
To wrap errors `fmt.Errorf` with the `%w` format specifier (introduced in golang 1.13) is now used instead of `errors.Wrap` which was provided by the `github.com/pkg/errors` package.
  • Loading branch information
boekkooi-terramate authored Aug 28, 2024
1 parent 6669041 commit b943418
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 43 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ require (
github.com/jackc/pgx/v4 v4.18.2
github.com/lib/pq v1.10.2
github.com/oklog/ulid v1.3.1
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.8.1
)

Expand All @@ -25,6 +24,7 @@ require (
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgtype v1.14.0 // indirect
github.com/lithammer/shortuuid/v3 v3.0.7 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.10.0 // indirect
golang.org/x/crypto v0.20.0 // indirect
Expand Down
12 changes: 6 additions & 6 deletions pkg/sql/publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ package sql

import (
"context"
"errors"
"fmt"
"sync"

"github.com/pkg/errors"

"github.com/ThreeDotsLabs/watermill"
"github.com/ThreeDotsLabs/watermill/message"
)
Expand Down Expand Up @@ -53,7 +53,7 @@ type Publisher struct {
func NewPublisher(db ContextExecutor, config PublisherConfig, logger watermill.LoggerAdapter) (*Publisher, error) {
config.setDefaults()
if err := config.validate(); err != nil {
return nil, errors.Wrap(err, "invalid config")
return nil, fmt.Errorf("invalid config: %w", err)
}

if db == nil {
Expand Down Expand Up @@ -105,7 +105,7 @@ func (p *Publisher) Publish(topic string, messages ...*message.Message) (err err

insertQuery, err := p.config.SchemaAdapter.InsertQuery(topic, messages)
if err != nil {
return errors.Wrap(err, "cannot create insert query")
return fmt.Errorf("cannot create insert query: %w", err)
}

p.logger.Trace("Inserting message to SQL", watermill.LogFields{
Expand All @@ -115,7 +115,7 @@ func (p *Publisher) Publish(topic string, messages ...*message.Message) (err err

_, err = p.db.ExecContext(context.Background(), insertQuery.Query, insertQuery.Args...)
if err != nil {
return errors.Wrap(err, "could not insert message as row")
return fmt.Errorf("could not insert message as row: %w", err)
}

return nil
Expand All @@ -138,7 +138,7 @@ func (p *Publisher) initializeSchema(topic string) error {
p.config.SchemaAdapter,
nil,
); err != nil {
return errors.Wrap(err, "cannot initialize schema")
return fmt.Errorf("cannot initialize schema: %w", err)
}

p.initializedTopics.Store(topic, struct{}{})
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package sql

import (
"context"
"fmt"

"github.com/ThreeDotsLabs/watermill"
"github.com/pkg/errors"
)

func initializeSchema(
Expand All @@ -30,9 +30,9 @@ func initializeSchema(
})

for _, q := range initializingQueries {
_, err := db.ExecContext(ctx, q.Query, q.Args...)
_, err = db.ExecContext(ctx, q.Query, q.Args...)
if err != nil {
return errors.Wrap(err, "could not initialize schema")
return fmt.Errorf("could not initialize schema: %w", err)
}
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/schema_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package sql
import (
"database/sql"
"encoding/json"
"fmt"

"github.com/ThreeDotsLabs/watermill/message"
"github.com/pkg/errors"
)

// SchemaAdapter produces the SQL queries and arguments appropriately for a specific schema and dialect
Expand Down Expand Up @@ -49,7 +49,7 @@ func defaultInsertArgs(msgs message.Messages) ([]interface{}, error) {
for _, msg := range msgs {
metadata, err := json.Marshal(msg.Metadata)
if err != nil {
return nil, errors.Wrapf(err, "could not marshal metadata into JSON for message %s", msg.UUID)
return nil, fmt.Errorf("could not marshal metadata into JSON for message %s: %w", msg.UUID, err)
}

args = append(args, msg.UUID, []byte(msg.Payload), metadata)
Expand Down
5 changes: 2 additions & 3 deletions pkg/sql/schema_adapter_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"strings"

"github.com/ThreeDotsLabs/watermill/message"
"github.com/pkg/errors"
)

// DefaultMySQLSchema is a default implementation of SchemaAdapter based on MySQL.
Expand Down Expand Up @@ -102,15 +101,15 @@ func (s DefaultMySQLSchema) UnmarshalMessage(row Scanner) (Row, error) {
r := Row{}
err := row.Scan(&r.Offset, &r.UUID, &r.Payload, &r.Metadata)
if err != nil {
return Row{}, errors.Wrap(err, "could not scan message row")
return Row{}, fmt.Errorf("could not scan message row: %w", err)
}

msg := message.NewMessage(string(r.UUID), r.Payload)

if r.Metadata != nil {
err = json.Unmarshal(r.Metadata, &msg.Metadata)
if err != nil {
return Row{}, errors.Wrap(err, "could not unmarshal metadata as JSON")
return Row{}, fmt.Errorf("could not unmarshal metadata as JSON: %w", err)
}
}

Expand Down
5 changes: 2 additions & 3 deletions pkg/sql/schema_adapter_postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"strings"

"github.com/ThreeDotsLabs/watermill/message"
"github.com/pkg/errors"
)

// DefaultPostgreSQLSchema is a default implementation of SchemaAdapter based on PostgreSQL.
Expand Down Expand Up @@ -112,15 +111,15 @@ func (s DefaultPostgreSQLSchema) UnmarshalMessage(row Scanner) (Row, error) {

err := row.Scan(&r.Offset, &transactionID, &r.UUID, &r.Payload, &r.Metadata)
if err != nil {
return Row{}, errors.Wrap(err, "could not scan message row")
return Row{}, fmt.Errorf("could not scan message row: %w", err)
}

msg := message.NewMessage(string(r.UUID), r.Payload)

if r.Metadata != nil {
err = json.Unmarshal(r.Metadata, &msg.Metadata)
if err != nil {
return Row{}, errors.Wrap(err, "could not unmarshal metadata as JSON")
return Row{}, fmt.Errorf("could not unmarshal metadata as JSON: %w", err)
}
}

Expand Down
32 changes: 16 additions & 16 deletions pkg/sql/subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package sql
import (
"context"
"database/sql"
stdErrors "errors"
"errors"
"fmt"
"sync"
"time"

"github.com/oklog/ulid"
"github.com/pkg/errors"

"github.com/ThreeDotsLabs/watermill"
"github.com/ThreeDotsLabs/watermill/message"
Expand Down Expand Up @@ -126,7 +126,7 @@ func NewSubscriber(db Beginner, config SubscriberConfig, logger watermill.Logger
config.setDefaults()
err := config.validate()
if err != nil {
return nil, errors.Wrap(err, "invalid config")
return nil, fmt.Errorf("invalid config: %w", err)
}

if logger == nil {
Expand All @@ -135,7 +135,7 @@ func NewSubscriber(db Beginner, config SubscriberConfig, logger watermill.Logger

idBytes, idStr, err := newSubscriberID()
if err != nil {
return &Subscriber{}, errors.Wrap(err, "cannot generate subscriber id")
return &Subscriber{}, fmt.Errorf("cannot generate subscriber id: %w", err)
}
logger = logger.With(watermill.LogFields{"subscriber_id": idStr})

Expand All @@ -159,7 +159,7 @@ func newSubscriberID() ([]byte, string, error) {
id := watermill.NewULID()
idBytes, err := ulid.MustParseStrict(id).MarshalBinary()
if err != nil {
return nil, "", errors.Wrap(err, "cannot marshal subscriber id")
return nil, "", fmt.Errorf("cannot marshal subscriber id: %w", err)
}

return idBytes, id, nil
Expand Down Expand Up @@ -191,7 +191,7 @@ func (s *Subscriber) Subscribe(ctx context.Context, topic string) (o <-chan *mes

_, err := tx.ExecContext(ctx, q.Query, q.Args...)
if err != nil {
return errors.Wrap(err, "cannot execute before subscribing query")
return fmt.Errorf("cannot execute before subscribing query: %w", err)
}
}
return nil
Expand Down Expand Up @@ -263,20 +263,20 @@ func (s *Subscriber) query(
}
tx, err := s.db.BeginTx(ctx, txOptions)
if err != nil {
return false, errors.Wrap(err, "could not begin tx for querying")
return false, fmt.Errorf("could not begin tx for querying: %w", err)
}

defer func() {
if err != nil {
rollbackErr := tx.Rollback()
if rollbackErr != nil && rollbackErr != sql.ErrTxDone {
if rollbackErr != nil && !errors.Is(rollbackErr, sql.ErrTxDone) {
logger.Error("could not rollback tx for querying message", rollbackErr, watermill.LogFields{
"query_err": err,
})
}
} else {
commitErr := tx.Commit()
if commitErr != nil && commitErr != sql.ErrTxDone {
if commitErr != nil && !errors.Is(commitErr, sql.ErrTxDone) {
logger.Error("could not commit tx for querying message", commitErr, nil)
}
}
Expand All @@ -293,12 +293,12 @@ func (s *Subscriber) query(
})
rows, err := tx.QueryContext(ctx, selectQuery.Query, selectQuery.Args...)
if err != nil {
return false, errors.Wrap(err, "could not query message")
return false, fmt.Errorf("could not query message: %w", err)
}

defer func() {
if rowsCloseErr := rows.Close(); rowsCloseErr != nil {
err = stdErrors.Join(err, errors.Wrap(err, "could not close rows"))
err = errors.Join(err, fmt.Errorf("could not close rows: %w", err))
}
}()

Expand All @@ -309,10 +309,10 @@ func (s *Subscriber) query(

for rows.Next() {
row, err := s.config.SchemaAdapter.UnmarshalMessage(rows)
if errors.Cause(err) == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return true, nil
} else if err != nil {
return false, errors.Wrap(err, "could not unmarshal message from query")
return false, fmt.Errorf("could not unmarshal message from query: %w", err)
}

messageRows = append(messageRows, row)
Expand All @@ -321,7 +321,7 @@ func (s *Subscriber) query(
for _, row := range messageRows {
acked, err := s.processMessage(ctx, topic, row, tx, out, logger)
if err != nil {
return false, errors.Wrap(err, "could not process message")
return false, fmt.Errorf("could not process message: %w", err)
}
if !acked {
break
Expand All @@ -348,7 +348,7 @@ func (s *Subscriber) query(

result, err := tx.ExecContext(ctx, ackQuery.Query, ackQuery.Args...)
if err != nil {
return false, errors.Wrap(err, "could not get args for acking the message")
return false, fmt.Errorf("could not get args for acking the message: %w", err)
}

rowsAffected, _ := result.RowsAffected()
Expand Down Expand Up @@ -388,7 +388,7 @@ func (s *Subscriber) processMessage(

_, err := tx.ExecContext(ctx, consumedQuery.Query, consumedQuery.Args...)
if err != nil {
return false, errors.Wrap(err, "cannot send consumed query")
return false, fmt.Errorf("cannot send consumed query: %w", err)
}

logger.Trace("Executed query to confirm message consumed", nil)
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/topic.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package sql

import (
"errors"
"fmt"
"regexp"

"github.com/pkg/errors"
)

var disallowedTopicCharacters = regexp.MustCompile(`[^A-Za-z0-9\-\$\:\.\_]`)
Expand All @@ -14,7 +14,7 @@ var ErrInvalidTopicName = errors.New("topic name should not contain characters m
// Topics are translated into SQL tables and patched into some queries, so this is done to prevent injection as well.
func validateTopicName(topic string) error {
if disallowedTopicCharacters.MatchString(topic) {
return errors.Wrap(ErrInvalidTopicName, topic)
return fmt.Errorf("%s: %w", topic, ErrInvalidTopicName)
}

return nil
Expand Down
11 changes: 5 additions & 6 deletions pkg/sql/topic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ import (
"testing"
"time"

"github.com/ThreeDotsLabs/watermill-sql/v3/pkg/sql"
"github.com/ThreeDotsLabs/watermill/message"

"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/ThreeDotsLabs/watermill-sql/v3/pkg/sql"
"github.com/ThreeDotsLabs/watermill/message"
)

func TestValidateTopicName(t *testing.T) {
Expand All @@ -22,11 +21,11 @@ func TestValidateTopicName(t *testing.T) {

err := publisher.Publish(cleverlyNamedTopic, message.NewMessage("uuid", nil))
require.Error(t, err)
assert.Equal(t, sql.ErrInvalidTopicName, errors.Cause(err))
assert.ErrorIs(t, err, sql.ErrInvalidTopicName)

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, err = subscriber.Subscribe(ctx, cleverlyNamedTopic)
require.Error(t, err)
assert.Equal(t, sql.ErrInvalidTopicName, errors.Cause(err))
assert.ErrorIs(t, err, sql.ErrInvalidTopicName)
}

0 comments on commit b943418

Please sign in to comment.