diff --git a/Cargo.toml b/Cargo.toml index 08acc91f..88815951 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ futures-sink = "0.3" lazy_static = "1" lru = "0.8.1" mio = { version = "0.8.0", features = ["os-poll", "net"] } -mysql_common = { version = "0.29.0", default-features = false } +mysql_common = { version = "0.29.2", default-features = false } once_cell = "1.7.2" pem = "1.0.1" percent-encoding = "2.1.0" diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 8dec489e..b4ead418 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -16,7 +16,7 @@ use mysql_common::{ packets::{ binlog_request::BinlogRequest, AuthPlugin, AuthSwitchRequest, CommonOkPacket, ErrPacket, HandshakePacket, HandshakeResponse, OkPacket, OkPacketDeserializer, OldAuthSwitchRequest, - ResultSetTerminator, SslRequest, + OldEofPacket, ResultSetTerminator, SslRequest, }, proto::MySerialize, }; @@ -697,9 +697,18 @@ impl Conn { /// Returns `true` for ProgressReport packet. fn handle_packet(&mut self, packet: &PooledBuf) -> Result { let ok_packet = if self.has_pending_result() { - ParseBuf(&*packet) - .parse::>(self.capabilities()) - .map(|x| x.into_inner()) + if self + .capabilities() + .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF) + { + ParseBuf(&*packet) + .parse::>(self.capabilities()) + .map(|x| x.into_inner()) + } else { + ParseBuf(&*packet) + .parse::>(self.capabilities()) + .map(|x| x.into_inner()) + } } else { ParseBuf(&*packet) .parse::>(self.capabilities()) @@ -1059,7 +1068,7 @@ impl Conn { mod test { use bytes::Bytes; use futures_util::stream::{self, StreamExt}; - use mysql_common::binlog::events::EventData; + use mysql_common::{binlog::events::EventData, constants::MAX_PAYLOAD_LEN}; use tokio::time::timeout; use std::time::Duration; @@ -1448,15 +1457,15 @@ mod test { #[tokio::test] async fn should_perform_queries() -> super::Result<()> { - let long_string = ::std::iter::repeat('A') - .take(18 * 1024 * 1024) - .collect::(); let mut conn = Conn::new(get_opts()).await?; - let result: Vec<(String, u8)> = conn - .query(format!(r"SELECT '{}', 231", long_string)) - .await?; + for x in (MAX_PAYLOAD_LEN - 2)..=(MAX_PAYLOAD_LEN + 2) { + let long_string = ::std::iter::repeat('A').take(x).collect::(); + let result: Vec<(String, u8)> = conn + .query(format!(r"SELECT '{}', 231", long_string)) + .await?; + assert_eq!((long_string, 231_u8), result[0]); + } conn.disconnect().await?; - assert_eq!((long_string, 231_u8), result[0]); Ok(()) } diff --git a/src/queryable/mod.rs b/src/queryable/mod.rs index b6bcc2b6..52bb4252 100644 --- a/src/queryable/mod.rs +++ b/src/queryable/mod.rs @@ -8,8 +8,8 @@ use futures_util::FutureExt; use mysql_common::{ + constants::MAX_PAYLOAD_LEN, io::ParseBuf, - packets::{OkPacketDeserializer, ResultSetTerminator}, proto::{Binary, Text}, row::RowDeserializer, value::ServerSide, @@ -42,10 +42,11 @@ pub trait Protocol: fmt::Debug + Send + Sync + 'static { fn result_set_meta(columns: Arc<[Column]>) -> ResultSetMeta; fn read_result_set_row(packet: &[u8], columns: Arc<[Column]>) -> Result; fn is_last_result_set_packet(capabilities: CapabilityFlags, packet: &[u8]) -> bool { - packet.len() < 8 - && ParseBuf(packet) - .parse::>(capabilities) - .is_ok() + if capabilities.contains(CapabilityFlags::CLIENT_DEPRECATE_EOF) { + packet[0] == 0xFE && packet.len() < MAX_PAYLOAD_LEN + } else { + packet[0] == 0xFE && packet.len() < 8 + } } }