Skip to content

Commit

Permalink
feat: custom headers are also supported for the query (gRPC request) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
bednar authored Mar 5, 2024
1 parent 5861cf4 commit da932f7
Show file tree
Hide file tree
Showing 4 changed files with 288 additions and 25 deletions.
21 changes: 21 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,26 @@
## 0.7.0 [unreleased]

### Features

1. [#107](https://github.com/InfluxCommunity/influxdb3-java/pull/107): Custom headers are also supported for the query (gRPC request)

```java
ClientConfig config = new ClientConfig.Builder()
.host("https://us-east-1-1.aws.cloud2.influxdata.com")
.token("my-token".toCharArray())
.database("my-database")
.headers(Map.of("X-Tracing-Id", "123"))
.build();

try (InfluxDBClient client = InfluxDBClient.getInstance(config)) {
//
// your code here
//
} catch (Exception e) {
throw new RuntimeException(e);
}
```

## 0.6.0 [2024-03-01]

### Features
Expand Down
31 changes: 24 additions & 7 deletions src/main/java/com/influxdb/v3/client/config/ClientConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@
* </li>
* <li><code>proxy</code> - HTTP proxy selector</li>
* <li><code>authenticator</code> - HTTP proxy authenticator</li>
* <li><code>headers</code> - set of HTTP headers to be added to requests</li>
* <li><code>headers</code> - headers to be added to requests</li>
* </ul>
* <p>
* If you want to create a client with custom configuration, you can use following code:
* <pre>
* ClientConfig config = new Config.Builder()
* ClientConfig config = new ClientConfig.Builder()
* .host("https://us-east-1-1.aws.cloud2.influxdata.com")
* .token("my-token".toCharArray())
* .database("my-database")
Expand Down Expand Up @@ -217,9 +217,9 @@ public Authenticator getAuthenticator() {
}

/**
* Gets custom HTTP headers.
* Gets custom headers for requests.
*
* @return the HTTP headers
* @return the headers
*/
@Nullable
public Map<String, String> getHeaders() {
Expand Down Expand Up @@ -465,9 +465,26 @@ public Builder authenticator(@Nullable final Authenticator authenticator) {
}

/**
* Sets the custom HTTP headers that will be included in requests.
* Sets the custom headers that will be added to requests. This is useful for adding custom headers to requests,
* such as tracing headers. To add custom headers use following code:
* <pre>
* ClientConfig config = new ClientConfig.Builder()
* .host("https://us-east-1-1.aws.cloud2.influxdata.com")
* .token("my-token".toCharArray())
* .database("my-database")
* .headers(Map.of("X-Tracing-Id", "123"))
* .build();
*
* @param headers Set of HTTP headers.
* try (InfluxDBClient client = InfluxDBClient.getInstance(config)) {
* //
* // your code here
* //
* } catch (Exception e) {
* throw new RuntimeException(e);
* }
* </pre>
*
* @param headers the headers to be added to requests
* @return this
*/
@Nonnull
Expand Down Expand Up @@ -526,7 +543,7 @@ public ClientConfig build(@Nonnull final String connectionString) throws Malform
/**
* Build an instance of {@code ClientConfig} from environment variables and/or system properties.
*
* @param env environment variables
* @param env environment variables
* @param properties system properties
* @return the configuration for an {@code InfluxDBClient}.
*/
Expand Down
60 changes: 42 additions & 18 deletions src/main/java/com/influxdb/v3/client/internal/FlightSqlClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand All @@ -60,32 +61,30 @@ final class FlightSqlClient implements AutoCloseable {
private final ObjectMapper objectMapper = new ObjectMapper();

FlightSqlClient(@Nonnull final ClientConfig config) {
this(config, null);
}

/**
* Constructor for testing purposes.
*
* @param config the client configuration
* @param client the flight client, if null a new client will be created
*/
FlightSqlClient(@Nonnull final ClientConfig config, @Nullable final FlightClient client) {
Arguments.checkNotNull(config, "config");

MetadataAdapter metadata = new MetadataAdapter(new Metadata());
if (config.getToken() != null && config.getToken().length > 0) {
metadata.insert("Authorization", "Bearer " + new String(config.getToken()));
}

this.headers = new HeaderCallOption(metadata);

Location location;
try {
URI uri = new URI(config.getHost());
if ("https".equals(uri.getScheme())) {
location = Location.forGrpcTls(uri.getHost(), uri.getPort() != -1 ? uri.getPort() : 443);
} else {
location = Location.forGrpcInsecure(uri.getHost(), uri.getPort() != -1 ? uri.getPort() : 80);
if (config.getHeaders() != null) {
for (Map.Entry<String, String> entry : config.getHeaders().entrySet()) {
metadata.insert(entry.getKey(), entry.getValue());
}
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}

client = FlightClient.builder()
.location(location)
.allocator(new RootAllocator(Long.MAX_VALUE))
.verifyServer(!config.getDisableServerCertificateValidation())
.build();
this.headers = new HeaderCallOption(metadata);
this.client = (client != null) ? client : createFlightClient(config);
}

@Nonnull
Expand All @@ -100,7 +99,7 @@ Stream<VectorSchemaRoot> execute(@Nonnull final String query,
put("query_type", queryType.name().toLowerCase());
}};

if (queryParameters.size() > 0) {
if (!queryParameters.isEmpty()) {
ticketData.put("params", queryParameters);
}

Expand All @@ -124,6 +123,31 @@ public void close() throws Exception {
client.close();
}

@Nonnull
private FlightClient createFlightClient(@Nonnull final ClientConfig config) {
Location location = createLocation(config);

return FlightClient.builder()
.location(location)
.allocator(new RootAllocator(Long.MAX_VALUE))
.verifyServer(!config.getDisableServerCertificateValidation())
.build();
}

@Nonnull
private Location createLocation(@Nonnull final ClientConfig config) {
try {
URI uri = new URI(config.getHost());
if ("https".equals(uri.getScheme())) {
return Location.forGrpcTls(uri.getHost(), uri.getPort() != -1 ? uri.getPort() : 443);
} else {
return Location.forGrpcInsecure(uri.getHost(), uri.getPort() != -1 ? uri.getPort() : 80);
}
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
}

private static final class FlightSqlIterator implements Iterator<VectorSchemaRoot>, AutoCloseable {

private final List<AutoCloseable> autoCloseable = new ArrayList<>();
Expand Down
201 changes: 201 additions & 0 deletions src/test/java/com/influxdb/v3/client/internal/FlightSqlClientTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
* The MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
package com.influxdb.v3.client.internal;

import java.net.URISyntaxException;
import java.util.Map;

import io.grpc.internal.GrpcUtil;
import org.apache.arrow.flight.CallHeaders;
import org.apache.arrow.flight.CallInfo;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightClientMiddleware;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.NoOpFlightProducer;
import org.apache.arrow.memory.RootAllocator;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import com.influxdb.v3.client.config.ClientConfig;
import com.influxdb.v3.client.query.QueryType;

public class FlightSqlClientTest {

private static final String LOCALHOST = "localhost";
private final Location grpcLocation = Location.forGrpcInsecure(LOCALHOST, 0);
private final String serverLocation = String.format("http://%s:%d", LOCALHOST, grpcLocation.getUri().getPort());

private final CallHeadersMiddleware callHeadersMiddleware = new CallHeadersMiddleware();

private RootAllocator allocator;
private FlightServer server;
private FlightClient client;

@BeforeEach
void reset() {
callHeadersMiddleware.headers = null;
}

@BeforeEach
void setUp() throws Exception {
allocator = new RootAllocator(Long.MAX_VALUE);
server = FlightServer.builder(allocator, grpcLocation, new NoOpFlightProducer()).build().start();
client = FlightClient.builder(allocator, server.getLocation()).intercept(callHeadersMiddleware).build();
callHeadersMiddleware.headers = null;
}

@AfterEach
void tearDown() throws Exception {
if (client != null) {
client.close();
}
if (server != null) {
server.shutdown();
server.awaitTermination();
}
if (allocator != null) {
allocator.close();
}
}

@Test
public void invalidHost() {
ClientConfig clientConfig = new ClientConfig.Builder()
.host("xyz://a bc")
.token("my-token".toCharArray())
.build();

Assertions.assertThatThrownBy(() -> {
try (FlightSqlClient ignored = new FlightSqlClient(clientConfig)) {
Assertions.fail("Should not be here");
}
})
.isInstanceOf(RuntimeException.class)
.hasCauseInstanceOf(URISyntaxException.class)
.hasMessageContaining("xyz://a bc");
}

@Test
public void callHeaders() throws Exception {
ClientConfig clientConfig = new ClientConfig.Builder()
.host(serverLocation)
.token("my-token".toCharArray())
.build();

try (FlightSqlClient flightSqlClient = new FlightSqlClient(clientConfig, client)) {

flightSqlClient.execute("select * from cpu", "mydb", QueryType.SQL, Map.of());

final CallHeaders incomingHeaders = callHeadersMiddleware.headers;

Assertions.assertThat(incomingHeaders.keys()).containsOnly(
"authorization",
GrpcUtil.MESSAGE_ACCEPT_ENCODING
);
Assertions.assertThat(incomingHeaders.get("authorization")).isEqualTo("Bearer my-token");
}
}

@Test
public void callHeadersWithoutToken() throws Exception {
ClientConfig clientConfig = new ClientConfig.Builder()
.host(serverLocation)
.build();

try (FlightSqlClient flightSqlClient = new FlightSqlClient(clientConfig, client)) {

flightSqlClient.execute("select * from cpu", "mydb", QueryType.SQL, Map.of());

final CallHeaders incomingHeaders = callHeadersMiddleware.headers;

Assertions.assertThat(incomingHeaders.keys()).containsOnly(GrpcUtil.MESSAGE_ACCEPT_ENCODING);
Assertions.assertThat(incomingHeaders.get("authorization")).isNull();
}
}

@Test
public void callHeadersEmptyToken() throws Exception {
ClientConfig clientConfig = new ClientConfig.Builder()
.host(serverLocation)
.token("".toCharArray())
.build();

try (FlightSqlClient flightSqlClient = new FlightSqlClient(clientConfig, client)) {

flightSqlClient.execute("select * from cpu", "mydb", QueryType.SQL, Map.of());

final CallHeaders incomingHeaders = callHeadersMiddleware.headers;

Assertions.assertThat(incomingHeaders.keys()).containsOnly(GrpcUtil.MESSAGE_ACCEPT_ENCODING);
Assertions.assertThat(incomingHeaders.get("authorization")).isNull();
}
}

@Test
public void callHeadersCustomHeader() throws Exception {
ClientConfig clientConfig = new ClientConfig.Builder()
.host(serverLocation)
.token("my-token".toCharArray())
.headers(Map.of("X-Tracing-Id", "123"))
.build();

try (FlightSqlClient flightSqlClient = new FlightSqlClient(clientConfig, client)) {

flightSqlClient.execute("select * from cpu", "mydb", QueryType.SQL, Map.of());

final CallHeaders incomingHeaders = callHeadersMiddleware.headers;

Assertions.assertThat(incomingHeaders.keys()).containsOnly(
"authorization",
"x-tracing-id",
GrpcUtil.MESSAGE_ACCEPT_ENCODING
);
Assertions.assertThat(incomingHeaders.get("X-Tracing-Id")).isEqualTo("123");
}
}

static class CallHeadersMiddleware implements FlightClientMiddleware.Factory {
CallHeaders headers;

@Override
public FlightClientMiddleware onCallStarted(final CallInfo info) {
return new FlightClientMiddleware() {
@Override
public void onBeforeSendingHeaders(final CallHeaders outgoingHeaders) {
headers = outgoingHeaders;
}

@Override
public void onHeadersReceived(final CallHeaders incomingHeaders) {
}

@Override
public void onCallCompleted(final CallStatus status) {
}
};
}
}
}

0 comments on commit da932f7

Please sign in to comment.