Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add TLS support #103

Merged
merged 5 commits into from
Dec 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@ jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
runs-on: ubuntu-latest
services:
redis:
image: redis:7.2.3-bookworm # https://hub.docker.com/_/redis
ports:
- 6379:6379
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -41,9 +36,18 @@ jobs:
${{ runner.os }}-test-
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@v1
- name: Start redis server
run: |
echo "Starting redis server"
pwd
test/conf/redis.sh
sleep 5
echo "Redis started"
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v2
- uses: codecov/codecov-action@v3
id: codecov
continue-on-error: true
with:
files: lcov.info
fail_ci_if_error: false
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ version = "2.0.0"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
MbedTLS = "739be429-bea8-5141-9913-cc70e7f3736d"

[compat]
julia = "^1"
DataStructures = "^0.18"
MbedTLS = "0.6.8, 0.7, 1"

[extras]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down
2 changes: 2 additions & 0 deletions src/Redis.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Redis
using Dates
using Sockets
using MbedTLS

import Base.get, Base.keys, Base.time

Expand Down Expand Up @@ -59,6 +60,7 @@ export sentinel_masters, sentinel_master, sentinel_slaves, sentinel_getmasteradd
export REDIS_PERSISTENT_KEY, REDIS_EXPIRED_KEY

include("exceptions.jl")
include("transport/transport.jl")
include("connection.jl")
include("parser.jl")
include("client.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/client.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ end
function subscription_loop(conn::SubscriptionConnection, err_callback::Function)
while is_connected(conn)
try
l = getline(conn.socket)
reply = parseline(l, conn.socket)
l = getline(conn.transport)
reply = parseline(l, conn.transport)
reply = convert_reply(reply)
message = SubscriptionMessage(reply)
if message.message_type == SubscriptionMessageType.Message
Expand Down
85 changes: 53 additions & 32 deletions src/connection.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import Sockets.connect, Sockets.TCPSocket, Base.StatusActive, Base.StatusOpen, Base.StatusPaused

abstract type RedisConnectionBase end
abstract type SubscribableConnection<:RedisConnectionBase end

Expand All @@ -8,31 +6,31 @@
port::Integer
password::AbstractString
db::Integer
socket::TCPSocket
transport::Transport.RedisTransport
end

struct SentinelConnection <: SubscribableConnection
host::AbstractString
port::Integer
password::AbstractString
db::Integer
socket::TCPSocket
transport::Transport.RedisTransport
end

struct TransactionConnection <: RedisConnectionBase
host::AbstractString
port::Integer
password::AbstractString
db::Integer
socket::TCPSocket
transport::Transport.RedisTransport
end

mutable struct PipelineConnection <: RedisConnectionBase
host::AbstractString
port::Integer
password::AbstractString
db::Integer
socket::TCPSocket
transport::Transport.RedisTransport
num_commands::Integer
end

Expand All @@ -43,77 +41,100 @@
db::Integer
callbacks::Dict{AbstractString, Function}
pcallbacks::Dict{AbstractString, Function}
socket::TCPSocket
transport::Transport.RedisTransport
end

function RedisConnection(; host="127.0.0.1", port=6379, password="", db=0)
Transport.get_sslconfig(s::RedisConnectionBase) = Transport.get_sslconfig(s.transport)

function RedisConnection(; host="127.0.0.1", port=6379, password="", db=0, sslconfig=nothing)
try
socket = connect(host, port)
connection = RedisConnection(host, port, password, db, socket)
connection = RedisConnection(
host,
port,
password,
db,
Transport.transport(host, port, sslconfig)
)
on_connect(connection)
catch
throw(ConnectionException("Failed to connect to Redis server"))
end
end

function SentinelConnection(; host="127.0.0.1", port=26379, password="", db=0)
function SentinelConnection(; host="127.0.0.1", port=26379, password="", db=0, sslconfig=nothing)

Check warning on line 64 in src/connection.jl

View check run for this annotation

Codecov / codecov/patch

src/connection.jl#L64

Added line #L64 was not covered by tests
try
socket = connect(host, port)
sentinel_connection = SentinelConnection(host, port, password, db, socket)
sentinel_connection = SentinelConnection(

Check warning on line 66 in src/connection.jl

View check run for this annotation

Codecov / codecov/patch

src/connection.jl#L66

Added line #L66 was not covered by tests
host,
port,
password,
db,
Transport.transport(host, port, sslconfig)
)
on_connect(sentinel_connection)
catch
throw(ConnectionException("Failed to connect to Redis sentinel"))
end
end

function TransactionConnection(parent::RedisConnection)
function TransactionConnection(parent::RedisConnection; sslconfig=Transport.get_sslconfig(parent))
try
socket = connect(parent.host, parent.port)
transaction_connection = TransactionConnection(parent.host,
parent.port, parent.password, parent.db, socket)
transaction_connection = TransactionConnection(
parent.host,
parent.port,
parent.password,
parent.db,
Transport.transport(parent.host, parent.port, sslconfig)
)
on_connect(transaction_connection)
catch
throw(ConnectionException("Failed to create transaction"))
end
end

function PipelineConnection(parent::RedisConnection)
function PipelineConnection(parent::RedisConnection; sslconfig=Transport.get_sslconfig(parent))
try
socket = connect(parent.host, parent.port)
pipeline_connection = PipelineConnection(parent.host,
parent.port, parent.password, parent.db, socket, 0)
pipeline_connection = PipelineConnection(
parent.host,
parent.port,
parent.password,
parent.db,
Transport.transport(parent.host, parent.port, sslconfig),
0
)
on_connect(pipeline_connection)
catch
throw(ConnectionException("Failed to create pipeline"))
end
end

function SubscriptionConnection(parent::SubscribableConnection)
function SubscriptionConnection(parent::SubscribableConnection; sslconfig=Transport.get_sslconfig(parent))
try
socket = connect(parent.host, parent.port)
subscription_connection = SubscriptionConnection(parent.host,
parent.port, parent.password, parent.db, Dict{AbstractString, Function}(),
Dict{AbstractString, Function}(), socket)
subscription_connection = SubscriptionConnection(
parent.host,
parent.port,
parent.password,
parent.db,
Dict{AbstractString, Function}(),
Dict{AbstractString, Function}(),
Transport.transport(parent.host, parent.port, sslconfig)
)
on_connect(subscription_connection)
catch
throw(ConnectionException("Failed to create subscription"))
end
end

function on_connect(conn::RedisConnectionBase)
# disable nagle and enable quickack to speed up the usually small exchanges
Sockets.nagle(conn.socket, false)
Sockets.quickack(conn.socket, true)

Transport.set_props!(conn.transport)
conn.password != "" && auth(conn, conn.password)
conn.db != 0 && select(conn, conn.db)
conn
end

function disconnect(conn::RedisConnectionBase)
close(conn.socket)
Transport.close(conn.transport)
end

function is_connected(conn::RedisConnectionBase)
conn.socket.status == StatusActive || conn.socket.status == StatusOpen || conn.socket.status == StatusPaused
Transport.is_connected(conn.transport)
end
28 changes: 14 additions & 14 deletions src/parser.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Formatting of incoming Redis Replies
"""
function getline(s::TCPSocket)
l = chomp(readline(s))
function getline(t::Transport.RedisTransport)
l = chomp(Transport.read_line(t))
length(l) > 1 || throw(ProtocolException("Invalid response received: $l"))
return l
end
Expand All @@ -12,15 +12,15 @@ convert_reply(reply::Array) = [convert_reply(r) for r in reply]
convert_reply(x) = x

function read_reply(conn::RedisConnectionBase)
l = getline(conn.socket)
reply = parseline(l, conn.socket)
l = getline(conn.transport)
reply = parseline(l, conn.transport)
convert_reply(reply)
end

parse_error(l::AbstractString) = throw(ServerException(l))

function parse_bulk_string(s::TCPSocket, slen::Int)
b = read(s, slen+2) # add crlf
function parse_bulk_string(t::Transport.RedisTransport, slen::Int)
b = Transport.read_nbytes(t, slen+2) # add crlf
if length(b) != slen + 2
throw(ProtocolException(
"Bulk string read error: expected $slen bytes; received $(length(b))"
Expand All @@ -30,17 +30,17 @@ function parse_bulk_string(s::TCPSocket, slen::Int)
end
end

function parse_array(s::TCPSocket, slen::Int)
function parse_array(t::Transport.RedisTransport, slen::Int)
a = Array{Any, 1}(undef, slen)
for i = 1:slen
l = getline(s)
r = parseline(l, s)
l = getline(t)
r = parseline(l, t)
a[i] = r
end
return a
end

function parseline(l::AbstractString, s::TCPSocket)
function parseline(l::AbstractString, t::Transport.RedisTransport)
reply_type = l[1]
reply_token = l[2:end]
if reply_type == '+'
Expand All @@ -52,14 +52,14 @@ function parseline(l::AbstractString, s::TCPSocket)
if slen == -1
nothing
else
parse_bulk_string(s, slen)
parse_bulk_string(t, slen)
end
elseif reply_type == '*'
slen = parse(Int, reply_token)
if slen == -1
nothing
else
parse_array(s, slen)
parse_array(t, slen)
end
elseif reply_type == '-'
parse_error(reply_token)
Expand Down Expand Up @@ -90,8 +90,8 @@ function execute_command_without_reply(conn::RedisConnectionBase, command)
is_connected(conn) || throw(ConnectionException("Socket is disconnected"))
iob = IOBuffer()
pack_command(iob, command)
lock(conn.socket.lock) do
write(conn.socket, take!(iob))
Transport.io_lock(conn.transport) do
Transport.write_bytes(conn.transport, take!(iob))
end
end

Expand Down
19 changes: 19 additions & 0 deletions src/transport/tcp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
struct TCPTransport <: RedisTransport
sock::TCPSocket
end

read_line(t::TCPTransport) = readline(t.sock)
read_nbytes(t::TCPTransport, m::Int) = read(t.sock, m)
write_bytes(t::TCPTransport, b::Vector{UInt8}) = write(t.sock, b)
Base.close(t::TCPTransport) = close(t.sock)
function set_props!(t::TCPTransport)
# disable nagle and enable quickack to speed up the usually small exchanges
Sockets.nagle(t.sock, false)
Sockets.quickack(t.sock, true)
end
get_sslconfig(::TCPTransport) = nothing
io_lock(f, t::TCPTransport) = lock(f, t.sock.lock)
function is_connected(t::TCPTransport)
status = t.sock.status
status == StatusActive || status == StatusOpen || status == StatusPaused
end
56 changes: 56 additions & 0 deletions src/transport/tls.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
struct TLSTransport <: RedisTransport
sock::TCPSocket
ctx::MbedTLS.SSLContext
sslconfig::MbedTLS.SSLConfig
buff::IOBuffer

function TLSTransport(sock::TCPSocket, sslconfig::MbedTLS.SSLConfig)
ctx = MbedTLS.SSLContext()
MbedTLS.setup!(ctx, sslconfig)
MbedTLS.associate!(ctx, sock)
MbedTLS.handshake(ctx)

return new(sock, ctx, sslconfig, PipeBuffer())
end
end

function read_into_buffer_until(cond::Function, t::TLSTransport)
cond(t) && return

buff = Vector{UInt8}(undef, MbedTLS.MBEDTLS_SSL_MAX_CONTENT_LEN)
pbuff = pointer(buff)

while !cond(t) && !eof(t.ctx)
nread = readbytes!(t.ctx, buff; all=false)
if nread > 0
unsafe_write(t.buff, pbuff, nread)
end
end
end

function read_line(t::TLSTransport)
read_into_buffer_until(t) do t
iob = t.buff
(bytesavailable(t.buff) > 0) && (UInt8('\n') in view(iob.data, iob.ptr:iob.size))
end
return readline(t.buff)
end
function read_nbytes(t::TLSTransport, m::Int)
read_into_buffer_until(t) do t
bytesavailable(t.buff) >= m
end
return read(t.buff, m)
end
write_bytes(t::TLSTransport, b::Vector{UInt8}) = write(t.ctx, b)
Base.close(t::TLSTransport) = close(t.ctx)
function set_props!(s::TLSTransport)
# disable nagle and enable quickack to speed up the usually small exchanges
Sockets.nagle(s.sock, false)
Sockets.quickack(s.sock, true)
end
get_sslconfig(t::TLSTransport) = t.sslconfig
io_lock(f, t::TLSTransport) = lock(f, t.sock.lock)
function is_connected(t::TLSTransport)
status = t.sock.status
status == StatusActive || status == StatusOpen || status == StatusPaused
end
Loading
Loading