Skip to content

Commit

Permalink
CHanges based on comment
Browse files Browse the repository at this point in the history
  • Loading branch information
XuanWang-Amos committed Jan 20, 2024
1 parent 9ab8419 commit ef98926
Showing 1 changed file with 15 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import datetime
import logging
import signal
import sys
import threading
import time
from typing import (
Expand All @@ -31,6 +30,7 @@
Sequence,
Set,
Tuple,
Union,
)

import grpc
Expand Down Expand Up @@ -67,6 +67,7 @@
_METHOD_ENUM_TO_STR = {v: k for k, v in _METHOD_STR_TO_ENUM.items()}

PerMethodMetadataType = Mapping[str, Sequence[Tuple[str, str]]]
MetadataType = Sequence[Tuple[str, Union[str, bytes]]]
# FutureFromCall is both a grpc.Call and grpc.Future
FutureFromCallType = Any

Expand All @@ -82,7 +83,7 @@ class _StatsWatcher:
_no_remote_peer: int
_lock: threading.Lock
_condition: threading.Condition
_metadata_keys: List[str]
_metadata_keys: Set[str]
_include_all_metadata: bool
_metadata_by_peer: DefaultDict[
str, messages_pb2.LoadBalancerStatsResponse.MetadataByPeer
Expand All @@ -98,7 +99,7 @@ def __init__(self, start: int, end: int, metadata_keys: Iterable[str]):
)
self._condition = threading.Condition()
self._no_remote_peer = 0
self._metadata_keys = [key.lower() for key in metadata_keys]
self._metadata_keys = set(key.lower() for key in metadata_keys)
self._include_all_metadata = "*" in [
key.strip() for key in metadata_keys
]
Expand All @@ -109,25 +110,20 @@ def __init__(self, start: int, end: int, metadata_keys: Iterable[str]):
def _add_metadata(
self,
rpc_metadata: messages_pb2.LoadBalancerStatsResponse.RpcMetadata,
metadata: MetadataType,
type: messages_pb2.LoadBalancerStatsResponse.MetadataType,
metadata_to_add: MetadataType,
metadata_type: messages_pb2.LoadBalancerStatsResponse.MetadataType,
) -> None:
for key, value in metadata:
for key, value in metadata_to_add:
if self._include_all_metadata or key.lower() in self._metadata_keys:
metadata_entry = (
messages_pb2.LoadBalancerStatsResponse.MetadataEntry()
)
metadata_entry.key = key
metadata_entry.value = value
metadata_entry.type = type

rpc_metadata.metadata.append(metadata_entry)
rpc_metadata.metadata.append(
messages_pb2.LoadBalancerStatsResponse.MetadataEntry(key=key, value=value, type=metadata_type))

def on_rpc_complete(
self,
request_id: int,
peer: str,
method: str,
*,
initial_metadata: MetadataType,
trailing_metadata: MetadataType,
) -> None:
Expand Down Expand Up @@ -315,22 +311,22 @@ def _on_rpc_done(
rpc_id,
hostname,
method,
future.initial_metadata(),
future.trailing_metadata(),
initial_metadata = future.initial_metadata(),
trailing_metadata = future.trailing_metadata(),
)


def _remove_completed_rpcs(
futures: Mapping[int, FutureFromCallType], print_response: bool
rpc_futures: Mapping[int, FutureFromCallType], print_response: bool
) -> None:
logger.debug("Removing completed RPCs")
done = []
for future_id, (future, method) in futures.items():
for future_id, (future, method) in rpc_futures.items():
if future.done():
_on_rpc_done(future_id, future, method, args.print_response)
done.append(future_id)
for rpc_id in done:
del futures[rpc_id]
del rpc_futures[rpc_id]


def _cancel_all_rpcs(futures: Mapping[int, Tuple[grpc.Future, str]]) -> None:
Expand Down

0 comments on commit ef98926

Please sign in to comment.