Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ssl client certs for authentication. #31

Merged
merged 10 commits into from
May 4, 2022
19 changes: 16 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,18 +189,31 @@ function consumeWithConfiguration(reader: object, limit: number, configurationJs
* @param {number} partitions The number of partitions.
* @param {number} replicationFactor The replication factor in a clustered setup.
* @param {string} compression The compression algorithm.
* @param {string} auth Authentication credentials for SASL PLAIN/SCRAM.
* @returns {string} A string containing the error.
*/
function createTopic(address: string, topic: string, partitions: number, replicationFactor number, compression string) => string {}
function createTopic(address: string, topic: string, partitions: number, replicationFactor: number, compression: string, auth: string) => string {}

/**
* List all topics in Kafka.
* Delete a topic from Kafka. It raises an error if the topic doesn't exist.
*
* @function
* @param {string} address The broker address.
* @param {string} topic The topic name.
* @param {string} auth Authentication credentials for SASL PLAIN/SCRAM.
* @returns {string} A string containing the error.
*/
function deleteTopic(address: string, topic: string, auth: string) => string {}

/**
* List all topics in Kafka.
*
* @function
* @param {string} address The broker address.
* @param {string} auth Authentication credentials for SASL PLAIN/SCRAM.
* @returns {string} A nested list of strings containing a list of topics and the error (if any).
*/
function listTopics(address: string) => [[string], string] {}
function listTopics(address: string, auth: string) => [[string], string] {}
```

</details>
Expand Down
96 changes: 87 additions & 9 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ package kafka

import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"io/ioutil"
"log"
"os"
"time"

kafkago "github.com/segmentio/kafka-go"
Expand All @@ -11,34 +15,36 @@ import (
)

const (
None = "none"
Plain = "plain"
SHA256 = "sha256"
SHA512 = "sha512"
)

type Credentials struct {
Username string `json:"username"`
Password string `json:"password"`
Algorithm string `json:"algorithm"`
Username string `json:"username"`
Password string `json:"password"`
Algorithm string `json:"algorithm"`
ClientCertPem string `json:"clientCertPem"`
ClientKeyPem string `json:"clientKeyPem"`
ServerCaPem string `json:"serverCaPem"`
}

func unmarshalCredentials(auth string) (creds *Credentials, err error) {
creds = &Credentials{
Algorithm: Plain,
Algorithm: None,
}

err = json.Unmarshal([]byte(auth), &creds)

return
}

func getDialer(creds *Credentials) (dialer *kafkago.Dialer) {
func getDialerFromCreds(creds *Credentials) (dialer *kafkago.Dialer) {
dialer = &kafkago.Dialer{
Timeout: 10 * time.Second,
DualStack: true,
TLS: &tls.Config{
MinVersion: tls.VersionTLS12,
},
TLS: tlsConfig(creds),
}

if creds.Algorithm == Plain {
Expand All @@ -48,7 +54,7 @@ func getDialer(creds *Credentials) (dialer *kafkago.Dialer) {
}
dialer.SASLMechanism = mechanism
return
} else {
} else if creds.Algorithm == SHA256 || creds.Algorithm == SHA512 {
hashes := make(map[string]scram.Algorithm)
hashes["sha256"] = scram.SHA256
hashes["sha512"] = scram.SHA512
Expand All @@ -65,4 +71,76 @@ func getDialer(creds *Credentials) (dialer *kafkago.Dialer) {
dialer.SASLMechanism = mechanism
return
}
return
}

func getDialerFromAuth(auth string) (dialer *kafkago.Dialer) {
if auth != "" {
// Parse the auth string
creds, err := unmarshalCredentials(auth)
if err != nil {
ReportError(err, "Unable to unmarshal credentials")
return nil
}

// Try to create an authenticated dialer from the credentials
// with TLS enabled if the credentials specify a client cert
// and key.
dialer = getDialerFromCreds(creds)
if dialer == nil {
ReportError(nil, "Dialer cannot authenticate")
return nil
}
} else {
// Create a normal (unauthenticated) dialer
dialer = &kafkago.Dialer{
Timeout: 10 * time.Second,
DualStack: false,
}
}

return
}

func fileExists(filename string) bool {
_, err := os.Stat(filename)
return err == nil
}

func tlsConfig(creds *Credentials) *tls.Config {
var clientCertFile = &creds.ClientCertPem
if !fileExists(*clientCertFile) {
ReportError(nil, "client certificate file not found")
return nil
}

var clientKeyFile = &creds.ClientKeyPem
if !fileExists(*clientKeyFile) {
ReportError(nil, "client key file not found")
return nil
}

var cert, err = tls.LoadX509KeyPair(*clientCertFile, *clientKeyFile)
if err != nil {
log.Fatalf("Error creating x509 keypair from client cert file %s and client key file %s", *clientCertFile, *clientKeyFile)
}

var caCertFile = &creds.ServerCaPem
if !fileExists(*caCertFile) {
ReportError(nil, "CA certificate file not found")
return nil
}

caCert, err := ioutil.ReadFile(*caCertFile)
if err != nil {
log.Fatalf("Error opening cert file %s, Error: %s", *caCertFile, err)
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)

return &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: caCertPool,
MinVersion: tls.VersionTLS12,
}
}
17 changes: 9 additions & 8 deletions configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ import (
)

type ConsumerConfiguration struct {
KeyDeserializer string `json:"keyDeserializer"`
ValueDeserializer string `json:"valueDeserializer"`
KeyDeserializer string `json:"keyDeserializer"`
ValueDeserializer string `json:"valueDeserializer"`
}

type ProducerConfiguration struct {
KeySerializer string `json:"keySerializer"`
ValueSerializer string `json:"valueSerializer"`
KeySerializer string `json:"keySerializer"`
ValueSerializer string `json:"valueSerializer"`
}

type BasicAuth struct {
Expand All @@ -23,11 +23,12 @@ type BasicAuth struct {
type SchemaRegistryConfiguration struct {
Url string `json:"url"`
BasicAuth BasicAuth `json:"basicAuth"`
UseLatest bool `json:"useLatest"`
}

type Configuration struct {
Consumer ConsumerConfiguration `json:"consumer"`
Producer ProducerConfiguration `json:"producer"`
Consumer ConsumerConfiguration `json:"consumer"`
Producer ProducerConfiguration `json:"producer"`
SchemaRegistry SchemaRegistryConfiguration `json:"schemaRegistry"`
}

Expand All @@ -43,7 +44,7 @@ func useKafkaAvroDeserializer(configuration Configuration, keyOrValue string) bo
return false
}
if keyOrValue == "key" && configuration.Consumer.KeyDeserializer == "io.confluent.kafka.serializers.KafkaAvroDeserializer" ||
keyOrValue == "value" && configuration.Consumer.ValueDeserializer == "io.confluent.kafka.serializers.KafkaAvroDeserializer" {
keyOrValue == "value" && configuration.Consumer.ValueDeserializer == "io.confluent.kafka.serializers.KafkaAvroDeserializer" {
return true
}
return false
Expand All @@ -55,7 +56,7 @@ func useKafkaAvroSerializer(configuration Configuration, keyOrValue string) bool
return false
}
if keyOrValue == "key" && configuration.Producer.KeySerializer == "io.confluent.kafka.serializers.KafkaAvroSerializer" ||
keyOrValue == "value" && configuration.Producer.ValueSerializer == "io.confluent.kafka.serializers.KafkaAvroSerializer" {
keyOrValue == "value" && configuration.Producer.ValueSerializer == "io.confluent.kafka.serializers.KafkaAvroSerializer" {
return true
}
return false
Expand Down
18 changes: 1 addition & 17 deletions consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,6 @@ import (
func (*Kafka) Reader(
brokers []string, topic string, partition int,
groupID string, offset int64, auth string) *kafkago.Reader {
var dialer *kafkago.Dialer

if auth != "" {
creds, err := unmarshalCredentials(auth)
if err != nil {
ReportError(err, "Unable to unmarshal credentials")
return nil
}

dialer = getDialer(creds)
if dialer == nil {
ReportError(nil, "Dialer cannot authenticate")
return nil
}
}

if groupID != "" {
partition = 0
}
Expand All @@ -40,7 +24,7 @@ func (*Kafka) Reader(
MaxWait: time.Millisecond * 200,
RebalanceTimeout: time.Second * 5,
QueueCapacity: 1,
Dialer: dialer,
Dialer: getDialerFromAuth(auth),
})

if offset > 0 {
Expand Down
18 changes: 1 addition & 17 deletions producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,12 @@ var (
)

func (*Kafka) Writer(brokers []string, topic string, auth string, compression string) *kafkago.Writer {
var dialer *kafkago.Dialer

if auth != "" {
creds, err := unmarshalCredentials(auth)
if err != nil {
ReportError(err, "Unable to unmarshal credentials")
return nil
}

dialer = getDialer(creds)
if dialer == nil {
ReportError(nil, "Dialer cannot authenticate")
return nil
}
}

writerConfig := kafkago.WriterConfig{
Brokers: brokers,
Topic: topic,
Balancer: &kafkago.LeastBytes{},
BatchSize: 1,
Dialer: dialer,
Dialer: getDialerFromAuth(auth),
Async: false,
}

Expand Down
Loading