diff --git a/assets/config/config.yaml b/assets/config/config.yaml index 9cdf0a2695c..16cec689a67 100644 --- a/assets/config/config.yaml +++ b/assets/config/config.yaml @@ -16,4 +16,4 @@ operatorTests: scyllaDBVersions: updateFrom: "6.2.0-rc2" upgradeFrom: "6.1.2" - nodeSetupImage: "quay.io/scylladb/scylla-operator-images:node-setup-v0.0.2@sha256:210b1dd9bd60a5bf4056783f3132bdeef0cf9ab0a19eff0b620b2dfa5c4e5d61" + nodeSetupImage: "quay.io/scylladb/scylla-operator-images:node-setup-v0.0.3@sha256:c6b3de240cc5c884d5c617485bae35c51572cdfd39b6431d2e1f759c7d7feea1" diff --git a/pkg/util/cql/frame.go b/pkg/util/cql/frame.go new file mode 100644 index 00000000000..fd562ab9e7e --- /dev/null +++ b/pkg/util/cql/frame.go @@ -0,0 +1,75 @@ +// Copyright (c) 2024 ScyllaDB. + +package cql + +import ( + "bytes" + "fmt" +) + +const ( + headerLen = 9 + + // OptionsFrame is a minimal OPTIONS CQL frame. + // Ref: https://github.com/apache/cassandra/blob/f278f6774fc76465c182041e081982105c3e7dbb/doc/native_protocol_v4.spec + OptionsFrame = `\x04\x00\x00\x00\x05\x00\x00\x00\x00` +) + +type FrameParser struct { + buf *bytes.Buffer +} + +func NewFrameParser(buf *bytes.Buffer) *FrameParser { + return &FrameParser{ + buf: buf, + } +} + +func (fp *FrameParser) SkipHeader() { + _ = fp.readBytes(headerLen) +} + +func (fp *FrameParser) readByte() byte { + p, err := fp.buf.ReadByte() + if err != nil { + panic(fmt.Errorf("can't read byte from buffer: %w", err)) + } + return p +} + +func (fp *FrameParser) ReadShort() uint16 { + return uint16(fp.readByte())<<8 | uint16(fp.readByte()) +} + +func (fp *FrameParser) ReadStringMultiMap() map[string][]string { + n := fp.ReadShort() + m := make(map[string][]string, n) + for i := uint16(0); i < n; i++ { + k := fp.ReadString() + v := fp.ReadStringList() + m[k] = v + } + return m +} + +func (fp *FrameParser) readBytes(n int) []byte { + p := make([]byte, 0, n) + for i := 0; i < n; i++ { + p = append(p, fp.readByte()) + } + + return p +} + +func (fp *FrameParser) ReadString() string { + return string(fp.readBytes(int(fp.ReadShort()))) +} + +func (fp *FrameParser) ReadStringList() []string { + n := fp.ReadShort() + l := make([]string, 0, n) + for i := uint16(0); i < n; i++ { + l = append(l, fp.ReadString()) + } + return l +} diff --git a/pkg/util/cql/frame_test.go b/pkg/util/cql/frame_test.go new file mode 100644 index 00000000000..8dbd5a0fbd7 --- /dev/null +++ b/pkg/util/cql/frame_test.go @@ -0,0 +1,184 @@ +// Copyright (c) 2024 ScyllaDB. + +package cql + +import ( + "bytes" + "testing" + + "k8s.io/apimachinery/pkg/api/equality" +) + +func TestFrameParserReadShort(t *testing.T) { + tt := []struct { + name string + buffer []byte + expectedShort uint16 + expectedBuffer []byte + }{ + { + name: "consumes two bytes and returns short number", + buffer: []byte{ + 0x01, 0x00, + 0x02, 0x00, + 0x03, 0x00, + 0x04, 0x00, + }, + expectedShort: 256, + expectedBuffer: []byte{ + 0x02, 0x00, + 0x03, 0x00, + 0x04, 0x00, + }, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + buf := bytes.NewBuffer(tc.buffer) + fp := NewFrameParser(buf) + gotShort := fp.ReadShort() + if gotShort != tc.expectedShort { + t.Errorf("got %v short, expected %v", gotShort, tc.expectedShort) + } + if !equality.Semantic.DeepEqual(buf.Bytes(), tc.expectedBuffer) { + t.Errorf("got %v buffer, expected %v", buf, tc.expectedBuffer) + } + }) + } +} + +func TestFrameParserSkipHeader(t *testing.T) { + tt := []struct { + name string + buffer []byte + expectedBuffer []byte + }{ + { + name: "consumes first 9 bytes", + buffer: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x10}, + expectedBuffer: []byte{0x10}, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + buf := bytes.NewBuffer(tc.buffer) + fp := NewFrameParser(buf) + fp.SkipHeader() + if !equality.Semantic.DeepEqual(buf.Bytes(), tc.expectedBuffer) { + t.Errorf("got %v buffer, expected %v", buf, tc.expectedBuffer) + } + }) + } +} + +func TestFrameParserReadString(t *testing.T) { + tt := []struct { + name string + buffer []byte + expectedString string + expectedBuffer []byte + }{ + { + name: "consumes bytes from buffer by reading string length (uint16) and content", + buffer: []byte{ + 0x00, 0x05, // 5 - string length + 0x68, 0x65, 0x6c, 0x6c, 0x6f, // hello + }, + expectedString: "hello", + expectedBuffer: []byte{}, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + buf := bytes.NewBuffer(tc.buffer) + fp := NewFrameParser(buf) + gotString := fp.ReadString() + if gotString != tc.expectedString { + t.Errorf("got %v string, expected %v", gotString, tc.expectedString) + } + if !equality.Semantic.DeepEqual(buf.Bytes(), tc.expectedBuffer) { + t.Errorf("got %v buffer, expected %v", buf, tc.expectedBuffer) + } + }) + } +} + +func TestFrameParserReadStringList(t *testing.T) { + tt := []struct { + name string + buffer []byte + expectedStrings []string + expectedBuffer []byte + }{ + { + name: "consumes bytes from buffer by reading string list length and strings", + buffer: []byte{ + 0x00, 0x02, // 2 - slice length + 0x00, 0x05, // 5 - string length + 0x68, 0x65, 0x6c, 0x6c, 0x6f, // hello + 0x00, 0x05, // 5 - string length + 0x77, 0x6f, 0x72, 0x6c, 0x64, // world + }, + expectedStrings: []string{"hello", "world"}, + expectedBuffer: []byte{}, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + buf := bytes.NewBuffer(tc.buffer) + fp := NewFrameParser(buf) + gotString := fp.ReadStringList() + if !equality.Semantic.DeepEqual(gotString, tc.expectedStrings) { + t.Errorf("got %v strings, expected %v", gotString, tc.expectedStrings) + } + if !equality.Semantic.DeepEqual(buf.Bytes(), tc.expectedBuffer) { + t.Errorf("got %v buffer, expected %v", buf, tc.expectedBuffer) + } + }) + } +} + +func TestFrameParserReadStringMultiMap(t *testing.T) { + tt := []struct { + name string + buffer []byte + expectedMultiMap map[string][]string + expectedBuffer []byte + }{ + { + name: "consumes bytes from buffer by reading string list length and strings", + buffer: []byte{ + 0x00, 0x01, // 1 - map elements + 0x00, 0x03, // 3 - key length + 0x66, 0x6f, 0x6f, // foo + 0x00, 0x02, // 2 - slice length + 0x00, 0x05, // 5 - string length + 0x68, 0x65, 0x6c, 0x6c, 0x6f, // hello + 0x00, 0x05, // 5 - string length + 0x77, 0x6f, 0x72, 0x6c, 0x64, // world + }, + expectedMultiMap: map[string][]string{ + "foo": {"hello", "world"}, + }, + expectedBuffer: []byte{}, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + buf := bytes.NewBuffer(tc.buffer) + fp := NewFrameParser(buf) + gotMultiMap := fp.ReadStringMultiMap() + if !equality.Semantic.DeepEqual(gotMultiMap, tc.expectedMultiMap) { + t.Errorf("got %v multimap, expected %v", gotMultiMap, tc.expectedMultiMap) + } + if !equality.Semantic.DeepEqual(buf.Bytes(), tc.expectedBuffer) { + t.Errorf("got %v buffer, expected %v", buf, tc.expectedBuffer) + } + }) + } +} diff --git a/test/e2e/set/scyllacluster/scyllacluster_shardawareness.go b/test/e2e/set/scyllacluster/scyllacluster_shardawareness.go index 0279a93faee..a058ff62375 100644 --- a/test/e2e/set/scyllacluster/scyllacluster_shardawareness.go +++ b/test/e2e/set/scyllacluster/scyllacluster_shardawareness.go @@ -3,24 +3,23 @@ package scyllacluster import ( + "bytes" "context" "fmt" - "net" - "sync" - "time" + "math/rand" - "github.com/gocql/gocql" g "github.com/onsi/ginkgo/v2" o "github.com/onsi/gomega" - "github.com/scylladb/gocqlx/v2" + configassests "github.com/scylladb/scylla-operator/assets/config" scyllav1 "github.com/scylladb/scylla-operator/pkg/api/scylla/v1" "github.com/scylladb/scylla-operator/pkg/controllerhelpers" + "github.com/scylladb/scylla-operator/pkg/util/cql" "github.com/scylladb/scylla-operator/test/e2e/framework" "github.com/scylladb/scylla-operator/test/e2e/utils" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/utils/pointer" ) var _ = g.Describe("ScyllaCluster", func() { @@ -28,9 +27,7 @@ var _ = g.Describe("ScyllaCluster", func() { g.It("should allow to build connection pool using shard aware ports", func() { const ( - nonShardAwarePort = 9042 - shardAwarePort = 19042 - nrShards = 4 + nrShards = 4 ) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) @@ -69,57 +66,81 @@ var _ = g.Describe("ScyllaCluster", func() { o.Expect(err).NotTo(o.HaveOccurred()) o.Expect(hosts).To(o.HaveLen(1)) - connections := make(map[uint16]string) - var connectionsMut sync.Mutex - - clusterConfig := gocql.NewCluster(hosts...) - clusterConfig.Dialer = DialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { - sourcePort := gocql.ScyllaGetSourcePort(ctx) - localAddr, err := net.ResolveTCPAddr(network, fmt.Sprintf(":%d", sourcePort)) - if err != nil { - return nil, err - } - - framework.Infof("Connecting to %s using %d source port", addr, sourcePort) - connectionsMut.Lock() - connections[sourcePort] = addr - connectionsMut.Unlock() + clientPod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "client", + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "client", + Image: configassests.Project.OperatorTests.NodeSetupImage, + Command: []string{ + "sleep", + "infinity", + }, + }, + }, + TerminationGracePeriodSeconds: pointer.Int64(1), + RestartPolicy: corev1.RestartPolicyNever, + }, + } + clientPod, err = f.KubeClient().CoreV1().Pods(f.Namespace()).Create(ctx, clientPod, metav1.CreateOptions{}) - d := &net.Dialer{LocalAddr: localAddr} - return d.DialContext(ctx, network, addr) - }) + waitCtx2, waitCtx2Cancel := utils.ContextForPodStartup(ctx) + defer waitCtx2Cancel() + clientPod, err = controllerhelpers.WaitForPodState(waitCtx2, f.KubeClient().CoreV1().Pods(clientPod.Namespace), clientPod.Name, controllerhelpers.WaitForStateOptions{}, utils.PodIsRunning) - framework.By("Waiting for the driver to establish connection to shards") - session, err := gocqlx.WrapSession(clusterConfig.CreateSession()) - o.Expect(err).NotTo(o.HaveOccurred()) - defer session.Close() + const ( + scyllaShardKey = "SCYLLA_SHARD" + shardAwarePort = 19042 - err = wait.PollUntilContextTimeout(ctx, 100*time.Millisecond, 5*time.Second, true, func(context.Context) (done bool, err error) { - return len(connections) == nrShards, nil - }) - o.Expect(err).NotTo(o.HaveOccurred()) + connectionAttempts = 10 + ) - shardAwareAttempts := 0 - for sourcePort, addr := range connections { - // Control connection is also put in pool, and it always uses the default port. - if sourcePort == 0 { - o.Expect(addr).To(o.HaveSuffix(fmt.Sprintf("%d", nonShardAwarePort))) - continue + for shard := range nrShards { + port := shardPort(shard, nrShards) + + for i := range connectionAttempts { + framework.By("Establishing connection number %d to shard number %d", i, shard) + stdout, stderr, err := utils.ExecWithOptions(f.ClientConfig(), f.KubeClient().CoreV1(), utils.ExecOptions{ + Command: []string{ + "/usr/bin/bash", + "-euEo", + "pipefail", + "-O", + "inherit_errexit", + "-c", + fmt.Sprintf(`echo -e '%s' | nc -p %d %s %d`, cql.OptionsFrame, port, hosts[0], shardAwarePort)}, + Namespace: clientPod.Namespace, + PodName: clientPod.Name, + ContainerName: clientPod.Name, + CaptureStdout: true, + CaptureStderr: true, + }) + o.Expect(err).NotTo(o.HaveOccurred(), stdout, stderr) + o.Expect(stderr).To(o.BeEmpty()) + o.Expect(stdout).ToNot(o.BeEmpty()) + + fp := cql.NewFrameParser(bytes.NewBuffer([]byte(stdout))) + fp.SkipHeader() + + scyllaSupported := fp.ReadStringMultiMap() + o.Expect(scyllaSupported[scyllaShardKey]).To(o.HaveLen(1)) + o.Expect(scyllaSupported[scyllaShardKey][0]).To(o.Equal(fmt.Sprintf("%d", shard))) } - o.Expect(addr).To(o.HaveSuffix(fmt.Sprintf("%d", shardAwarePort))) - shardAwareAttempts++ } - - // Control connection used for shard number discovery, lands on some random shard. - // This connection is also put in pool, and driver only establish connections to missing shards - // using shard-aware-port. - // Connections to shard-aware-port are guaranteed to land on shard driver wants. - o.Expect(shardAwareAttempts).To(o.Equal(nrShards - 1)) }) }) -type DialerFunc func(ctx context.Context, network, addr string) (net.Conn, error) - -func (f DialerFunc) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { - return f(ctx, network, addr) +// Ref: https://github.com/scylladb/scylla-rust-driver/blob/de7d8a5c78ea0702bf6da80197f7c495a145c188/scylla/src/routing.rs#L104-L110 +func shardPort(shard, nrShards int) int { + const ( + maxPort = 65535 + minPort = 49152 + ) + maxRange := maxPort - nrShards + 1 + minRange := minPort + nrShards - 1 + r := rand.Intn(maxRange-minRange+1) + minRange + return r/nrShards*nrShards + shard }