From f6eea01ed1b2c4e59e734ea1df655cc1e6bbe964 Mon Sep 17 00:00:00 2001 From: Lyonnet Date: Sun, 1 Dec 2024 00:53:17 +0800 Subject: [PATCH] Fix url quote with `requote_uri` from requests library --- curl_cffi/requests/session.py | 58 +++++++++++++++++++++++++++++++-- tests/unittest/test_requests.py | 26 ++++++++++----- 2 files changed, 72 insertions(+), 12 deletions(-) diff --git a/curl_cffi/requests/session.py b/curl_cffi/requests/session.py index a66173b..d9edabb 100644 --- a/curl_cffi/requests/session.py +++ b/curl_cffi/requests/session.py @@ -30,7 +30,7 @@ from .. import AsyncCurl, Curl, CurlError, CurlHttpVersion, CurlInfo, CurlOpt, CurlSslVersion from ..curl import CURL_WRITEFUNC_ERROR, CurlMime from .cookies import Cookies, CookieTypes, CurlMorsel -from .exceptions import ImpersonateError, RequestException, SessionClosed, code2error +from .exceptions import InvalidURL, ImpersonateError, RequestException, SessionClosed, code2error from .headers import Headers, HeaderTypes from .impersonate import BrowserType # noqa: F401 from .impersonate import ( @@ -150,7 +150,6 @@ def _update_url_params(url: str, params: Union[Dict, List, Tuple]) -> str: new_args_counter = Counter(x[0] for x in params) for key, value in params: # Bool and Dict values should be converted to json-friendly values - # you may throw this part away if you don't like it :) if isinstance(value, (bool, dict)): value = dumps(value) # 1 to 1 mapping, we have to search and update it. @@ -176,6 +175,57 @@ def _update_url_params(url: str, params: Union[Dict, List, Tuple]) -> str: return new_url +# Adapted from: https://github.com/psf/requests/blob/1ae6fc3137a11e11565ed22436aa1e77277ac98c/src%2Frequests%2Futils.py#L633-L682 +# License: Apache 2.0 + +# The unreserved URI characters (RFC 3986) +UNRESERVED_SET = frozenset( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + "0123456789-._~" +) + + +def unquote_unreserved(uri: str) -> str: + """Un-escape any percent-escape sequences in a URI that are unreserved + characters. This leaves all reserved, illegal and non-ASCII bytes encoded. + """ + parts = uri.split("%") + for i in range(1, len(parts)): + h = parts[i][0:2] + if len(h) == 2 and h.isalnum(): + try: + c = chr(int(h, 16)) + except ValueError: + raise InvalidURL(f"Invalid percent-escape sequence: '{h}'") + + if c in UNRESERVED_SET: + parts[i] = c + parts[i][2:] + else: + parts[i] = f"%{parts[i]}" + else: + parts[i] = f"%{parts[i]}" + return "".join(parts) + + +def requote_uri(uri: str) -> str: + """Re-quote the given URI. + + This function passes the given URI through an unquote/quote cycle to + ensure that it is fully and consistently quoted. + """ + safe_with_percent = "!#$%&'()*+,/:;=?@[]~|" + safe_without_percent = "!#$&'()*+,/:;=?@[]~|" + try: + # Unquote only the unreserved characters + # Then quote only illegal characters (do not quote reserved, + # unreserved, or '%') + return quote(unquote_unreserved(uri), safe=safe_with_percent) + except InvalidURL: + # We couldn't unquote the given URI, so let's try quoting it, but + # there may be unquoted '%'s in the URI. We need to make sure they're + # properly quoted so they do not cause issues elsewhere. + return quote(uri, safe=safe_without_percent) + + # TODO: should we move this function to headers.py? def _update_header_line(header_lines: List[str], key: str, value: str, replace: bool = False): """Update header line list by key value pair.""" @@ -430,8 +480,10 @@ def _set_curl_options( url = _update_url_params(url, params) if self.base_url: url = urljoin(self.base_url, url) - if quote is not False: + if quote: url = _quote_path_and_params(url, quote_str=quote) + if quote is not False: + url = requote_uri(url) c.setopt(CurlOpt.URL, url.encode()) # data/body/json diff --git a/tests/unittest/test_requests.py b/tests/unittest/test_requests.py index 8a78774..348658d 100644 --- a/tests/unittest/test_requests.py +++ b/tests/unittest/test_requests.py @@ -164,8 +164,8 @@ def test_url_encode(server): # should not change url = "http://127.0.0.1:8000/%2f%2f%2f" - r = requests.get(str(url)) - assert r.url == str(url) + r = requests.get(url) + assert r.url == url url = "http://127.0.0.1:8000/imaginary-pagination:7" r = requests.get(str(url)) @@ -175,15 +175,17 @@ def test_url_encode(server): r = requests.get(str(url)) assert r.url == url - # Non-ASCII URL should be percent encoded as UTF-8 sequence - non_ascii_url = "http://127.0.0.1:8000/search?q=测试" - encoded_non_ascii_url = "http://127.0.0.1:8000/search?q=%E6%B5%8B%E8%AF%95" + # NOTE: this seems to be unnecessary - r = requests.get(non_ascii_url) - assert r.url == encoded_non_ascii_url + # Non-ASCII URL should be percent encoded as UTF-8 sequence + # non_ascii_url = "http://127.0.0.1:8000/search?q=测试" + # encoded_non_ascii_url = "http://127.0.0.1:8000/search?q=%E6%B5%8B%E8%AF%95" + # + # r = requests.get(non_ascii_url) + # assert r.url == encoded_non_ascii_url - r = requests.get(encoded_non_ascii_url) - assert r.url == encoded_non_ascii_url + # r = requests.get(encoded_non_ascii_url) + # assert r.url == encoded_non_ascii_url # should be quoted url = "http://127.0.0.1:8000/e x a m p l e" @@ -216,6 +218,12 @@ def test_url_encode(server): r = requests.get(url, quote=False) assert r.url == url + # Do not unquote + url = "http://127.0.0.1:8000/path?token=example%7C2024-10-20T10%3A00%3A00Z" + r = requests.get(url) + print(r.url) + assert r.url == url + # empty values should be kept url = "http://127.0.0.1:8000/api?param1=value1¶m2=¶m3=value3" r = requests.get(url)