From d9cb89065ff093aa61ed8ecdc383860fd0d9efa0 Mon Sep 17 00:00:00 2001
From: Arve Knudsen <arve.knudsen@gmail.com>
Date: Fri, 24 Jan 2025 09:51:14 +0100
Subject: [PATCH] Pass cluster

Signed-off-by: Arve Knudsen <arve.knudsen@gmail.com>
---
 Makefile                                       |  3 +++
 go.mod                                         |  2 +-
 go.sum                                         |  4 ++--
 pkg/api/middleware.go                          |  7 ++-----
 pkg/frontend/config.go                         |  9 +++++++--
 pkg/frontend/transport/roundtripper.go         | 10 +++++++---
 pkg/frontend/v1/frontend_test.go               |  2 +-
 .../v2/frontend_scheduler_adapter_test.go      |  6 +++---
 pkg/frontend/v2/frontend_test.go               |  2 +-
 pkg/mimir/modules.go                           |  6 +++++-
 .../grafana/dskit/clusterutil/clusterutil.go   |  6 ++++++
 .../grafana/dskit/httpgrpc/httpgrpc.go         | 18 +++++++++++++++++-
 vendor/modules.txt                             |  3 ++-
 13 files changed, 57 insertions(+), 21 deletions(-)
 create mode 100644 vendor/github.com/grafana/dskit/clusterutil/clusterutil.go

diff --git a/Makefile b/Makefile
index e4ac16b55d4..faeb5481afa 100644
--- a/Makefile
+++ b/Makefile
@@ -453,6 +453,9 @@ lint: check-makefiles
 	faillint -paths \
 		"github.com/twmb/franz-go/pkg/kgo.{AllowAutoTopicCreation}" \
 		./pkg/... ./cmd/... ./tools/... ./integration/...
+	# We need to ensure that when creating http-grpc requests, the X-Cluster header is included.
+	faillint -paths "github.com/grafana/dskit/httpgrpc.{FromHTTPRequest}=github.com/grafana/dskit/httpgrpc.FromHTTPRequestWithCluster" \
+		./pkg/... ./cmd/... ./tools/... ./integration/...
 
 format: ## Run gofmt and goimports.
 	find . $(DONT_FIND) -name '*.pb.go' -prune -o -type f -name '*.go' -exec gofmt -w -s {} \;
diff --git a/go.mod b/go.mod
index 5867142d6e3..7ce21e3e5f3 100644
--- a/go.mod
+++ b/go.mod
@@ -22,7 +22,7 @@ require (
 	github.com/golang/snappy v0.0.4
 	github.com/google/gopacket v1.1.19
 	github.com/gorilla/mux v1.8.1
-	github.com/grafana/dskit v0.0.0-20250123101449-feb230ca9dc2
+	github.com/grafana/dskit v0.0.0-20250124130032-aff6c876915b
 	github.com/grafana/e2e v0.1.2-0.20240118170847-db90b84177fc
 	github.com/hashicorp/golang-lru v1.0.2 // indirect
 	github.com/influxdata/influxdb/v2 v2.7.11
diff --git a/go.sum b/go.sum
index b0eed302f75..36ee3843d79 100644
--- a/go.sum
+++ b/go.sum
@@ -1271,8 +1271,8 @@ github.com/grafana-tools/sdk v0.0.0-20220919052116-6562121319fc h1:PXZQA2WCxe85T
 github.com/grafana-tools/sdk v0.0.0-20220919052116-6562121319fc/go.mod h1:AHHlOEv1+GGQ3ktHMlhuTUwo3zljV3QJbC0+8o2kn+4=
 github.com/grafana/alerting v0.0.0-20250113170557-b4ab2ba363a8 h1:mdI6P22PgFD7bQ0Yf4h8cfHSldak4nxogvlsTHZyZmc=
 github.com/grafana/alerting v0.0.0-20250113170557-b4ab2ba363a8/go.mod h1:QsnoKX/iYZxA4Cv+H+wC7uxutBD8qi8ZW5UJvD2TYmU=
-github.com/grafana/dskit v0.0.0-20250123101449-feb230ca9dc2 h1:2ZN2dTx3NDEvREKr5UH6qaAtI+g6wkItol9TzZMGhcQ=
-github.com/grafana/dskit v0.0.0-20250123101449-feb230ca9dc2/go.mod h1:SPLNCARd4xdjCkue0O6hvuoveuS1dGJjDnfxYe405YQ=
+github.com/grafana/dskit v0.0.0-20250124130032-aff6c876915b h1:y34FUoHHxGMbsX8nsGdCdJGfazKlGjf/7QhvgFMn0HA=
+github.com/grafana/dskit v0.0.0-20250124130032-aff6c876915b/go.mod h1:SPLNCARd4xdjCkue0O6hvuoveuS1dGJjDnfxYe405YQ=
 github.com/grafana/e2e v0.1.2-0.20240118170847-db90b84177fc h1:BW+LjKJDz0So5LI8UZfW5neWeKpSkWqhmGjQFzcFfLM=
 github.com/grafana/e2e v0.1.2-0.20240118170847-db90b84177fc/go.mod h1:JVmqPBe8A/pZWwRoJW5ZjyALeY5OXMzPl7LrVXOdZAI=
 github.com/grafana/franz-go v0.0.0-20241009100846-782ba1442937 h1:fwwnG/NcygoS6XbAaEyK2QzMXI/BZIEJvQ3CD+7XZm8=
diff --git a/pkg/api/middleware.go b/pkg/api/middleware.go
index 0151618ea63..147eeb7740f 100644
--- a/pkg/api/middleware.go
+++ b/pkg/api/middleware.go
@@ -6,6 +6,7 @@ import (
 
 	"github.com/go-kit/log"
 	"github.com/go-kit/log/level"
+	"github.com/grafana/dskit/clusterutil"
 	"github.com/grafana/dskit/middleware"
 )
 
@@ -13,7 +14,7 @@ import (
 func ClusterValidationMiddleware(cluster string, logger log.Logger) middleware.Interface {
 	return middleware.Func(func(next http.Handler) http.Handler {
 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-			reqCluster := r.Header.Get(clusterHeader)
+			reqCluster := r.Header.Get(clusterutil.ClusterHeader)
 			if reqCluster != cluster {
 				level.Warn(logger).Log("msg", "rejecting request intended for wrong cluster", "cluster", cluster, "request_cluster", reqCluster)
 				http.Error(w, fmt.Sprintf("request intended for cluster %q - this is cluster %q", reqCluster, cluster),
@@ -25,7 +26,3 @@ func ClusterValidationMiddleware(cluster string, logger log.Logger) middleware.I
 		})
 	})
 }
-
-const (
-	clusterHeader = "X-Cluster"
-)
diff --git a/pkg/frontend/config.go b/pkg/frontend/config.go
index a6a397f391c..0195f731da3 100644
--- a/pkg/frontend/config.go
+++ b/pkg/frontend/config.go
@@ -30,6 +30,8 @@ type CombinedFrontendConfig struct {
 	QueryMiddleware querymiddleware.Config `yaml:",inline"`
 
 	DownstreamURL string `yaml:"downstream_url" category:"advanced"`
+
+	Cluster string `yaml:"-"`
 }
 
 func (cfg *CombinedFrontendConfig) RegisterFlags(f *flag.FlagSet, logger log.Logger) {
@@ -66,6 +68,9 @@ func InitFrontend(
 	reg prometheus.Registerer,
 	codec querymiddleware.Codec,
 ) (http.RoundTripper, *v1.Frontend, *v2.Frontend, error) {
+	if cfg.Cluster == "" {
+		panic("cluster not defined")
+	}
 	switch {
 	case cfg.DownstreamURL != "":
 		// If the user has specified a downstream Prometheus, then we should use that.
@@ -88,7 +93,7 @@ func InitFrontend(
 		}
 
 		fr, err := v2.NewFrontend(cfg.FrontendV2, v2Limits, log, reg, codec)
-		return transport.AdaptGrpcRoundTripperToHTTPRoundTripper(fr), nil, fr, err
+		return transport.AdaptGrpcRoundTripperToHTTPRoundTripper(fr, cfg.Cluster), nil, fr, err
 
 	default:
 		// No scheduler = use original frontend.
@@ -96,6 +101,6 @@ func InitFrontend(
 		if err != nil {
 			return nil, nil, nil, err
 		}
-		return transport.AdaptGrpcRoundTripperToHTTPRoundTripper(fr), fr, nil, nil
+		return transport.AdaptGrpcRoundTripperToHTTPRoundTripper(fr, cfg.Cluster), fr, nil, nil
 	}
 }
diff --git a/pkg/frontend/transport/roundtripper.go b/pkg/frontend/transport/roundtripper.go
index 141fade2ee3..062ba63190a 100644
--- a/pkg/frontend/transport/roundtripper.go
+++ b/pkg/frontend/transport/roundtripper.go
@@ -20,13 +20,17 @@ type GrpcRoundTripper interface {
 	RoundTripGRPC(context.Context, *httpgrpc.HTTPRequest) (*httpgrpc.HTTPResponse, io.ReadCloser, error)
 }
 
-func AdaptGrpcRoundTripperToHTTPRoundTripper(r GrpcRoundTripper) http.RoundTripper {
-	return &grpcRoundTripperAdapter{roundTripper: r}
+func AdaptGrpcRoundTripperToHTTPRoundTripper(r GrpcRoundTripper, cluster string) http.RoundTripper {
+	return &grpcRoundTripperAdapter{
+		roundTripper: r,
+		cluster:      cluster,
+	}
 }
 
 // This adapter wraps GrpcRoundTripper and converted it into http.RoundTripper
 type grpcRoundTripperAdapter struct {
 	roundTripper GrpcRoundTripper
+	cluster      string
 }
 
 type buffer struct {
@@ -39,7 +43,7 @@ func (b *buffer) Bytes() []byte {
 }
 
 func (a *grpcRoundTripperAdapter) RoundTrip(r *http.Request) (*http.Response, error) {
-	req, err := httpgrpc.FromHTTPRequest(r)
+	req, err := httpgrpc.FromHTTPRequestWithCluster(r, a.cluster)
 	if err != nil {
 		return nil, err
 	}
diff --git a/pkg/frontend/v1/frontend_test.go b/pkg/frontend/v1/frontend_test.go
index 1ee7cfe66e4..5eaa9f27f78 100644
--- a/pkg/frontend/v1/frontend_test.go
+++ b/pkg/frontend/v1/frontend_test.go
@@ -337,7 +337,7 @@ func testFrontend(t *testing.T, config Config, handler http.Handler, test func(a
 	handlerCfg := transport.HandlerConfig{QueryStatsEnabled: true}
 	flagext.DefaultValues(&handlerCfg)
 
-	rt := transport.AdaptGrpcRoundTripperToHTTPRoundTripper(v1)
+	rt := transport.AdaptGrpcRoundTripperToHTTPRoundTripper(v1, "")
 	r := mux.NewRouter()
 	r.PathPrefix("/").Handler(middleware.Merge(
 		middleware.AuthenticateUser,
diff --git a/pkg/frontend/v2/frontend_scheduler_adapter_test.go b/pkg/frontend/v2/frontend_scheduler_adapter_test.go
index 7504019debb..24e767fef9b 100644
--- a/pkg/frontend/v2/frontend_scheduler_adapter_test.go
+++ b/pkg/frontend/v2/frontend_scheduler_adapter_test.go
@@ -136,7 +136,7 @@ func TestExtractAdditionalQueueDimensions(t *testing.T) {
 			reqs := []*http.Request{rangeHTTPReq, labelValuesHTTPReq}
 
 			for _, req := range reqs {
-				httpgrpcReq, err := httpgrpc.FromHTTPRequest(req)
+				httpgrpcReq, err := httpgrpc.FromHTTPRequestWithCluster(req, "")
 				require.NoError(t, err)
 
 				additionalQueueDimensions, err := adapter.extractAdditionalQueueDimensions(
@@ -179,7 +179,7 @@ func TestExtractAdditionalQueueDimensions(t *testing.T) {
 			ctx := user.InjectOrgID(context.Background(), "tenant-0")
 
 			instantHTTPReq := makeInstantHTTPRequest(ctx, testData.time)
-			httpgrpcReq, err := httpgrpc.FromHTTPRequest(instantHTTPReq)
+			httpgrpcReq, err := httpgrpc.FromHTTPRequestWithCluster(instantHTTPReq, "")
 			require.NoError(t, err)
 
 			additionalQueueDimensions, err := adapter.extractAdditionalQueueDimensions(
@@ -229,7 +229,7 @@ func TestQueryDecoding(t *testing.T) {
 			ctx := user.InjectOrgID(context.Background(), "tenant-0")
 
 			labelValuesHTTPReq := makeLabelValuesHTTPRequest(ctx, testData.start, testData.end)
-			httpgrpcReq, err := httpgrpc.FromHTTPRequest(labelValuesHTTPReq)
+			httpgrpcReq, err := httpgrpc.FromHTTPRequestWithCluster(labelValuesHTTPReq, "")
 			require.NoError(t, err)
 
 			additionalQueueDimensions, err := adapter.extractAdditionalQueueDimensions(
diff --git a/pkg/frontend/v2/frontend_test.go b/pkg/frontend/v2/frontend_test.go
index e4f8e5b91a6..e7d3d3feaf6 100644
--- a/pkg/frontend/v2/frontend_test.go
+++ b/pkg/frontend/v2/frontend_test.go
@@ -532,7 +532,7 @@ func TestFrontendStreamingResponse(t *testing.T) {
 			})
 
 			req := httptest.NewRequest("GET", "/api/v1/cardinality/active_series?selector=metric", nil)
-			rt := transport.AdaptGrpcRoundTripperToHTTPRoundTripper(f)
+			rt := transport.AdaptGrpcRoundTripperToHTTPRoundTripper(f, "")
 
 			resp, err := rt.RoundTrip(req.WithContext(user.InjectOrgID(context.Background(), userID)))
 			require.NoError(t, err)
diff --git a/pkg/mimir/modules.go b/pkg/mimir/modules.go
index dfd4a7721a1..2b1dbe80b3e 100644
--- a/pkg/mimir/modules.go
+++ b/pkg/mimir/modules.go
@@ -15,6 +15,7 @@ import (
 
 	"github.com/go-kit/log"
 	"github.com/go-kit/log/level"
+	"github.com/grafana/dskit/clusterutil"
 	"github.com/grafana/dskit/dns"
 	httpgrpc_server "github.com/grafana/dskit/httpgrpc/server"
 	"github.com/grafana/dskit/kv"
@@ -718,7 +719,9 @@ func (t *Mimir) initFlusher() (serv services.Service, err error) {
 // initQueryFrontendCodec initializes query frontend codec.
 // NOTE: Grafana Enterprise Metrics depends on this.
 func (t *Mimir) initQueryFrontendCodec() (services.Service, error) {
-	t.QueryFrontendCodec = querymiddleware.NewPrometheusCodec(t.Registerer, t.Cfg.Querier.EngineConfig.LookbackDelta, t.Cfg.Frontend.QueryMiddleware.QueryResultResponseFormat, t.Cfg.Frontend.QueryMiddleware.ExtraPropagateHeaders)
+	// Always pass through the cluster header.
+	propagateHeaders := append([]string{clusterutil.ClusterHeader}, t.Cfg.Frontend.QueryMiddleware.ExtraPropagateHeaders...)
+	t.QueryFrontendCodec = querymiddleware.NewPrometheusCodec(t.Registerer, t.Cfg.Querier.EngineConfig.LookbackDelta, t.Cfg.Frontend.QueryMiddleware.QueryResultResponseFormat, propagateHeaders)
 	return nil, nil
 }
 
@@ -783,6 +786,7 @@ func (t *Mimir) initQueryFrontendTripperware() (serv services.Service, err error
 }
 
 func (t *Mimir) initQueryFrontend() (serv services.Service, err error) {
+	t.Cfg.Frontend.Cluster = t.Cfg.Server.Cluster
 	t.Cfg.Frontend.FrontendV2.QuerySchedulerDiscovery = t.Cfg.QueryScheduler.ServiceDiscovery
 	t.Cfg.Frontend.FrontendV2.LookBackDelta = t.Cfg.Querier.EngineConfig.LookbackDelta
 	t.Cfg.Frontend.FrontendV2.QueryStoreAfter = t.Cfg.Querier.QueryStoreAfter
diff --git a/vendor/github.com/grafana/dskit/clusterutil/clusterutil.go b/vendor/github.com/grafana/dskit/clusterutil/clusterutil.go
new file mode 100644
index 00000000000..a02d2d4fbd2
--- /dev/null
+++ b/vendor/github.com/grafana/dskit/clusterutil/clusterutil.go
@@ -0,0 +1,6 @@
+package clusterutil
+
+const (
+	// ClusterHeader is the name of the cluster identifying HTTP header.
+	ClusterHeader = "X-Cluster"
+)
diff --git a/vendor/github.com/grafana/dskit/httpgrpc/httpgrpc.go b/vendor/github.com/grafana/dskit/httpgrpc/httpgrpc.go
index 616023899b7..350b2c8a90b 100644
--- a/vendor/github.com/grafana/dskit/httpgrpc/httpgrpc.go
+++ b/vendor/github.com/grafana/dskit/httpgrpc/httpgrpc.go
@@ -17,6 +17,7 @@ import (
 	"github.com/gogo/status"
 	"google.golang.org/grpc/metadata"
 
+	"github.com/grafana/dskit/clusterutil"
 	"github.com/grafana/dskit/grpcutil"
 	"github.com/grafana/dskit/log"
 )
@@ -42,7 +43,7 @@ func (nopCloser) Close() error { return nil }
 // BytesBuffer returns the underlaying `bytes.buffer` used to build this io.ReadCloser.
 func (n nopCloser) BytesBuffer() *bytes.Buffer { return n.Buffer }
 
-// FromHTTPRequest converts an ordinary http.Request into an httpgrpc.HTTPRequest
+// FromHTTPRequest converts an ordinary http.Request into an httpgrpc.HTTPRequest.
 func FromHTTPRequest(r *http.Request) (*HTTPRequest, error) {
 	body, err := io.ReadAll(r.Body)
 	if err != nil {
@@ -56,6 +57,21 @@ func FromHTTPRequest(r *http.Request) (*HTTPRequest, error) {
 	}, nil
 }
 
+// FromHTTPRequestWithCluster converts an ordinary http.Request into an httpgrpc.HTTPRequest.
+// It's the same as FromHTTPRequest except that if cluster is non-empty, it has to be equal to the
+// middleware.ClusterHeader header (or an error is returned).
+func FromHTTPRequestWithCluster(r *http.Request, cluster string) (*HTTPRequest, error) {
+	if cluster != "" {
+		if c := r.Header.Get(clusterutil.ClusterHeader); c != cluster {
+			return nil, fmt.Errorf(
+				"httpgrpc.FromHTTPRequest: %q header should be %q, but is %q",
+				clusterutil.ClusterHeader, cluster, c,
+			)
+		}
+	}
+	return FromHTTPRequest(r)
+}
+
 // ToHTTPRequest converts httpgrpc.HTTPRequest to http.Request.
 func ToHTTPRequest(ctx context.Context, r *HTTPRequest) (*http.Request, error) {
 	req, err := http.NewRequest(r.Method, r.Url, nopCloser{Buffer: bytes.NewBuffer(r.Body)})
diff --git a/vendor/modules.txt b/vendor/modules.txt
index 502993c92ed..7c669993da4 100644
--- a/vendor/modules.txt
+++ b/vendor/modules.txt
@@ -644,12 +644,13 @@ github.com/grafana/alerting/receivers/webex
 github.com/grafana/alerting/receivers/webhook
 github.com/grafana/alerting/receivers/wecom
 github.com/grafana/alerting/templates
-# github.com/grafana/dskit v0.0.0-20250123101449-feb230ca9dc2
+# github.com/grafana/dskit v0.0.0-20250124130032-aff6c876915b
 ## explicit; go 1.21
 github.com/grafana/dskit/backoff
 github.com/grafana/dskit/ballast
 github.com/grafana/dskit/cache
 github.com/grafana/dskit/cancellation
+github.com/grafana/dskit/clusterutil
 github.com/grafana/dskit/concurrency
 github.com/grafana/dskit/crypto/tls
 github.com/grafana/dskit/dns