Skip to content

Commit

Permalink
add type hints to protobuf and client
Browse files Browse the repository at this point in the history
  • Loading branch information
daddycocoaman committed Oct 6, 2022
1 parent 95f2bf2 commit 667a676
Show file tree
Hide file tree
Showing 17 changed files with 12,204 additions and 6,817 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Homepage = "https://github.com/moloch--/sliver-py"
dependencies = [
"grpcio~=1.47",
"grpcio-tools~=1.47",
"mypy-protobuf~=3.3.0",
]

[tool.hatch.envs.dev]
Expand Down
125 changes: 125 additions & 0 deletions scripts/protobufgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import os
import pkg_resources
from pathlib import Path
from grpc_tools import protoc
from rich.console import Console


console = Console(log_time=False, log_path=False)
ROOT_DIR = Path(__file__).parents[1]
os.chdir(ROOT_DIR)

IN_DIR = ROOT_DIR / "sliver/protobuf"
OUT_DIR = ROOT_DIR / "src/sliver/pb"

COMMON_PROTO_PATH = IN_DIR / "commonpb/common.proto"
SLIVER_PROTO_PATH = IN_DIR / "sliverpb/sliver.proto"
CLIENT_PROTO_PATH = IN_DIR / "clientpb/client.proto"
GRPC_PROTO_PATH = IN_DIR / "rpcpb/services.proto"

# There is a more accurate way to do all of this using the ast module but this works for now
try:
# Cleanup old files
console.log("[bold green]Removing old generated files...")
for file in OUT_DIR.glob("**/*.py"):
if file.name.split("_")[0] in ["common", "sliver", "client", "services"]:
file.unlink()
console.log(f"Removed {file}")

console.log("[bold green]Generating new files...")
proto_pyd = pkg_resources.resource_filename("grpc_tools", "_proto")

# Generate commonpb
protoc.main(
f"-I{proto_pyd} -I {IN_DIR.relative_to(ROOT_DIR)} --mypy_out=readable_stubs:{OUT_DIR} --python_out={OUT_DIR} {COMMON_PROTO_PATH.relative_to(ROOT_DIR)}".split()
)
console.log(f"Generated {COMMON_PROTO_PATH.name}")

# Generate sliverpb
protoc.main(
f"-I{proto_pyd} -I {IN_DIR.relative_to(ROOT_DIR)} --mypy_out=readable_stubs:{OUT_DIR} --python_out={OUT_DIR} {SLIVER_PROTO_PATH.relative_to(ROOT_DIR)}".split()
)
console.log(f"Generated {SLIVER_PROTO_PATH.name}")

# Generate clientpb
protoc.main(
f"-I{proto_pyd} -I {IN_DIR.relative_to(ROOT_DIR)} --mypy_out=readable_stubs:{OUT_DIR} --python_out={OUT_DIR} {CLIENT_PROTO_PATH.relative_to(ROOT_DIR)}".split()
)
console.log(f"Generated {CLIENT_PROTO_PATH.name}")

# Generate rpcpb
protoc.main(
f"-I{proto_pyd} -I {IN_DIR.relative_to(ROOT_DIR)} --mypy_out=readable_stubs:{OUT_DIR} --mypy_grpc_out={OUT_DIR} --python_out={OUT_DIR} --grpc_python_out={OUT_DIR} {GRPC_PROTO_PATH.relative_to(ROOT_DIR)}".split()
)
console.log(f"Generated {GRPC_PROTO_PATH.name}")

# Rewrite imports for py files
console.log("[bold green]Rewriting imports for py files...")
for file in OUT_DIR.glob("**/*.py"):
if file.name.split("_")[0] in ["sliver", "client", "services"]:
content = (
file.read_text()
.replace(
"from commonpb import common_pb2 as commonpb_dot_common__pb2",
"from ..commonpb import common_pb2 as commonpb_dot_common__pb2",
)
.replace(
"from sliverpb import sliver_pb2 as sliverpb_dot_sliver__pb2",
"from ..sliverpb import sliver_pb2 as sliverpb_dot_sliver__pb2",
)
.replace(
"from clientpb import client_pb2 as clientpb_dot_client__pb2",
"from ..clientpb import client_pb2 as clientpb_dot_client__pb2",
)
)
# Need to make sure grpc.experimental is imported
if file.name == "services_pb2_grpc.py":
content = (content
.replace("grpc.Channel", "grpc.aio.Channel")
.replace("import grpc", "import grpc\nimport grpc.experimental")
) # fmt: skip

file.write_text(content)
console.log(f"Rewrote imports for {file}")

# Rewrite imports for pyi files
console.log("[bold green]Rewriting imports for pyi files...")
for file in OUT_DIR.glob("**/*.pyi"):
if file.name.split("_")[0] in ["sliver", "client", "services"]:
content = (
file.read_text()
.replace(
"import commonpb.common_pb2",
"from ..commonpb import common_pb2",
)
.replace(
"commonpb.common_pb2",
"common_pb2",
)
.replace(
"import sliverpb.sliver_pb2",
"from ..sliverpb import sliver_pb2",
)
.replace(
"sliverpb.sliver_pb2",
"sliver_pb2",
)
.replace(
"import clientpb.client_pb2",
"from ..clientpb import client_pb2",
)
.replace(
"clientpb.client_pb2",
"client_pb2",
)
)

# Need to correct type hints. This is a hacky way to do it but it works
if file.name == "services_pb2_grpc.pyi":
content = content.replace("grpc.Channel", "grpc.aio.Channel")

file.write_text(content)
console.log(f"Rewrote imports for {file}")
except Exception as e:
console.log("[bold red]Failed to generate files!")
console.log(e)
2 changes: 1 addition & 1 deletion sliver
Submodule sliver updated 2829 files
14 changes: 7 additions & 7 deletions src/sliver/beacon.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,28 @@ class BaseBeacon(object):
def __init__(
self,
beacon: client_pb2.Beacon,
channel: grpc.Channel,
channel: grpc.aio.Channel,
timeout: int = TIMEOUT,
):
"""Base class for Beacon classes.
:param beacon: Beacon protobuf object.
:type beacon: client_pb2.Beacon
:param channel: A gRPC channel.
:type channel: grpc.Channel
:type channel: grpc.aio.Channel
:param timeout: Seconds to wait for timeout, defaults to TIMEOUT
:type timeout: int, optional
"""
self._log = logging.getLogger(self.__class__.__name__)
self._channel: grpc.Channel = channel
self._beacon: client_pb2.Beacon = beacon
self._channel = channel
self._beacon = beacon
self._stub = SliverRPCStub(channel)
self.timeout = timeout
self.beacon_tasks: Dict[str, Tuple[asyncio.Future, Any]] = {}
asyncio.get_event_loop().create_task(self.taskresult_events())

@property
def beacon_id(self) -> int:
def beacon_id(self) -> str:
"""Beacon ID"""
return self._beacon.ID

Expand All @@ -63,7 +63,7 @@ def name(self) -> str:
return self._beacon.Name

@property
def hostname(self) -> int:
def hostname(self) -> str:
"""Beacon hostname"""
return self._beacon.Hostname

Expand Down Expand Up @@ -118,7 +118,7 @@ def filename(self) -> str:
return self._beacon.Filename

@property
def last_checkin(self) -> str:
def last_checkin(self) -> int:
"""Last check in time"""
return self._beacon.LastCheckin

Expand Down
Loading

0 comments on commit 667a676

Please sign in to comment.