diff --git a/rueidiscompat/adapter.go b/rueidiscompat/adapter.go index 4d9324ea..9fade169 100644 --- a/rueidiscompat/adapter.go +++ b/rueidiscompat/adapter.go @@ -1100,20 +1100,25 @@ const ( func (c *Compat) BitCount(ctx context.Context, key string, bitCount *BitCount) *IntCmd { ret := &IntCmd{} + if bitCount == nil { - resp := c.client.Do(ctx, c.client.B().Bitcount().Key(key).Build()) + resp := c.client.Do(ctx, c.client.B().Bitcount().Key(key).Bit().Build()) // Default to BIT ret.val, ret.err = resp.ToInt64() return ret } builder := c.client.B().Bitcount().Key(key).Start(bitCount.Start).End(bitCount.End) + + // Default to BIT unless specified otherwise var resp rueidis.RedisResult - if bitCount.Unit == BitCountIndexBit { + switch bitCount.Unit { + case BitCountIndexByte: + resp = c.client.Do(ctx, builder.Byte().Build()) // Handling for BYTE + case BitCountIndexBit, "": // Default to BIT if not specified resp = c.client.Do(ctx, builder.Bit().Build()) - } else if bitCount.Unit == BitCountIndexByte { - resp = c.client.Do(ctx, builder.Byte().Build()) - } else { - panic("unsupported unit") + default: + ret.err = fmt.Errorf("unsupported unit: %s", bitCount.Unit) + return ret } ret.val, ret.err = resp.ToInt64() diff --git a/rueidiscompat/adapter_test.go b/rueidiscompat/adapter_test.go index 65700f07..e4d64d6c 100644 --- a/rueidiscompat/adapter_test.go +++ b/rueidiscompat/adapter_test.go @@ -36,9 +36,6 @@ import ( "testing" "time" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" - "github.com/redis/rueidis" ) @@ -1128,28 +1125,41 @@ func testAdapter(resp3 bool) { Expect(get.Val()).To(Equal("Hello World")) }) - It("should BitCount", func() { + It("should BitCount with and without specified unit", func() { set := adapter.Set(ctx, "key", "foobar", 0) Expect(set.Err()).NotTo(HaveOccurred()) Expect(set.Val()).To(Equal("OK")) + // Test default (bit counting) bitCount := adapter.BitCount(ctx, "key", nil) Expect(bitCount.Err()).NotTo(HaveOccurred()) - Expect(bitCount.Val()).To(Equal(int64(26))) + Expect(bitCount.Val()).To(Equal(int64(26))) // Assuming the full bit count for "foobar" + // Test bit counting explicitly specified bitCount = adapter.BitCount(ctx, "key", &BitCount{ Start: 0, End: 0, + Unit: BitCountIndexBit, }) Expect(bitCount.Err()).NotTo(HaveOccurred()) - Expect(bitCount.Val()).To(Equal(int64(4))) + Expect(bitCount.Val()).To(Equal(int64(4))) // Count bits in the first byte + // Test byte counting bitCount = adapter.BitCount(ctx, "key", &BitCount{ Start: 1, End: 1, + Unit: BitCountIndexByte, }) Expect(bitCount.Err()).NotTo(HaveOccurred()) - Expect(bitCount.Val()).To(Equal(int64(6))) + Expect(bitCount.Val()).To(Equal(int64(6))) // Count bits in the second byte + + // Test handling unsupported units + bitCount = adapter.BitCount(ctx, "key", &BitCount{ + Start: 1, + End: 1, + Unit: "UNSUPPORTED", + }) + Expect(bitCount.Err()).To(HaveOccurred()) }) It("should BitOpAnd", func() {