diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index 2c445c986..0bbf7a971 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -259,6 +259,16 @@ jobs: env: MYSQL_URI: mysql://root:password@tcp(localhost:${{ job.services.mariadb.ports[3306] }})/ + - name: Test Redis cluster + run: | + git clone https://github.com/gorse-cloud/redis-stack.git + pushd redis-stack + docker-compose up -d + popd + go test ./storage/cache -run ^TestRedis + env: + REDIS_URI: redis+cluster://localhost:7000 + golangci: name: lint runs-on: ubuntu-latest diff --git a/config/config.toml b/config/config.toml index 4b1a6fb8e..d595098e2 100644 --- a/config/config.toml +++ b/config/config.toml @@ -3,6 +3,7 @@ # The database for caching, support Redis, MySQL, Postgres and MongoDB: # redis://:@:/ # rediss://:@:/ +# redis+cluster://:@:,:,...,: # mysql://[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] # postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full # postgresql://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full diff --git a/storage/cache/database.go b/storage/cache/database.go index f550774c3..650e85d36 100644 --- a/storage/cache/database.go +++ b/storage/cache/database.go @@ -311,6 +311,19 @@ func Open(path, tablePrefix string, opts ...storage.Option) (Database, error) { return nil, errors.Trace(err) } return database, nil + } else if strings.HasPrefix(path, storage.RedisClusterPrefix) { + opt, err := ParseRedisClusterURL(path) + if err != nil { + return nil, err + } + database := new(Redis) + database.client = redis.NewClusterClient(opt) + database.TablePrefix = storage.TablePrefix(tablePrefix) + if err = redisotel.InstrumentTracing(database.client, redisotel.WithAttributes(semconv.DBSystemRedis)); err != nil { + log.Logger().Error("failed to add tracing for redis", zap.Error(err)) + return nil, errors.Trace(err) + } + return database, nil } else if strings.HasPrefix(path, storage.MongoPrefix) || strings.HasPrefix(path, storage.MongoSrvPrefix) { // connect to database database := new(MongoDB) diff --git a/storage/cache/redis.go b/storage/cache/redis.go index 84f4cfb74..18cbb507f 100644 --- a/storage/cache/redis.go +++ b/storage/cache/redis.go @@ -19,6 +19,7 @@ import ( "encoding/base64" "fmt" "io" + "net/url" "strconv" "strings" "time" @@ -88,7 +89,13 @@ func (r *Redis) Init() error { func (r *Redis) Scan(work func(string) error) error { ctx := context.Background() - return r.scan(ctx, r.client, work) + if clusterClient, isCluster := r.client.(*redis.ClusterClient); isCluster { + return clusterClient.ForEachMaster(ctx, func(ctx context.Context, client *redis.Client) error { + return r.scan(ctx, client, work) + }) + } else { + return r.scan(ctx, r.client, work) + } } func (r *Redis) scan(ctx context.Context, client redis.UniversalClient, work func(string) error) error { @@ -115,10 +122,16 @@ func (r *Redis) scan(ctx context.Context, client redis.UniversalClient, work fun func (r *Redis) Purge() error { ctx := context.Background() - return r.purge(ctx, r.client) + if clusterClient, isCluster := r.client.(*redis.ClusterClient); isCluster { + return clusterClient.ForEachMaster(ctx, func(ctx context.Context, client *redis.Client) error { + return r.purge(ctx, client, isCluster) + }) + } else { + return r.purge(ctx, r.client, isCluster) + } } -func (r *Redis) purge(ctx context.Context, client redis.UniversalClient) error { +func (r *Redis) purge(ctx context.Context, client redis.UniversalClient, isCluster bool) error { var ( result []string cursor uint64 @@ -130,8 +143,20 @@ func (r *Redis) purge(ctx context.Context, client redis.UniversalClient) error { return errors.Trace(err) } if len(result) > 0 { - if err = client.Del(ctx, result...).Err(); err != nil { - return errors.Trace(err) + if isCluster { + p := client.Pipeline() + for _, key := range result { + if err = p.Del(ctx, key).Err(); err != nil { + return errors.Trace(err) + } + } + if _, err = p.Exec(ctx); err != nil { + return errors.Trace(err) + } + } else { + if err = client.Del(ctx, result...).Err(); err != nil { + return errors.Trace(err) + } } } if cursor == 0 { @@ -488,3 +513,159 @@ func escape(s string) string { ) return r.Replace(s) } + +func ParseRedisClusterURL(redisURL string) (*redis.ClusterOptions, error) { + options := &redis.ClusterOptions{} + uri := redisURL + + var err error + if strings.HasPrefix(redisURL, storage.RedisClusterPrefix) { + uri = uri[len(storage.RedisClusterPrefix):] + } else { + return nil, fmt.Errorf("scheme must be \"redis+cluster\"") + } + + if idx := strings.Index(uri, "@"); idx != -1 { + userInfo := uri[:idx] + uri = uri[idx+1:] + + username := userInfo + var password string + + if idx := strings.Index(userInfo, ":"); idx != -1 { + username = userInfo[:idx] + password = userInfo[idx+1:] + } + + // Validate and process the username. + if strings.Contains(username, "/") { + return nil, fmt.Errorf("unescaped slash in username") + } + options.Username, err = url.PathUnescape(username) + if err != nil { + return nil, errors.Wrap(err, fmt.Errorf("invalid username")) + } + + // Validate and process the password. + if strings.Contains(password, ":") { + return nil, fmt.Errorf("unescaped colon in password") + } + if strings.Contains(password, "/") { + return nil, fmt.Errorf("unescaped slash in password") + } + options.Password, err = url.PathUnescape(password) + if err != nil { + return nil, errors.Wrap(err, fmt.Errorf("invalid password")) + } + } + + // fetch the hosts field + hosts := uri + if idx := strings.IndexAny(uri, "/?@"); idx != -1 { + if uri[idx] == '@' { + return nil, fmt.Errorf("unescaped @ sign in user info") + } + hosts = uri[:idx] + } + + options.Addrs = strings.Split(hosts, ",") + uri = uri[len(hosts):] + if len(uri) > 0 && uri[0] == '/' { + uri = uri[1:] + } + + // grab connection arguments from URI + connectionArgsFromQueryString, err := extractQueryArgsFromURI(uri) + if err != nil { + return nil, err + } + for _, pair := range connectionArgsFromQueryString { + err = addOption(options, pair) + if err != nil { + return nil, err + } + } + + return options, nil +} + +func extractQueryArgsFromURI(uri string) ([]string, error) { + if len(uri) == 0 { + return nil, nil + } + + if uri[0] != '?' { + return nil, errors.New("must have a ? separator between path and query") + } + + uri = uri[1:] + if len(uri) == 0 { + return nil, nil + } + return strings.FieldsFunc(uri, func(r rune) bool { return r == ';' || r == '&' }), nil +} + +type optionHandler struct { + int *int + bool *bool + duration *time.Duration +} + +func addOption(options *redis.ClusterOptions, pair string) error { + kv := strings.SplitN(pair, "=", 2) + if len(kv) != 2 || kv[0] == "" { + return fmt.Errorf("invalid option") + } + + key, err := url.QueryUnescape(kv[0]) + if err != nil { + return errors.Wrap(err, errors.Errorf("invalid option key %q", kv[0])) + } + + value, err := url.QueryUnescape(kv[1]) + if err != nil { + return errors.Wrap(err, errors.Errorf("invalid option value %q", kv[1])) + } + + handlers := map[string]optionHandler{ + "max_retries": {int: &options.MaxRetries}, + "min_retry_backoff": {duration: &options.MinRetryBackoff}, + "max_retry_backoff": {duration: &options.MaxRetryBackoff}, + "dial_timeout": {duration: &options.DialTimeout}, + "read_timeout": {duration: &options.ReadTimeout}, + "write_timeout": {duration: &options.WriteTimeout}, + "pool_fifo": {bool: &options.PoolFIFO}, + "pool_size": {int: &options.PoolSize}, + "pool_timeout": {duration: &options.PoolTimeout}, + "min_idle_conns": {int: &options.MinIdleConns}, + "max_idle_conns": {int: &options.MaxIdleConns}, + "conn_max_idle_time": {duration: &options.ConnMaxIdleTime}, + "conn_max_lifetime": {duration: &options.ConnMaxLifetime}, + } + + lowerKey := strings.ToLower(key) + if handler, ok := handlers[lowerKey]; ok { + if handler.int != nil { + *handler.int, err = strconv.Atoi(value) + if err != nil { + return errors.Wrap(err, fmt.Errorf("invalid '%s' value: %q", key, value)) + } + } else if handler.duration != nil { + *handler.duration, err = time.ParseDuration(value) + if err != nil { + return errors.Wrap(err, fmt.Errorf("invalid '%s' value: %q", key, value)) + } + } else if handler.bool != nil { + *handler.bool, err = strconv.ParseBool(value) + if err != nil { + return errors.Wrap(err, fmt.Errorf("invalid '%s' value: %q", key, value)) + } + } else { + return fmt.Errorf("redis: unexpected option: %s", key) + } + } else { + return fmt.Errorf("redis: unexpected option: %s", key) + } + + return nil +} diff --git a/storage/cache/redis_test.go b/storage/cache/redis_test.go index ceefdd88a..bfbd377d3 100644 --- a/storage/cache/redis_test.go +++ b/storage/cache/redis_test.go @@ -114,3 +114,27 @@ func BenchmarkRedis(b *testing.B) { // benchmark benchmark(b, database) } + +func TestParseRedisClusterURL(t *testing.T) { + options, err := ParseRedisClusterURL("redis+cluster://username:password@127.0.0.1:6379,127.0.0.1:6380,127.0.0.1:6381/?" + + "max_retries=1000&dial_timeout=1h&pool_fifo=true") + if assert.NoError(t, err) { + assert.Equal(t, "username", options.Username) + assert.Equal(t, "password", options.Password) + assert.Equal(t, []string{"127.0.0.1:6379", "127.0.0.1:6380", "127.0.0.1:6381"}, options.Addrs) + assert.Equal(t, 1000, options.MaxRetries) + assert.Equal(t, time.Hour, options.DialTimeout) + assert.True(t, options.PoolFIFO) + } + + _, err = ParseRedisClusterURL("redis://") + assert.Error(t, err) + _, err = ParseRedisClusterURL("redis+cluster://username:password@127.0.0.1:6379/?max_retries=a") + assert.Error(t, err) + _, err = ParseRedisClusterURL("redis+cluster://username:password@127.0.0.1:6379/?dial_timeout=a") + assert.Error(t, err) + _, err = ParseRedisClusterURL("redis+cluster://username:password@127.0.0.1:6379/?pool_fifo=a") + assert.Error(t, err) + _, err = ParseRedisClusterURL("redis+cluster://username:password@127.0.0.1:6379/?a=1") + assert.Error(t, err) +} diff --git a/storage/scheme.go b/storage/scheme.go index 4bce902e0..d0fa9936b 100644 --- a/storage/scheme.go +++ b/storage/scheme.go @@ -29,17 +29,18 @@ import ( ) const ( - MySQLPrefix = "mysql://" - MongoPrefix = "mongodb://" - MongoSrvPrefix = "mongodb+srv://" - PostgresPrefix = "postgres://" - PostgreSQLPrefix = "postgresql://" - ClickhousePrefix = "clickhouse://" - CHHTTPPrefix = "chhttp://" - CHHTTPSPrefix = "chhttps://" - SQLitePrefix = "sqlite://" - RedisPrefix = "redis://" - RedissPrefix = "rediss://" + MySQLPrefix = "mysql://" + MongoPrefix = "mongodb://" + MongoSrvPrefix = "mongodb+srv://" + PostgresPrefix = "postgres://" + PostgreSQLPrefix = "postgresql://" + ClickhousePrefix = "clickhouse://" + CHHTTPPrefix = "chhttp://" + CHHTTPSPrefix = "chhttps://" + SQLitePrefix = "sqlite://" + RedisPrefix = "redis://" + RedissPrefix = "rediss://" + RedisClusterPrefix = "redis+cluster://" ) func AppendURLParams(rawURL string, params []lo.Tuple2[string, string]) (string, error) {