Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query string passthrough #6

Merged
merged 3 commits into from
May 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions asgiproxy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@ def get_upstream_url(self, *, scope: Scope) -> str:
"""
raise NotImplementedError("...")

def get_upstream_url_with_query(self, *, scope: Scope) -> str:
"""
Get the upstream URL for a client request, including any query parameters to include.
"""
# The default implementation simply appends the original URL's query string to the
# upstream URL generated by `get_upstream_url`.
url = self.get_upstream_url(scope=scope)
query_string = scope.get("query_string")
if query_string:
sep = "&" if "?" in url else "?"
url += "{}{}".format(sep, query_string.decode("utf-8"))
return url

def process_client_headers(self, *, scope: Scope, headers: Headers) -> Headerlike:
"""
Process client HTTP headers before they're passed upstream.
Expand All @@ -39,8 +52,7 @@ def get_upstream_http_options(
"""
return dict(
method=client_request.method,
url=self.get_upstream_url(scope=scope),
params=client_request.query_params.multi_items(),
url=self.get_upstream_url_with_query(scope=scope),
data=data,
headers=self.process_client_headers(
scope=scope,
Expand Down
6 changes: 6 additions & 0 deletions tests/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from asgiproxy.config import BaseURLProxyConfigMixin, ProxyConfig


class ExampleComProxyConfig(BaseURLProxyConfigMixin, ProxyConfig):
upstream_base_url = "http://example.com"
rewrite_host_header = "example.com"
48 changes: 30 additions & 18 deletions tests/test_asgiproxy.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,42 @@
import pytest
from asgiref.testing import ApplicationCommunicator
from starlette.requests import Request

from asgiproxy.config import BaseURLProxyConfigMixin, ProxyConfig
from asgiproxy.context import ProxyContext
from asgiproxy.proxies.http import proxy_http
from tests.utils import http_response_from_asgi_messages
from tests.configs import ExampleComProxyConfig
from tests.utils import http_response_from_asgi_messages, make_http_scope


@pytest.mark.parametrize("full_url, expected_url", [
("http://127.0.0.1/pathlet/?encode&flep&murp", "http://example.com/pathlet/?encode&flep&murp"),
("http://127.0.0.1/pathlet/", "http://example.com/pathlet/"),
])
def test_query_string_passthrough(full_url, expected_url):
proxy_config = ExampleComProxyConfig()
scope = make_http_scope(
full_url=full_url,
headers={
"Accept": "text/html",
"User-Agent": "Foo",
},
)
client_request = Request(scope)
opts = proxy_config.get_upstream_http_options(
scope=scope, client_request=client_request, data=None
)
assert opts["url"] == expected_url


@pytest.mark.asyncio
async def test_asgiproxy():
scope = {
"type": "http",
"http_version": "1.1",
"method": "GET",
"path": "/",
"query_string": "",
"headers": [
(b"Host", b"http://127.0.0.1/"),
(b"User-Agent", b"Foo"),
(b"Accept", b"text/html"),
],
}

class ExampleComProxyConfig(BaseURLProxyConfigMixin, ProxyConfig):
upstream_base_url = "http://example.com"
rewrite_host_header = "example.com"
scope = make_http_scope(
full_url="http://127.0.0.1/",
headers={
"Accept": "text/html",
"User-Agent": "Foo",
},
)

context = ProxyContext(config=ExampleComProxyConfig())

Expand Down
26 changes: 26 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import gzip
from typing import List
from urllib.parse import urlparse

from asgiref.typing import HTTPScope
from starlette.datastructures import Headers


Expand All @@ -20,3 +22,27 @@ def http_response_from_asgi_messages(messages: List[dict]):
"body": body_bytes,
"content": content,
}


def make_http_scope(*, method="GET", full_url: str, headers=None) -> HTTPScope:
if headers is None:
headers = {}
headers = {str(key).lower(): str(value) for (key, value) in headers.items()}
url_parts = urlparse(full_url)
headers.setdefault("host", url_parts.netloc)
# noinspection PyTypeChecker
return {
"type": "http",
"asgi": {"version": "3.0", "spec_version": "3.0"},
"http_version": "1.1",
"method": method,
"scheme": url_parts.scheme,
"path": url_parts.path,
"raw_path": url_parts.path.encode(),
"query_string": url_parts.query.encode(),
"root_path": "/",
"headers": [(key.encode(), value.encode()) for (key, value) in headers.items()],
"extensions": {},
"client": None,
"server": None,
}