diff --git a/tests/conftest.py b/tests/conftest.py index 724ca65d3..4db3ae018 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,12 @@ from __future__ import annotations import functools -from typing import Any, Callable, Literal +from typing import Any, Literal import pytest from starlette.testclient import TestClient - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory @pytest.fixture diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 2176404d8..3ad1751a2 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -5,7 +5,6 @@ from typing import ( Any, AsyncGenerator, - Callable, Generator, ) @@ -23,8 +22,7 @@ from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket - -TestClientFactory = Callable[[ASGIApp], TestClient] +from tests.types import TestClientFactory class CustomMiddleware(BaseHTTPMiddleware): diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index 09ec9513f..630361243 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -1,15 +1,10 @@ -from typing import Callable - from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route -from starlette.testclient import TestClient -from starlette.types import ASGIApp - -TestClientFactory = Callable[[ASGIApp], TestClient] +from tests.types import TestClientFactory def test_cors_allow_all( diff --git a/tests/middleware/test_errors.py b/tests/middleware/test_errors.py index a2dbabd8a..e32f406ae 100644 --- a/tests/middleware/test_errors.py +++ b/tests/middleware/test_errors.py @@ -1,4 +1,4 @@ -from typing import Any, Callable +from typing import Any import pytest @@ -8,10 +8,8 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response from starlette.routing import Route -from starlette.testclient import TestClient from starlette.types import Receive, Scope, Send - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def test_handler( diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py index 5bfecadb7..b6f68296d 100644 --- a/tests/middleware/test_gzip.py +++ b/tests/middleware/test_gzip.py @@ -1,15 +1,10 @@ -from typing import Callable - from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.gzip import GZipMiddleware from starlette.requests import Request from starlette.responses import ContentStream, PlainTextResponse, StreamingResponse from starlette.routing import Route -from starlette.testclient import TestClient -from starlette.types import ASGIApp - -TestClientFactory = Callable[[ASGIApp], TestClient] +from tests.types import TestClientFactory def test_gzip_responses(test_client_factory: TestClientFactory) -> None: diff --git a/tests/middleware/test_https_redirect.py b/tests/middleware/test_https_redirect.py index 9195694a3..22dfc14b6 100644 --- a/tests/middleware/test_https_redirect.py +++ b/tests/middleware/test_https_redirect.py @@ -1,14 +1,10 @@ -from typing import Callable - from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route -from starlette.testclient import TestClient - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def test_https_redirect_middleware(test_client_factory: TestClientFactory) -> None: diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py index 4fbeec88c..9a0d70a0d 100644 --- a/tests/middleware/test_session.py +++ b/tests/middleware/test_session.py @@ -1,5 +1,4 @@ import re -from typing import Callable from starlette.applications import Starlette from starlette.middleware import Middleware @@ -8,8 +7,7 @@ from starlette.responses import JSONResponse from starlette.routing import Mount, Route from starlette.testclient import TestClient - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def view_session(request: Request) -> JSONResponse: diff --git a/tests/middleware/test_trusted_host.py b/tests/middleware/test_trusted_host.py index 466302210..ddff46c48 100644 --- a/tests/middleware/test_trusted_host.py +++ b/tests/middleware/test_trusted_host.py @@ -1,14 +1,10 @@ -from typing import Callable - from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route -from starlette.testclient import TestClient - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def test_trusted_host_middleware(test_client_factory: TestClientFactory) -> None: diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index 69842d3ad..58696bb65 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -5,10 +5,9 @@ from starlette._utils import collapse_excgroups from starlette.middleware.wsgi import WSGIMiddleware, build_environ -from starlette.testclient import TestClient +from tests.types import TestClientFactory WSGIResponse = Iterable[bytes] -TestClientFactory = Callable[..., TestClient] StartResponse = Callable[..., Any] Environment = Dict[str, Any] diff --git a/tests/test_applications.py b/tests/test_applications.py index af3cdda6c..20da7ea81 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,7 +1,7 @@ import os from contextlib import asynccontextmanager from pathlib import Path -from typing import AsyncGenerator, AsyncIterator, Callable, Generator +from typing import AsyncGenerator, AsyncIterator, Generator import anyio import pytest @@ -20,8 +20,7 @@ from starlette.testclient import TestClient from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory async def error_500(request: Request, exc: HTTPException) -> JSONResponse: diff --git a/tests/test_authentication.py b/tests/test_authentication.py index ecddda75e..35c1110d1 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -21,10 +21,9 @@ from starlette.requests import HTTPConnection, Request from starlette.responses import JSONResponse, Response from starlette.routing import Route, WebSocketRoute -from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect +from tests.types import TestClientFactory -TestClientFactory = Callable[..., TestClient] AsyncEndpoint = Callable[..., Awaitable[Response]] SyncEndpoint = Callable[..., Response] diff --git a/tests/test_background.py b/tests/test_background.py index 846deecfd..cbffcc06a 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -1,13 +1,9 @@ -from typing import Callable - import pytest from starlette.background import BackgroundTask, BackgroundTasks from starlette.responses import Response -from starlette.testclient import TestClient from starlette.types import Receive, Scope, Send - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def test_async_task(test_client_factory: TestClientFactory) -> None: diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index aba3ceb1a..bac6814e4 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -1,5 +1,5 @@ from contextvars import ContextVar -from typing import Callable, Iterator +from typing import Iterator import anyio import pytest @@ -9,9 +9,7 @@ from starlette.requests import Request from starlette.responses import Response from starlette.routing import Route -from starlette.testclient import TestClient - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory @pytest.mark.anyio diff --git a/tests/test_convertors.py b/tests/test_convertors.py index 72ee17a82..520c98767 100644 --- a/tests/test_convertors.py +++ b/tests/test_convertors.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Callable, Iterator +from typing import Iterator import pytest @@ -8,9 +8,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Route, Router -from starlette.testclient import TestClient - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory @pytest.fixture(scope="module", autouse=True) diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index eeb0f2322..8f201e25b 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterator +from typing import Iterator import pytest @@ -8,8 +8,7 @@ from starlette.routing import Route, Router from starlette.testclient import TestClient from starlette.websockets import WebSocket - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory class Homepage(HTTPEndpoint): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 401ad8212..f4e91ad87 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, Generator +from typing import Generator import pytest @@ -10,8 +10,7 @@ from starlette.routing import Route, Router, WebSocketRoute from starlette.testclient import TestClient from starlette.types import Receive, Scope, Send - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def raise_runtime_error(request: Request) -> None: diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index ed2226878..8d97a0ba7 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -13,10 +13,8 @@ from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Mount -from starlette.testclient import TestClient from starlette.types import ASGIApp, Receive, Scope, Send - -TestClientFactory = typing.Callable[..., TestClient] +from tests.types import TestClientFactory class ForceMultipartDict(typing.Dict[typing.Any, typing.Any]): diff --git a/tests/test_requests.py b/tests/test_requests.py index c52ebc141..02f29ee35 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,7 +1,7 @@ from __future__ import annotations import sys -from typing import Any, Callable, Iterator +from typing import Any, Iterator import anyio import pytest @@ -9,10 +9,8 @@ from starlette.datastructures import Address, State from starlette.requests import ClientDisconnect, Request from starlette.responses import JSONResponse, PlainTextResponse, Response -from starlette.testclient import TestClient from starlette.types import Message, Receive, Scope, Send - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def test_request_url(test_client_factory: TestClientFactory) -> None: diff --git a/tests/test_responses.py b/tests/test_responses.py index f05529dd7..791e9b3ac 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -5,7 +5,7 @@ import time from http.cookies import SimpleCookie from pathlib import Path -from typing import AsyncIterator, Callable, Iterator +from typing import AsyncIterator, Iterator import anyio import pytest @@ -23,8 +23,7 @@ ) from starlette.testclient import TestClient from starlette.types import Message, Receive, Scope, Send - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def test_text_response(test_client_factory: TestClientFactory) -> None: diff --git a/tests/test_routing.py b/tests/test_routing.py index 03c31c67f..1490723b4 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -17,8 +17,7 @@ from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect - -TestClientFactory = typing.Callable[..., TestClient] +from tests.types import TestClientFactory def homepage(request: Request) -> Response: diff --git a/tests/test_schemas.py b/tests/test_schemas.py index e00b2b8de..f4a5b4ad9 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -1,20 +1,16 @@ -from typing import Callable - from starlette.applications import Starlette from starlette.endpoints import HTTPEndpoint from starlette.requests import Request from starlette.responses import Response from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.schemas import SchemaGenerator -from starlette.testclient import TestClient from starlette.websockets import WebSocket +from tests.types import TestClientFactory schemas = SchemaGenerator( {"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}} ) -TestClientFactory = Callable[..., TestClient] - def ws(session: WebSocket) -> None: """ws""" diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 085301302..65d71b97b 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -16,9 +16,7 @@ from starlette.responses import Response from starlette.routing import Mount from starlette.staticfiles import StaticFiles -from starlette.testclient import TestClient - -TestClientFactory = typing.Callable[..., TestClient] +from tests.types import TestClientFactory def test_staticfiles(tmpdir: Path, test_client_factory: TestClientFactory) -> None: diff --git a/tests/test_templates.py b/tests/test_templates.py index 10a1366bc..8e344f331 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -1,7 +1,6 @@ from __future__ import annotations import os -import typing from pathlib import Path from unittest import mock @@ -16,9 +15,7 @@ from starlette.responses import Response from starlette.routing import Route from starlette.templating import Jinja2Templates -from starlette.testclient import TestClient - -TestClientFactory = typing.Callable[..., TestClient] +from tests.types import TestClientFactory def test_templates(tmpdir: Path, test_client_factory: TestClientFactory) -> None: diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 4ed1ced9a..77de3d976 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -4,7 +4,7 @@ import sys from asyncio import Task, current_task as asyncio_current_task from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator, Callable +from typing import Any, AsyncGenerator import anyio import anyio.lowlevel @@ -20,8 +20,7 @@ from starlette.testclient import ASGIInstance, TestClient from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def mock_service_endpoint(request: Request) -> JSONResponse: @@ -212,7 +211,7 @@ async def inner(receive: Receive, send: Send) -> None: return inner - client = test_client_factory(app) + client = test_client_factory(app) # type: ignore response = client.get("/") assert response.text == "Hello, world!" @@ -252,7 +251,7 @@ async def asgi(receive: Receive, send: Send) -> None: return asgi - client = test_client_factory(app) + client = test_client_factory(app) # type: ignore with client.websocket_connect("/") as websocket: data = websocket.receive_json() assert data == {"message": "test"} @@ -268,7 +267,7 @@ async def asgi(receive: Receive, send: Send) -> None: return asgi - client = test_client_factory(app) + client = test_client_factory(app) # type: ignore with client.websocket_connect("/") as websocket: ... assert websocket.should_close.is_set() diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 95ffbdbe7..16d2d0f1f 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -1,5 +1,5 @@ import sys -from typing import Any, Callable, MutableMapping +from typing import Any, MutableMapping import anyio import pytest @@ -7,11 +7,10 @@ from starlette import status from starlette.responses import Response -from starlette.testclient import TestClient, WebSocketDenialResponse +from starlette.testclient import WebSocketDenialResponse from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def test_websocket_url(test_client_factory: TestClientFactory) -> None: diff --git a/tests/types.py b/tests/types.py new file mode 100644 index 000000000..1cbacf107 --- /dev/null +++ b/tests/types.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol + +import httpx + +from starlette.testclient import TestClient +from starlette.types import ASGIApp + +if TYPE_CHECKING: + + class TestClientFactory(Protocol): # pragma: no cover + def __call__( + self, + app: ASGIApp, + base_url: str = "http://testserver", + raise_server_exceptions: bool = True, + root_path: str = "", + cookies: httpx._types.CookieTypes | None = None, + headers: dict[str, str] | None = None, + follow_redirects: bool = True, + ) -> TestClient: ... +else: # pragma: no cover + + class TestClientFactory: + __test__ = False