diff --git a/src/pd/retry.rs b/src/pd/retry.rs index 0e65e13b..3c17a49e 100644 --- a/src/pd/retry.rs +++ b/src/pd/retry.rs @@ -70,18 +70,12 @@ impl RetryClient { } } -macro_rules! retry { - ($self: ident, $tag: literal, |$cluster: ident| $call: expr) => {{ +macro_rules! retry_core { + ($self: ident, $tag: literal, $call: expr) => {{ let stats = pd_stats($tag); let mut last_err = Ok(()); for _ in 0..LEADER_CHANGE_RETRY { - // use the block here to drop the guard of the read lock, - // otherwise `reconnect` will try to acquire the write lock and results in a deadlock - let res = { - let $cluster = &mut $self.cluster.write().await.0; - let res = $call.await; - res - }; + let res = $call; match stats.done(res) { Ok(r) => return Ok(r), @@ -103,6 +97,28 @@ macro_rules! retry { }}; } +macro_rules! retry_mut { + ($self: ident, $tag: literal, |$cluster: ident| $call: expr) => {{ + retry_core!($self, $tag, { + // use the block here to drop the guard of the lock, + // otherwise `reconnect` will try to acquire the write lock and results in a deadlock + let $cluster = &mut $self.cluster.write().await.0; + $call.await + }) + }}; +} + +macro_rules! retry { + ($self: ident, $tag: literal, |$cluster: ident| $call: expr) => {{ + retry_core!($self, $tag, { + // use the block here to drop the guard of the lock, + // otherwise `reconnect` will try to acquire the write lock and results in a deadlock + let $cluster = &$self.cluster.read().await.0; + $call.await + }) + }}; +} + impl RetryClient { pub async fn connect( endpoints: &[String], @@ -127,7 +143,7 @@ impl RetryClientTrait for RetryClient { // These get_* functions will try multiple times to make a request, reconnecting as necessary. // It does not know about encoding. Caller should take care of it. async fn get_region(self: Arc, key: Vec) -> Result { - retry!(self, "get_region", |cluster| { + retry_mut!(self, "get_region", |cluster| { let key = key.clone(); async { cluster @@ -141,7 +157,7 @@ impl RetryClientTrait for RetryClient { } async fn get_region_by_id(self: Arc, region_id: RegionId) -> Result { - retry!(self, "get_region_by_id", |cluster| async { + retry_mut!(self, "get_region_by_id", |cluster| async { cluster .get_region_by_id(region_id, self.timeout) .await @@ -152,7 +168,7 @@ impl RetryClientTrait for RetryClient { } async fn get_store(self: Arc, id: StoreId) -> Result { - retry!(self, "get_store", |cluster| async { + retry_mut!(self, "get_store", |cluster| async { cluster .get_store(id, self.timeout) .await @@ -161,7 +177,7 @@ impl RetryClientTrait for RetryClient { } async fn get_all_stores(self: Arc) -> Result> { - retry!(self, "get_all_stores", |cluster| async { + retry_mut!(self, "get_all_stores", |cluster| async { cluster .get_all_stores(self.timeout) .await @@ -174,7 +190,7 @@ impl RetryClientTrait for RetryClient { } async fn update_safepoint(self: Arc, safepoint: u64) -> Result { - retry!(self, "update_gc_safepoint", |cluster| async { + retry_mut!(self, "update_gc_safepoint", |cluster| async { cluster .update_safepoint(safepoint, self.timeout) .await @@ -257,7 +273,7 @@ mod test { } async fn retry_err(client: Arc) -> Result<()> { - retry!(client, "test", |_c| ready(Err(internal_err!("whoops")))) + retry_mut!(client, "test", |_c| ready(Err(internal_err!("whoops")))) } async fn retry_ok(client: Arc) -> Result<()> { @@ -310,7 +326,7 @@ mod test { client: Arc, max_retries: Arc, ) -> Result<()> { - retry!(client, "test", |c| { + retry_mut!(client, "test", |c| { c.fetch_add(1, std::sync::atomic::Ordering::SeqCst); let max_retries = max_retries.fetch_sub(1, Ordering::SeqCst) - 1; diff --git a/src/pd/timestamp.rs b/src/pd/timestamp.rs index 672b587b..a1cc7fbd 100644 --- a/src/pd/timestamp.rs +++ b/src/pd/timestamp.rs @@ -98,15 +98,13 @@ async fn run_tso( let mut responses = pd_client.tso(request_stream).await?.into_inner(); while let Some(Ok(resp)) = responses.next().await { - let mut pending_requests = pending_requests.lock().await; - - // Wake up the sending future blocked by too many pending requests as we are consuming - // some of them here. - if pending_requests.len() == MAX_PENDING_COUNT { - sending_future_waker.wake(); + { + let mut pending_requests = pending_requests.lock().await; + allocate_timestamps(&resp, &mut pending_requests)?; } - allocate_timestamps(&resp, &mut pending_requests)?; + // Wake up the sending future blocked by too many pending requests or locked. + sending_future_waker.wake(); } // TODO: distinguish between unexpected stream termination and expected end of test info!("TSO stream terminated"); @@ -139,6 +137,7 @@ impl Stream for TsoRequestStream { { pending_requests } else { + this.self_waker.register(cx.waker()); return Poll::Pending; }; let mut requests = Vec::new(); @@ -148,8 +147,8 @@ impl Stream for TsoRequestStream { Poll::Ready(Some(sender)) => { requests.push(sender); } - Poll::Ready(None) => return Poll::Ready(None), - Poll::Pending => break, + Poll::Ready(None) if requests.is_empty() => return Poll::Ready(None), + _ => break, } }