Skip to content

Commit

Permalink
Add fingerprint as an options
Browse files Browse the repository at this point in the history
Fixes #103
  • Loading branch information
morazow committed Dec 9, 2021
1 parent c6fe772 commit 3c3aa09
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 75 deletions.
97 changes: 59 additions & 38 deletions src/main/scala/com/exasol/spark/util/ExasolConfiguration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,54 @@ final case class ExasolConfiguration(
jdbc_options: String,
username: String,
password: String,
fingerprint: String,
max_nodes: Int,
create_table: Boolean,
drop_table: Boolean,
batch_size: Int
)

/**
* A companion object that creates {@link ExasolConfiguration}.
*/
object ExasolConfiguration {

val IPv4_DIGITS: String = "(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"
val IPv4_REGEX: Regex = raw"""^$IPv4_DIGITS\.$IPv4_DIGITS\.$IPv4_DIGITS\.$IPv4_DIGITS$$""".r
private[this] val IPv4_DIGITS: String = "(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"
private[this] val IPv4_REGEX: Regex = raw"""^$IPv4_DIGITS\.$IPv4_DIGITS\.$IPv4_DIGITS\.$IPv4_DIGITS$$""".r
private[this] val DEFAULT_MAX_NODES = "200"
private[this] val DEFAULT_BATCH_SIZE = "1000"

/**
* Returns {@link ExasolConfiguration} from key-value options.
*
* It also validates key-value parameters.
*
* @param options key value options
* @return an instance of {@link ExasolConfiguration}
*/
def apply(options: Map[String, String]): ExasolConfiguration = {
val host = options.getOrElse("host", getLocalHost())
val jdbc_options = options.getOrElse("jdbc_options", "")
checkHost(host)
checkJdbcOptions(jdbc_options)

ExasolConfiguration(
host = host,
port = options.getOrElse("port", "8888").toInt,
jdbc_options = jdbc_options,
username = options.getOrElse("username", "sys"),
password = options.getOrElse("password", "exasol"),
fingerprint = options.getOrElse("fingerprint", ""),
max_nodes = options.getOrElse("max_nodes", DEFAULT_MAX_NODES).toInt,
create_table = options.getOrElse("create_table", "false").toBoolean,
drop_table = options.getOrElse("drop_table", "false").toBoolean,
batch_size = options.getOrElse("batch_size", DEFAULT_BATCH_SIZE).toInt
)
}

def getLocalHost(): String = InetAddress.getLocalHost.getHostAddress
private[this] def getLocalHost(): String = InetAddress.getLocalHost().getHostAddress()

def checkHost(host: String): String = host match {
private[util] def checkHost(host: String): String = host match {
case IPv4_REGEX(_*) => host
case _ =>
throw new IllegalArgumentException(
Expand All @@ -72,8 +106,16 @@ object ExasolConfiguration {
)
}

def checkJdbcOptions(str: String): String = {
if (str.endsWith(";") || str.startsWith(";")) {
private[this] def checkJdbcOptions(options: String): Unit = {
checkStartsOrEndsWith(options, ";")
if (!options.isEmpty()) {
val keyValuePairs = options.split(";")
checkContainsKeyValuePairs(keyValuePairs, "=")
}
}

private[this] def checkStartsOrEndsWith(input: String, pattern: String): Unit =
if (input.endsWith(pattern) || input.startsWith(pattern)) {
throw new IllegalArgumentException(
ExaError
.messageBuilder("E-SEC-5")
Expand All @@ -83,38 +125,17 @@ object ExasolConfiguration {
)
}

if (str.length > 0) {
str
.split(";")
.foreach { kv =>
if (kv.filter(_ == '=').length != 1) {
throw new IllegalArgumentException(
ExaError
.messageBuilder("E-SEC-6")
.message("Parameter {{PARAMETER}} does not have key=value format.", kv)
.mitigation("Please make sure parameters are encoded as key=value pairs.")
.toString()
)
}
}
private[this] def checkContainsKeyValuePairs(options: Array[String], pattern: String): Unit =
options.foreach { case keyValue =>
if (keyValue.split(pattern).length != 2) {
throw new IllegalArgumentException(
ExaError
.messageBuilder("E-SEC-6")
.message("Parameter {{PARAMETER}} does not have key=value format.", keyValue)
.mitigation("Please make sure parameters are encoded as key=value pairs.")
.toString()
)
}
}
str
}

@SuppressWarnings(
Array("org.wartremover.warts.Overloading", "org.danielnixon.extrawarts.StringOpsPartial")
)
def apply(opts: Map[String, String]): ExasolConfiguration =
ExasolConfiguration(
host = checkHost(opts.getOrElse("host", getLocalHost())),
port = opts.getOrElse("port", "8888").toInt,
jdbc_options = checkJdbcOptions(opts.getOrElse("jdbc_options", "")),
username = opts.getOrElse("username", "sys"),
password = opts.getOrElse("password", "exasol"),
max_nodes = opts.getOrElse("max_nodes", "200").toInt,
create_table = opts.getOrElse("create_table", "false").toBoolean,
drop_table = opts.getOrElse("drop_table", "false").toBoolean,
batch_size = opts.getOrElse("batch_size", "1000").toInt
)

}
34 changes: 20 additions & 14 deletions src/main/scala/com/exasol/spark/util/ExasolConnectionManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,26 @@ import com.exasol.jdbc.EXAStatement
*/
final case class ExasolConnectionManager(config: ExasolConfiguration) {

private[this] val mainJdbcConnectionUrl = s"jdbc:exa:${config.host}:${config.port}"
private[this] val DO_NOT_VALIDATE_CERTIFICATE = "validateservercertificate=0"
private[this] val MAIN_CONNECTION_PREFIX = "jdbc:exa"
private[this] val WORKER_CONNECTION_PREFIX = "jdbc:exa-worker"

/** A regular Exasol jdbc connection string */
def getJdbcConnectionString(): String =
getConnectionStringWithOptions(mainJdbcConnectionUrl)
def getJdbcConnectionString(): String = {
val host = getHostWithFingerprint(config.host)
val url = s"$MAIN_CONNECTION_PREFIX:$host:${config.port}"
getConnectionStringWithOptions(url)
}

private[this] def getHostWithFingerprint(host: String): String =
if (!config.fingerprint.isEmpty() && !config.jdbc_options.contains(DO_NOT_VALIDATE_CERTIFICATE)) {
host + "/" + config.fingerprint
} else {
host
}

def mainConnection(): EXAConnection =
ExasolConnectionManager.makeConnection(
getJdbcConnectionString(),
config.username,
config.password
)
ExasolConnectionManager.makeConnection(getJdbcConnectionString(), config.username, config.password)

def writerMainConnection(): EXAConnection =
ExasolConnectionManager.makeConnection(
Expand All @@ -50,11 +58,7 @@ final case class ExasolConnectionManager(config: ExasolConfiguration) {
* the user.
*/
def getConnection(): EXAConnection =
ExasolConnectionManager.createConnection(
getJdbcConnectionString(),
config.username,
config.password
)
ExasolConnectionManager.createConnection(getJdbcConnectionString(), config.username, config.password)

/**
* Starts a parallel sub-connections from the main JDBC connection.
Expand All @@ -80,7 +84,9 @@ final case class ExasolConnectionManager(config: ExasolConfiguration) {
.zipWithIndex
.toSeq
.map { case ((host, port), idx) =>
getConnectionStringWithOptions(s"jdbc:exa-worker:$host:$port;workerID=$idx;workertoken=$token")
val hostWithFingerprint = getHostWithFingerprint(host)
val url = s"$WORKER_CONNECTION_PREFIX:$hostWithFingerprint:$port;workerID=$idx;workertoken=$token"
getConnectionStringWithOptions(url)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,46 +1,95 @@
package com.exasol.spark.util

import com.exasol.jdbc.EXAConnection

import org.mockito.Mockito._
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar

class ExasolConnectionManagerSuite extends AnyFunSuite with Matchers {
class ExasolConnectionManagerSuite extends AnyFunSuite with Matchers with MockitoSugar {

def getECMConnectionString(opts: Map[String, String]): String = {
val conf: ExasolConfiguration = ExasolConfiguration.apply(opts)
ExasolConnectionManager(conf).getJdbcConnectionString()
}
def getManager(options: Map[String, String]): ExasolConnectionManager =
ExasolConnectionManager(ExasolConfiguration(options))

@SuppressWarnings(Array("scala:S1313")) // Hardcoded IP addresses are safe in tests
val emptyOpts: Map[String, String] = Map("host" -> "10.0.0.1", "port" -> "8888")
def getJdbcUrl(options: Map[String, String]): String =
getManager(options).getJdbcConnectionString()

test("check extra exasol jdbc options are correctly configured for establishing connection") {
assert(getECMConnectionString(emptyOpts) === "jdbc:exa:10.0.0.1:8888")
@SuppressWarnings(Array("scala:S1313")) // Hardcoded IP addresses are safe in tests
val requiredOptions: Map[String, String] = Map("host" -> "10.0.0.1", "port" -> "8888")

val correctOpts1 = emptyOpts ++ Map("jdbc_options" -> "debug=1")
assert(getECMConnectionString(correctOpts1) === "jdbc:exa:10.0.0.1:8888;debug=1")
test("check empty jdbc options returns correctly configured jdbc url") {
assert(getJdbcUrl(requiredOptions) === "jdbc:exa:10.0.0.1:8888")
}

val correctOpts2 = emptyOpts ++ Map("jdbc_options" -> "debug=1;encryption=0")
assert(getECMConnectionString(correctOpts2) === "jdbc:exa:10.0.0.1:8888;debug=1;encryption=0")
test("check extra jdbc options are correctly configured for establishing connection") {
Map(
"debug=1" -> "jdbc:exa:10.0.0.1:8888;debug=1",
"debug=1;encryption=0" -> "jdbc:exa:10.0.0.1:8888;debug=1;encryption=0"
).foreach { case (jdbc_options, expectedJdbcUrl) =>
val options = requiredOptions ++ Map("jdbc_options" -> jdbc_options)
assert(getJdbcUrl(options) === expectedJdbcUrl)
}
}

test("check exasol jdbc options has invalid property format") {
val incorrectOpt = emptyOpts ++ Map("jdbc_options" -> "debug==1;encryption=0")
test("throws when jdbc options have invalid key-value property format") {
val incorrectOpt = requiredOptions ++ Map("jdbc_options" -> "debug==1;encryption=0")
val thrown = intercept[IllegalArgumentException] {
getECMConnectionString(incorrectOpt)
getJdbcUrl(incorrectOpt)
}
val message = thrown.getMessage()
assert(message.startsWith("E-SEC-6"))
assert(message.contains("Parameter 'debug==1' does not have key=value format"))
}

test("check exasol jdbc options start with semicolon") {
val incorrectOpt = emptyOpts ++ Map("jdbc_options" -> ";debug=1;encryption=0")
val thrown = intercept[IllegalArgumentException] {
getECMConnectionString(incorrectOpt)
test("throws when jdbc options start or end with semicolon") {
Seq(";debug=1;encryption=0", "encryption=1;").foreach { case options =>
val incorrectOpt = requiredOptions ++ Map("jdbc_options" -> options)
val thrown = intercept[IllegalArgumentException] {
getJdbcUrl(incorrectOpt)
}
val message = thrown.getMessage()
assert(message.startsWith("E-SEC-5"))
assert(message.contains("JDBC options should not start or end with semicolon"))
}
val message = thrown.getMessage()
assert(message.startsWith("E-SEC-5"))
assert(message.contains("JDBC options should not start or end with semicolon"))
}

private[this] def getMockedConnection(): EXAConnection = {
val connection = mock[EXAConnection]
when(connection.GetWorkerHosts()).thenReturn(Array("worker1", "worker2"))
when(connection.GetWorkerPorts()).thenReturn(Array(21001, 21010))
when(connection.GetWorkerToken()).thenReturn(12345L)
connection
}

test("returns list of worker connections") {
val expected = Seq(getWorkerJdbcUrl("worker1", 21001, 0, 12345L), getWorkerJdbcUrl("worker2", 21010, 1, 12345L))
assert(getManager(requiredOptions).subConnections(getMockedConnection()) === expected)
}

test("returns jdbc url with fingerprint") {
val options = requiredOptions ++ Map("fingerprint" -> "dummy_fingerprint")
assert(getJdbcUrl(options) === "jdbc:exa:10.0.0.1/dummy_fingerprint:8888")
}

test("returns jdbc url without fingerprint if validateservercertificate=0") {
val options = requiredOptions ++ Map(
"jdbc_options" -> "validateservercertificate=0",
"fingerprint" -> "dummy_fingerprint"
)
assert(getJdbcUrl(options) === "jdbc:exa:10.0.0.1:8888;validateservercertificate=0")
}

test("returns list of worker connections with fingerprint") {
val options = requiredOptions ++ Map("fingerprint" -> "fp", "jdbc_options" -> "debug=1")
val expected = Seq(
s"${getWorkerJdbcUrl("worker1/fp", 21001, 0, 12345L)};debug=1",
s"${getWorkerJdbcUrl("worker2/fp", 21010, 1, 12345L)};debug=1"
)
assert(getManager(options).subConnections(getMockedConnection()) === expected)
}

private[this] def getWorkerJdbcUrl(host: String, port: Int, id: Int, token: Long): String =
s"jdbc:exa-worker:$host:$port;workerID=$id;workertoken=$token"

}

0 comments on commit 3c3aa09

Please sign in to comment.