Skip to content

Commit 5ba424c

Browse files
committed
refactor(api): Simplify Protobuf codec
1 parent e91cbbc commit 5ba424c

File tree

2 files changed

+90
-52
lines changed

2 files changed

+90
-52
lines changed

src/servers/proto.rs

+19-52
Original file line numberDiff line numberDiff line change
@@ -3,74 +3,45 @@
33
use std::net::SocketAddr;
44

55
use futures::prelude::*;
6-
use prost::Message;
76
use thiserror::Error;
87
use tokio::net::TcpStream;
8+
use tokio_util::codec::Framed;
99

1010
use crate::{
1111
api::proto::{self, message, ProtoApiError},
12-
global::{Global, InputMessage, InputSourceHandle, InputSourceName},
12+
global::{Global, InputSourceName},
1313
};
1414

15+
mod codec;
16+
use codec::*;
17+
1518
#[derive(Debug, Error)]
1619
pub enum ProtoServerError {
1720
#[error("i/o error: {0}")]
1821
Io(#[from] futures_io::Error),
19-
#[error("decode error: {}", 0)]
20-
DecodeError(#[from] prost::DecodeError),
22+
#[error("decode error: {0}")]
23+
Codec(#[from] ProtoCodecError),
2124
#[error(transparent)]
2225
Api(#[from] ProtoApiError),
2326
}
2427

25-
fn encode_response(buf: &mut bytes::BytesMut, msg: impl prost::Message) -> bytes::Bytes {
26-
// Clear the buffer to start fresh
27-
buf.clear();
28-
29-
// Reserve enough space for the response
30-
let len = msg.encoded_len();
31-
if buf.capacity() < len {
32-
buf.reserve(len * 2);
33-
}
34-
35-
// Encode the message
36-
msg.encode(buf).unwrap();
37-
buf.split().freeze()
38-
}
39-
40-
fn success_response(peer_addr: SocketAddr, buf: &mut bytes::BytesMut) -> bytes::Bytes {
28+
fn success_response(peer_addr: SocketAddr) -> message::HyperionReply {
4129
let mut reply = message::HyperionReply::default();
4230
reply.r#type = message::hyperion_reply::Type::Reply.into();
4331
reply.success = Some(true);
4432

4533
trace!("({}) sending success: {:?}", peer_addr, reply);
46-
encode_response(buf, reply)
34+
reply
4735
}
4836

49-
fn error_response(
50-
peer_addr: SocketAddr,
51-
buf: &mut bytes::BytesMut,
52-
error: impl std::fmt::Display,
53-
) -> bytes::Bytes {
37+
fn error_response(peer_addr: SocketAddr, error: impl std::fmt::Display) -> message::HyperionReply {
5438
let mut reply = message::HyperionReply::default();
5539
reply.r#type = message::hyperion_reply::Type::Reply.into();
5640
reply.success = Some(false);
5741
reply.error = Some(error.to_string());
5842

5943
trace!("({}) sending error: {:?}", peer_addr, reply);
60-
encode_response(buf, reply)
61-
}
62-
63-
fn handle_request(
64-
peer_addr: SocketAddr,
65-
request_bytes: bytes::BytesMut,
66-
source: &InputSourceHandle<InputMessage>,
67-
) -> Result<(), ProtoServerError> {
68-
let request_bytes = request_bytes.freeze();
69-
let request = message::HyperionRequest::decode(request_bytes.clone())?;
70-
71-
trace!("({}) got request: {:?}", peer_addr, request);
72-
73-
Ok(proto::handle_request(request, source)?)
44+
reply
7445
}
7546

7647
pub async fn handle_client(
@@ -79,35 +50,31 @@ pub async fn handle_client(
7950
) -> Result<(), ProtoServerError> {
8051
debug!("accepted new connection from {}", peer_addr);
8152

82-
let framed = tokio_util::codec::LengthDelimitedCodec::builder()
83-
.length_field_length(4)
84-
.new_framed(socket);
85-
let (mut writer, mut reader) = framed.split();
53+
let (mut writer, mut reader) = Framed::new(socket, ProtoCodec::new()).split();
8654

8755
// unwrap: cannot fail because the priority is None
8856
let source = global
8957
.register_input_source(InputSourceName::Protobuf { peer_addr }, None)
9058
.await
9159
.unwrap();
9260

93-
// buffer for building responses
94-
let mut reply_buf = bytes::BytesMut::with_capacity(128);
95-
96-
while let Some(request_bytes) = reader.next().await {
97-
let request_bytes = match request_bytes {
61+
while let Some(request) = reader.next().await {
62+
let request = match request {
9863
Ok(rb) => rb,
9964
Err(error) => {
10065
error!("({}) error reading frame: {}", peer_addr, error);
10166
continue;
10267
}
10368
};
10469

105-
let reply = match handle_request(peer_addr, request_bytes, &source) {
106-
Ok(()) => success_response(peer_addr, &mut reply_buf),
70+
trace!("({}) got request: {:?}", peer_addr, request);
71+
72+
let reply = match proto::handle_request(request, &source) {
73+
Ok(()) => success_response(peer_addr),
10774
Err(error) => {
10875
error!("({}) error processing request: {}", peer_addr, error);
10976

110-
error_response(peer_addr, &mut reply_buf, error)
77+
error_response(peer_addr, error)
11178
}
11279
};
11380

src/servers/proto/codec.rs

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
use bytes::BytesMut;
2+
use prost::Message;
3+
use thiserror::Error;
4+
use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
5+
6+
use crate::api::proto::message;
7+
8+
#[derive(Debug, Error)]
9+
pub enum ProtoCodecError {
10+
#[error("i/o error: {0}")]
11+
Io(#[from] futures_io::Error),
12+
#[error(transparent)]
13+
LengthDelimited(#[from] tokio_util::codec::LengthDelimitedCodecError),
14+
#[error(transparent)]
15+
Decode(#[from] prost::DecodeError),
16+
#[error(transparent)]
17+
Encode(#[from] prost::EncodeError),
18+
}
19+
20+
/// JSON tokio codec
21+
pub struct ProtoCodec {
22+
/// Line parsing codec
23+
inner: LengthDelimitedCodec,
24+
/// Buffer for encoding messages
25+
buf: BytesMut,
26+
}
27+
28+
impl ProtoCodec {
29+
/// Create a new ProtoCodec
30+
pub fn new() -> Self {
31+
Self {
32+
inner: LengthDelimitedCodec::builder()
33+
.length_field_length(4)
34+
.new_codec(),
35+
buf: BytesMut::new(),
36+
}
37+
}
38+
}
39+
40+
impl Decoder for ProtoCodec {
41+
type Item = message::HyperionRequest;
42+
type Error = ProtoCodecError;
43+
44+
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
45+
match self.inner.decode(src) {
46+
Ok(inner_result) => Ok(match inner_result {
47+
Some(ref data) => Some(message::HyperionRequest::decode(data.clone().freeze())?),
48+
None => None,
49+
}),
50+
Err(error) => Err(error.into()),
51+
}
52+
}
53+
}
54+
55+
impl Encoder<message::HyperionReply> for ProtoCodec {
56+
type Error = ProtoCodecError;
57+
58+
fn encode(
59+
&mut self,
60+
item: message::HyperionReply,
61+
dst: &mut BytesMut,
62+
) -> Result<(), Self::Error> {
63+
self.buf.clear();
64+
self.buf.reserve(item.encoded_len());
65+
66+
match item.encode(&mut self.buf) {
67+
Ok(_) => Ok(self.inner.encode(self.buf.clone().freeze(), dst)?),
68+
Err(encode_error) => Err(encode_error.into()),
69+
}
70+
}
71+
}

0 commit comments

Comments
 (0)