Skip to content

Commit

Permalink
allow non-ssl rpc clients (2) (#17510)
Browse files Browse the repository at this point in the history
* allow non-ssl rpc clients (2)

* add missing tests

* another test
  • Loading branch information
altendky authored Feb 7, 2024
1 parent a64aafa commit 9ed9abe
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 12 deletions.
9 changes: 8 additions & 1 deletion chia/cmds/cmds_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from chia.wallet.util.tx_config import CoinSelectionConfig, CoinSelectionConfigLoader, TXConfig, TXConfigLoader

NODE_TYPES: Dict[str, Type[RpcClient]] = {
"base": RpcClient,
"farmer": FarmerRpcClient,
"wallet": WalletRpcClient,
"full_node": FullNodeRpcClient,
Expand All @@ -41,6 +42,7 @@
}

node_config_section_names: Dict[Type[RpcClient], str] = {
RpcClient: "base",
FarmerRpcClient: "farmer",
WalletRpcClient: "wallet",
FullNodeRpcClient: "full_node",
Expand Down Expand Up @@ -92,6 +94,7 @@ async def get_any_service_client(
rpc_port: Optional[int] = None,
root_path: Optional[Path] = None,
consume_errors: bool = True,
use_ssl: bool = True,
) -> AsyncIterator[Tuple[_T_RpcClient, Dict[str, Any]]]:
"""
Yields a tuple with a RpcClient for the applicable node type a dictionary of the node's configuration,
Expand All @@ -112,7 +115,11 @@ async def get_any_service_client(
if rpc_port is None:
rpc_port = config[node_type]["rpc_port"]
# select node client type based on string
node_client = await client_type.create(self_hostname, uint16(rpc_port), root_path, config)
if use_ssl:
node_client = await client_type.create(self_hostname, uint16(rpc_port), root_path=root_path, net_config=config)
else:
node_client = await client_type.create(self_hostname, uint16(rpc_port), root_path=None, net_config=None)

try:
# check if we can connect to node
await validate_client_connection(node_client, node_type, rpc_port, consume_errors)
Expand Down
38 changes: 27 additions & 11 deletions chia/rpc/rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class RpcClient:

url: str
session: aiohttp.ClientSession
ssl_context: SSLContext
ssl_context: Optional[SSLContext]
hostname: str
port: uint16
closing_task: Optional[asyncio.Task] = None
Expand All @@ -41,19 +41,35 @@ async def create(
cls: Type[_T_RpcClient],
self_hostname: str,
port: uint16,
root_path: Path,
net_config: Dict[str, Any],
root_path: Optional[Path],
net_config: Optional[Dict[str, Any]],
) -> _T_RpcClient:
ca_crt_path, ca_key_path = private_ssl_ca_paths(root_path, net_config)
crt_path = root_path / net_config["daemon_ssl"]["private_crt"]
key_path = root_path / net_config["daemon_ssl"]["private_key"]
timeout = net_config.get("rpc_timeout", 300)
if (root_path is not None) != (net_config is not None):
raise ValueError("Either both or neither of root_path and net_config must be provided")

ssl_context: Optional[SSLContext]
if root_path is None:
scheme = "http"
ssl_context = None
else:
assert root_path is not None
assert net_config is not None
scheme = "https"
ca_crt_path, ca_key_path = private_ssl_ca_paths(root_path, net_config)
crt_path = root_path / net_config["daemon_ssl"]["private_crt"]
key_path = root_path / net_config["daemon_ssl"]["private_key"]
ssl_context = ssl_context_for_client(ca_crt_path, ca_key_path, crt_path, key_path)

timeout = 300
if net_config is not None:
timeout = net_config.get("rpc_timeout", timeout)

self = cls(
hostname=self_hostname,
port=port,
url=f"https://{self_hostname}:{str(port)}/",
url=f"{scheme}://{self_hostname}:{str(port)}/",
session=aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)),
ssl_context=ssl_context_for_client(ca_crt_path, ca_key_path, crt_path, key_path),
ssl_context=ssl_context,
)

return self
Expand All @@ -64,8 +80,8 @@ async def create_as_context(
cls: Type[_T_RpcClient],
self_hostname: str,
port: uint16,
root_path: Path,
net_config: Dict[str, Any],
root_path: Optional[Path] = None,
net_config: Optional[Dict[str, Any]] = None,
) -> AsyncIterator[_T_RpcClient]:
self = await cls.create(
self_hostname=self_hostname,
Expand Down
1 change: 1 addition & 0 deletions tests/cmds/cmd_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ async def test_get_any_service_client(
rpc_port: Optional[int] = None,
root_path: Optional[Path] = None,
consume_errors: bool = True,
use_ssl: bool = True,
) -> AsyncIterator[Tuple[_T_RpcClient, Dict[str, Any]]]:
if root_path is None:
root_path = default_root
Expand Down
27 changes: 27 additions & 0 deletions tests/cmds/test_cmds_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from __future__ import annotations

from pathlib import Path

import pytest

from chia.cmds.cmds_util import get_any_service_client
from chia.rpc.rpc_client import RpcClient
from tests.util.misc import RecordingWebServer


@pytest.mark.anyio
async def test_get_any_service_client_works_without_ssl(
root_path_populated_with_config: Path,
recording_web_server: RecordingWebServer,
) -> None:
expected_result = {"success": True, "keepy": "uppy"}

async with get_any_service_client(
client_type=RpcClient,
rpc_port=recording_web_server.web_server.listen_port,
root_path=root_path_populated_with_config,
use_ssl=False,
) as [rpc_client, _]:
result = await rpc_client.fetch(path="", request_json={"response": expected_result})

assert result == expected_result
Empty file added tests/rpc/__init__.py
Empty file.
49 changes: 49 additions & 0 deletions tests/rpc/test_rpc_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional

import pytest

from chia.rpc.rpc_client import RpcClient
from chia.util.ints import uint16
from tests.util.misc import Marks, RecordingWebServer, datacases


@dataclass
class InvalidCreateCase:
id: str
root_path: Optional[Path] = None
net_config: Optional[Dict[str, Any]] = None
marks: Marks = ()


@datacases(
InvalidCreateCase(id="just root path", root_path=Path("/root/path")),
InvalidCreateCase(id="just net config", net_config={}),
)
@pytest.mark.anyio
async def test_rpc_client_create_raises_for_invalid_root_path_net_config_combinations(
case: InvalidCreateCase,
) -> None:
with pytest.raises(ValueError, match="Either both or neither of"):
await RpcClient.create(
self_hostname="",
port=uint16(0),
root_path=case.root_path,
net_config=case.net_config,
)


@pytest.mark.anyio
async def test_rpc_client_works_without_ssl(recording_web_server: RecordingWebServer) -> None:
expected_result = {"success": True, "daddy": "putdown"}

async with RpcClient.create_as_context(
self_hostname=recording_web_server.web_server.hostname,
port=recording_web_server.web_server.listen_port,
) as rpc_client:
result = await rpc_client.fetch(path="", request_json={"response": expected_result})

assert result == expected_result

0 comments on commit 9ed9abe

Please sign in to comment.