diff --git a/ariadne/asgi/graphql.py b/ariadne/asgi/graphql.py index 797ccb107..6f0e6045f 100644 --- a/ariadne/asgi/graphql.py +++ b/ariadne/asgi/graphql.py @@ -1,5 +1,6 @@ +from collections.abc import Awaitable from logging import Logger, LoggerAdapter -from typing import Any, Awaitable, Optional, Type, Union +from typing import Any, Optional, Union from graphql import ExecutionContext, GraphQLSchema from starlette.requests import Request @@ -45,7 +46,7 @@ def __init__( explorer: Optional[Explorer] = None, logger: Union[None, str, Logger, LoggerAdapter] = None, error_formatter: ErrorFormatter = format_error, - execution_context_class: Optional[Type[ExecutionContext]] = None, + execution_context_class: Optional[type[ExecutionContext]] = None, http_handler: Optional[GraphQLHTTPHandler] = None, websocket_handler: Optional[GraphQLWebsocketHandler] = None, ) -> None: @@ -182,7 +183,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): elif scope["type"] == "websocket": await self.websocket_handler.handle(scope=scope, receive=receive, send=send) else: - raise ValueError("Unknown scope type: %r" % (scope["type"],)) + raise ValueError("Unknown scope type: {!r}".format(scope["type"])) async def handle_request(self, request: Request) -> Response: """Shortcut for `graphql_app.http_handler.handle_request(...)`.""" diff --git a/ariadne/asgi/handlers/__init__.py b/ariadne/asgi/handlers/__init__.py index 38eedc1b9..1134611ed 100644 --- a/ariadne/asgi/handlers/__init__.py +++ b/ariadne/asgi/handlers/__init__.py @@ -1,8 +1,7 @@ from .base import GraphQLHandler, GraphQLHttpHandlerBase, GraphQLWebsocketHandler -from .http import GraphQLHTTPHandler from .graphql_transport_ws import GraphQLTransportWSHandler from .graphql_ws import GraphQLWSHandler - +from .http import GraphQLHTTPHandler __all__ = [ "GraphQLHandler", diff --git a/ariadne/asgi/handlers/base.py b/ariadne/asgi/handlers/base.py index 18f5fb554..9737afac1 100644 --- a/ariadne/asgi/handlers/base.py +++ b/ariadne/asgi/handlers/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from inspect import isawaitable from logging import Logger, LoggerAdapter -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Union from graphql import DocumentNode, ExecutionContext, GraphQLSchema, MiddlewareManager from starlette.types import Receive, Scope, Send @@ -41,8 +41,8 @@ def __init__(self) -> None: self.query_validator: Optional[QueryValidator] = None self.validation_rules: Optional[ValidationRules] = None self.execute_get_queries: bool = False - self.execution_context_class: Optional[Type[ExecutionContext]] = None - self.middleware_manager_class: Optional[Type[MiddlewareManager]] = None + self.execution_context_class: Optional[type[ExecutionContext]] = None + self.middleware_manager_class: Optional[type[MiddlewareManager]] = None @abstractmethod async def handle(self, scope: Scope, receive: Receive, send: Send): @@ -86,7 +86,7 @@ def configure( explorer: Optional[Explorer] = None, logger: Union[None, str, Logger, LoggerAdapter] = None, error_formatter: ErrorFormatter = format_error, - execution_context_class: Optional[Type[ExecutionContext]] = None, + execution_context_class: Optional[type[ExecutionContext]] = None, ): """Configures the handler with options from the ASGI application. diff --git a/ariadne/asgi/handlers/graphql_transport_ws.py b/ariadne/asgi/handlers/graphql_transport_ws.py index 60ee82b9d..9f272c2f4 100644 --- a/ariadne/asgi/handlers/graphql_transport_ws.py +++ b/ariadne/asgi/handlers/graphql_transport_ws.py @@ -1,15 +1,16 @@ import asyncio +from collections.abc import AsyncGenerator from contextlib import suppress from datetime import timedelta from inspect import isawaitable -from typing import Any, AsyncGenerator, Dict, List, Optional, cast +from typing import Any, Optional, cast from graphql import GraphQLError from graphql.language import OperationType from starlette.types import Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState -from ...graphql import subscribe, parse_query, validate_data +from ...graphql import parse_query, subscribe, validate_data from ...logger import log_error from ...types import ( ExecutionResult, @@ -24,8 +25,8 @@ def __init__(self) -> None: self.connection_acknowledged: bool = False self.connection_init_timeout_task: Optional[asyncio.Task] = None self.connection_init_received: bool = False - self.operations: Dict[str, Operation] = {} - self.operation_tasks: Dict[str, asyncio.Task] = {} + self.operations: dict[str, Operation] = {} + self.operation_tasks: dict[str, asyncio.Task] = {} self.websocket: WebSocket @@ -374,7 +375,7 @@ async def get_results(): if not success: if not isinstance(results_producer, list): - error_payload = cast(List[dict], [results_producer]) + error_payload = cast(list[dict], [results_producer]) else: error_payload = results_producer @@ -524,7 +525,8 @@ async def observe_async_results( } ) except asyncio.CancelledError: # pylint: disable=W0706 - # if asyncio Task is cancelled then CancelledError is thrown in the coroutine + # if asyncio Task is cancelled then CancelledError + # is thrown in the coroutine raise except Exception as error: if not isinstance(error, GraphQLError): diff --git a/ariadne/asgi/handlers/graphql_ws.py b/ariadne/asgi/handlers/graphql_ws.py index 1d5f0e3e7..a40416b45 100644 --- a/ariadne/asgi/handlers/graphql_ws.py +++ b/ariadne/asgi/handlers/graphql_ws.py @@ -1,6 +1,7 @@ import asyncio +from collections.abc import AsyncGenerator from inspect import isawaitable -from typing import Any, AsyncGenerator, Dict, List, Optional, cast +from typing import Any, Optional, cast from graphql import DocumentNode, GraphQLError from graphql.language import OperationType @@ -104,7 +105,7 @@ async def handle_websocket(self, websocket: WebSocket): `websocket`: the `WebSocket` instance from Starlette or FastAPI. """ - operations: Dict[str, Operation] = {} + operations: dict[str, Operation] = {} await websocket.accept("graphql-ws") try: while WebSocketState.DISCONNECTED not in ( @@ -134,7 +135,7 @@ async def handle_websocket_message( self, websocket: WebSocket, message: dict, - operations: Dict[str, Operation], + operations: dict[str, Operation], ): """Handles new message from websocket connection. @@ -167,7 +168,7 @@ async def process_single_message( websocket: WebSocket, data: Any, operation_id: str, - operations: Dict[str, Operation], + operations: dict[str, Operation], ) -> None: """Processes websocket message containing new GraphQL operation. @@ -290,7 +291,7 @@ async def start_websocket_operation( context_value: Any, query_document: DocumentNode, operation_id: str, - operations: Dict[str, Operation], + operations: dict[str, Operation], ): if self.schema is None: raise TypeError("schema is not set, call configure method to initialize it") @@ -310,7 +311,7 @@ async def start_websocket_operation( ) if not success: - results = cast(List[dict], results) + results = cast(list[dict], results) await websocket.send_json( { "type": GraphQLWSHandler.GQL_ERROR, diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index 9432feccd..03658aa7d 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -1,7 +1,7 @@ import json from http import HTTPStatus from inspect import isawaitable -from typing import Any, Optional, Type, Union, cast +from typing import Any, Optional, Union, cast from graphql import DocumentNode, MiddlewareManager from starlette.datastructures import UploadFile @@ -38,7 +38,7 @@ def __init__( self, extensions: Optional[Extensions] = None, middleware: Optional[Middlewares] = None, - middleware_manager_class: Optional[Type[MiddlewareManager]] = None, + middleware_manager_class: Optional[type[MiddlewareManager]] = None, ) -> None: """Initializes the HTTP handler. diff --git a/ariadne/contrib/federation/__init__.py b/ariadne/contrib/federation/__init__.py index 514bfd881..e8c469d2f 100644 --- a/ariadne/contrib/federation/__init__.py +++ b/ariadne/contrib/federation/__init__.py @@ -7,3 +7,5 @@ from .interfaces import FederatedInterfaceType from .objects import FederatedObjectType from .schema import make_federated_schema + +__all__ = ["FederatedInterfaceType", "FederatedObjectType", "make_federated_schema"] diff --git a/ariadne/contrib/federation/schema.py b/ariadne/contrib/federation/schema.py index 20d724b27..c86b456f1 100644 --- a/ariadne/contrib/federation/schema.py +++ b/ariadne/contrib/federation/schema.py @@ -1,6 +1,6 @@ -import re import os -from typing import Dict, List, Optional, Type, Union, cast +import re +from typing import Optional, Union, cast from graphql import extend_schema, parse from graphql.language import DocumentNode @@ -13,12 +13,12 @@ from ...executable_schema import ( SchemaBindables, - make_executable_schema, join_type_defs, + make_executable_schema, ) +from ...load_schema import load_schema_from_path from ...schema_names import SchemaNameConverter from ...schema_visitor import SchemaDirectiveVisitor -from ...load_schema import load_schema_from_path from .utils import get_entity_types, purge_schema_directives, resolve_entities base_federation_service_type_defs = """ @@ -54,9 +54,9 @@ def has_query_type(type_defs: str) -> bool: def make_federated_schema( - type_defs: Union[str, List[str]], + type_defs: Union[str, list[str]], *bindables: SchemaBindables, - directives: Optional[Dict[str, Type[SchemaDirectiveVisitor]]] = None, + directives: Optional[dict[str, type[SchemaDirectiveVisitor]]] = None, convert_names_case: Union[bool, SchemaNameConverter] = False, ) -> GraphQLSchema: if isinstance(type_defs, list): diff --git a/ariadne/contrib/federation/utils.py b/ariadne/contrib/federation/utils.py index a93acc45e..c169886a4 100644 --- a/ariadne/contrib/federation/utils.py +++ b/ariadne/contrib/federation/utils.py @@ -1,7 +1,7 @@ # pylint: disable=cell-var-from-loop from inspect import isawaitable -from typing import Any, List, Tuple, cast +from typing import Any, cast from graphql import ( DirectiveDefinitionNode, @@ -11,14 +11,13 @@ ) from graphql.language import DirectiveNode from graphql.type import ( - GraphQLNamedType, GraphQLInputObjectType, + GraphQLNamedType, GraphQLObjectType, GraphQLResolveInfo, GraphQLSchema, ) - _allowed_directives = [ "skip", # Default directive as per specs. "include", # Default directive as per specs. @@ -41,7 +40,7 @@ ] -def _purge_directive_nodes(nodes: Tuple[Node, ...]) -> Tuple[Node, ...]: +def _purge_directive_nodes(nodes: tuple[Node, ...]) -> tuple[Node, ...]: return tuple( node for node in nodes @@ -58,11 +57,11 @@ def _purge_type_directives(definition: Node): if isinstance(value, tuple): # Remove directive nodes from the tuple # e.g. doc -> definitions [DirectiveDefinitionNode] - next_value = _purge_directive_nodes(cast(Tuple[Node, ...], value)) + next_value = _purge_directive_nodes(cast(tuple[Node, ...], value)) for item in next_value: if isinstance(item, Node): - # Look for directive nodes on sub-nodes - # e.g. doc -> definitions [ObjectTypeDefinitionNode] -> fields -> directives + # Look for directive nodes on sub-nodes, e.g.: doc -> + # definitions [ObjectTypeDefinitionNode] -> fields -> directives _purge_type_directives(item) setattr(definition, key, next_value) elif isinstance(value, Node): @@ -111,7 +110,7 @@ async def add_typename_to_async_return(obj: Any, typename: str) -> Any: return add_typename_to_possible_return(await obj, typename) -def get_entity_types(schema: GraphQLSchema) -> List[GraphQLNamedType]: +def get_entity_types(schema: GraphQLSchema) -> list[GraphQLNamedType]: """Get all types that include the @key directive.""" schema_types = schema.type_map.values() @@ -135,9 +134,9 @@ def includes_directive( def gather_directives( type_object: GraphQLNamedType, -) -> List[DirectiveNode]: +) -> list[DirectiveNode]: """Get all directive attached to a type.""" - directives: List[DirectiveNode] = [] + directives: list[DirectiveNode] = [] if hasattr(type_object, "extension_ast_nodes") and type_object.extension_ast_nodes: for ast_node in type_object.extension_ast_nodes: diff --git a/ariadne/contrib/sse.py b/ariadne/contrib/sse.py index cbc5e506c..06aab6ca2 100644 --- a/ariadne/contrib/sse.py +++ b/ariadne/contrib/sse.py @@ -2,32 +2,27 @@ import json import logging from asyncio import Lock +from collections.abc import AsyncGenerator, Awaitable from functools import partial from http import HTTPStatus from io import StringIO from typing import ( Any, + Callable, + Literal, Optional, cast, - AsyncGenerator, - List, - Literal, get_args, - Dict, - Callable, - Awaitable, - Type, ) from anyio import ( - get_cancelled_exc_class, CancelScope, - sleep, - move_on_after, create_task_group, + get_cancelled_exc_class, + move_on_after, + sleep, ) -from graphql import DocumentNode -from graphql import MiddlewareManager +from graphql import DocumentNode, MiddlewareManager from starlette.requests import Request from starlette.responses import Response from starlette.types import Receive, Scope, Send @@ -66,7 +61,8 @@ def __init__( `event`: the type of the event. Either "next" or "complete" # Optional arguments - `result`: an `ExecutionResult` or a `dict` that represents the result of the operation + `result`: an `ExecutionResult` or a `dict` that represents + the result of the operation """ assert event in get_args(EVENT_TYPES), f"Invalid event type: {event}" self.event = event @@ -111,7 +107,7 @@ def encode_execution_result(self) -> str: Returns the JSON string representation of the execution result """ - payload: Dict[str, Any] = {} + payload: dict[str, Any] = {} if self.result is not None and self.result.data is not None: payload["data"] = self.result.data if self.result is not None and self.result.errors is not None: @@ -124,8 +120,8 @@ def encode_execution_result(self) -> str: class ServerSentEventResponse(Response): - """Sends GraphQL SSE events using the EventSource protocol using Starlette's Response class - based on the implementation https://github.com/sysid/sse-starlette/ + """Sends GraphQL SSE events using the EventSource protocol using + Starlette's Response class based on the implementation https://github.com/sysid/sse-starlette/ """ # Sends a ping event to the client every 15 seconds to overcome proxy timeout issues @@ -137,7 +133,7 @@ def __init__( generator: AsyncGenerator[GraphQLServerSentEvent, Any], send_timeout: Optional[int] = None, ping_interval: Optional[int] = None, - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, **kwargs, ): """Initializes an SSE Response that sends events generated by an async generator @@ -147,8 +143,8 @@ def __init__( # Optional arguments `send_timeout`: the timeout in seconds to send an event to the client - `ping_interval`: the interval in seconds to send a ping event to the client, overrides - the DEFAULT_PING_INTERVAL of 15 seconds + `ping_interval`: the interval in seconds to send a ping event to the client, + overrides the DEFAULT_PING_INTERVAL of 15 seconds `headers`: a dictionary of headers to be sent with the response `encoding`: the encoding to use for the response """ @@ -159,7 +155,7 @@ def __init__( self.ping_interval = ping_interval or self.DEFAULT_PING_INTERVAL self.body = None # type: ignore - _headers: Dict[str, str] = {} + _headers: dict[str, str] = {} if headers is not None: _headers.update(headers) # mandatory for servers-sent events headers @@ -175,7 +171,8 @@ def __init__( self._send_lock = Lock() async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """The main entrypoint for the ServerSentEventResponse which is called by starlette + """The main entrypoint for the ServerSentEventResponse + which is called by starlette # Required arguments `scope`: the starlette Scope object @@ -198,7 +195,8 @@ async def wrap_cancelling(func: Callable[[], Awaitable[None]]) -> None: async def _ping(self, send: Send) -> None: """Sends a ping event to the client every `ping_interval` seconds gets - cancelled if the client disconnects through the anyio CancelScope of the TaskGroup + cancelled if the client disconnects through the anyio + CancelScope of the TaskGroup # Required arguments `send`: the starlette Send object @@ -211,7 +209,7 @@ async def _ping(self, send: Send) -> None: "type": "http.response.body", # always encode as utf-8 as per # https://html.spec.whatwg.org/multipage/server-sent-events.html#sse-processing-model - "body": ":\r\n\r\n".encode("utf-8"), + "body": b":\r\n\r\n", "more_body": True, } ) @@ -258,8 +256,8 @@ async def send_events(self, send: Send) -> None: @staticmethod async def listen_for_disconnect(receive: Receive) -> None: - """Listens for the client disconnect event and stops the streaming by exiting the infinite - loop. This triggers the anyio CancelScope to cancel the TaskGroup + """Listens for the client disconnect event and stops the streaming by exiting + the infinite loop. This triggers the anyio CancelScope to cancel the TaskGroup # Required arguments `receive`: the starlette Receive object @@ -283,10 +281,10 @@ def encode_event(event: GraphQLServerSentEvent) -> bytes: class GraphQLHTTPSSEHandler(GraphQLHTTPHandler): - """Extension to the default GraphQLHTTPHandler to also handle Server-Sent Events as per - the GraphQL SSE Protocol specification. This handler only supports the defined - `Distinct connections mode` due to its statelessness. This implementation is based on - the specification as of commit 80cf75b5952d1a065c95bdbd6a74304c90dbe2c5. + """Extension to the default GraphQLHTTPHandler to also handle Server-Sent Events + as per the GraphQL SSE Protocol specification. This handler only supports the + defined `Distinct connections mode` due to its statelessness. This implementation + is based on the specification as of commit 80cf75b5952d1a065c95bdbd6a74304c90dbe2c5. For more information see the specification (https://github.com/enisdenjo/graphql-sse/blob/master/PROTOCOL.md) """ @@ -295,10 +293,10 @@ def __init__( self, extensions: Optional[Extensions] = None, middleware: Optional[Middlewares] = None, - middleware_manager_class: Optional[Type[MiddlewareManager]] = None, + middleware_manager_class: Optional[type[MiddlewareManager]] = None, send_timeout: Optional[int] = None, ping_interval: Optional[int] = None, - default_response_headers: Optional[Dict[str, str]] = None, + default_response_headers: Optional[dict[str, str]] = None, ): super().__init__(extensions, middleware, middleware_manager_class) self.send_timeout = send_timeout @@ -321,7 +319,8 @@ async def handle_request_override(self, request: Request) -> Optional[Response]: return None async def handle_sse_request(self, request: Request) -> Response: - """Handles the HTTP request with GraphQL Subscription query using Server-Sent Events. + """Handles the HTTP request with GraphQL Subscription query using Server-Sent + Events. # Required arguments @@ -412,7 +411,7 @@ async def sse_subscribe_to_graphql( if not success: if not isinstance(results, list): - error_payload = cast(List[dict], [results]) + error_payload = cast(list[dict], [results]) else: error_payload = results @@ -443,7 +442,7 @@ async def sse_subscribe_to_graphql( @staticmethod async def sse_generate_error_response( - errors: List[GraphQLError], + errors: list[GraphQLError], ) -> AsyncGenerator[GraphQLServerSentEvent, Any]: """A Server-Sent Event response generator for the errors To be passed to a ServerSentEventResponse instance diff --git a/ariadne/contrib/tracing/apollotracing.py b/ariadne/contrib/tracing/apollotracing.py index e803b8c39..9bdb7d7c0 100644 --- a/ariadne/contrib/tracing/apollotracing.py +++ b/ariadne/contrib/tracing/apollotracing.py @@ -1,5 +1,5 @@ from inspect import iscoroutinefunction -from typing import Any, List, Optional, cast +from typing import Any, Optional, cast from graphql import GraphQLResolveInfo from graphql.pyutils import is_awaitable @@ -40,7 +40,7 @@ def __init__(self, trace_default_resolver: bool = False) -> None: self.trace_default_resolver = trace_default_resolver self.start_date: Optional[datetime] = None self.start_timestamp: int = 0 - self.resolvers: List[dict] = [] + self.resolvers: list[dict] = [] self._totals = None diff --git a/ariadne/contrib/tracing/opentelemetry.py b/ariadne/contrib/tracing/opentelemetry.py index 16e168fdc..26f4da705 100644 --- a/ariadne/contrib/tracing/opentelemetry.py +++ b/ariadne/contrib/tracing/opentelemetry.py @@ -1,6 +1,6 @@ from functools import partial from inspect import iscoroutinefunction -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Optional, Union from graphql import GraphQLResolveInfo from graphql.pyutils import is_awaitable @@ -9,8 +9,7 @@ from ...types import ContextValue, Extension, Resolver from .utils import copy_args_for_tracing, format_path, should_trace - -ArgFilter = Callable[[Dict[str, Any], GraphQLResolveInfo], Dict[str, Any]] +ArgFilter = Callable[[dict[str, Any], GraphQLResolveInfo], dict[str, Any]] RootSpanName = Union[str, Callable[[ContextValue], str]] DEFAULT_OPERATION_NAME = "GraphQL Operation" @@ -128,8 +127,8 @@ async def await_sync_result(): return result def filter_resolver_args( - self, args: Dict[str, Any], info: GraphQLResolveInfo - ) -> Dict[str, Any]: + self, args: dict[str, Any], info: GraphQLResolveInfo + ) -> dict[str, Any]: args_to_trace = copy_args_for_tracing(args) if not self._arg_filter: diff --git a/ariadne/contrib/tracing/opentracing.py b/ariadne/contrib/tracing/opentracing.py index f3f82dc31..1a66adb58 100644 --- a/ariadne/contrib/tracing/opentracing.py +++ b/ariadne/contrib/tracing/opentracing.py @@ -1,6 +1,6 @@ from functools import partial from inspect import iscoroutinefunction -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Optional, Union from graphql import GraphQLResolveInfo from graphql.pyutils import is_awaitable @@ -10,8 +10,7 @@ from ...types import ContextValue, Extension, Resolver from .utils import copy_args_for_tracing, format_path, should_trace - -ArgFilter = Callable[[Dict[str, Any], GraphQLResolveInfo], Dict[str, Any]] +ArgFilter = Callable[[dict[str, Any], GraphQLResolveInfo], dict[str, Any]] RootSpanName = Union[str, Callable[[ContextValue], str]] @@ -107,8 +106,8 @@ async def await_sync_result(): return result def filter_resolver_args( - self, args: Dict[str, Any], info: GraphQLResolveInfo - ) -> Dict[str, Any]: + self, args: dict[str, Any], info: GraphQLResolveInfo + ) -> dict[str, Any]: args_to_trace = copy_args_for_tracing(args) if not self._arg_filter: diff --git a/ariadne/enums.py b/ariadne/enums.py index 83c942a6b..ea44a56d0 100644 --- a/ariadne/enums.py +++ b/ariadne/enums.py @@ -1,13 +1,11 @@ import enum +import warnings from typing import ( Any, - Dict, Optional, - Type, Union, cast, ) -import warnings from graphql.type import GraphQLEnumType, GraphQLNamedType, GraphQLSchema @@ -51,7 +49,7 @@ class EnumType(SchemaBindable): "MEMBER": 0, "MODERATOR": 1, "ADMIN": 2, - } + }, ) ``` """ @@ -59,7 +57,7 @@ class EnumType(SchemaBindable): def __init__( self, name: str, - values: Union[Dict[str, Any], Type[enum.Enum], Type[enum.IntEnum]], + values: Union[dict[str, Any], type[enum.Enum], type[enum.IntEnum]], ) -> None: """Initializes the `EnumType` with `name` and `values` mapping. @@ -72,7 +70,7 @@ def __init__( enum's in Python logic. """ self.name = name - self.values = cast(Dict[str, Any], getattr(values, "__members__", values)) + self.values = cast(dict[str, Any], getattr(values, "__members__", values)) def bind_to_schema(self, schema: GraphQLSchema) -> None: """Binds this `EnumType` instance to the instance of GraphQL schema.""" @@ -82,9 +80,7 @@ def bind_to_schema(self, schema: GraphQLSchema) -> None: for key, value in self.values.items(): if key not in graphql_type.values: - raise ValueError( - "Value %s is not defined on enum %s" % (key, self.name) - ) + raise ValueError(f"Value {key} is not defined on enum {self.name}") graphql_type.values[key].value = value def bind_to_default_values(self, _schema: GraphQLSchema) -> None: @@ -113,9 +109,9 @@ def validate_graphql_type(self, graphql_type: Optional[GraphQLNamedType]) -> Non """Validates that schema's GraphQL type associated with this `EnumType` is an `enum`.""" if not graphql_type: - raise ValueError("Enum %s is not defined in the schema" % self.name) + raise ValueError(f"Enum {self.name} is not defined in the schema") if not isinstance(graphql_type, GraphQLEnumType): raise ValueError( - "%s is defined in the schema, but it is instance of %s (expected %s)" - % (self.name, type(graphql_type).__name__, GraphQLEnumType.__name__) + f"{self.name} is defined in the schema, but it is instance of " + f"{type(graphql_type).__name__} (expected {GraphQLEnumType.__name__})" ) diff --git a/ariadne/enums_default_values.py b/ariadne/enums_default_values.py index 555e90357..8f8adc629 100644 --- a/ariadne/enums_default_values.py +++ b/ariadne/enums_default_values.py @@ -7,7 +7,6 @@ GraphQLSchemaEnumsValuesVisitor, ) - __all__ = [ "repair_schema_default_enum_values", "validate_schema_default_enum_values", diff --git a/ariadne/enums_values_visitor.py b/ariadne/enums_values_visitor.py index 075118a18..48fa8fc50 100644 --- a/ariadne/enums_values_visitor.py +++ b/ariadne/enums_values_visitor.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Optional, Union, cast +from typing import Any, Optional, Union, cast from graphql import ( EnumValueNode, @@ -10,8 +10,8 @@ GraphQLInputField, GraphQLInputObjectType, GraphQLInterfaceType, - GraphQLNonNull, GraphQLList, + GraphQLNonNull, GraphQLObjectType, GraphQLSchema, GraphQLType, @@ -23,7 +23,7 @@ class GraphQLEnumsValuesVisitor: schema: GraphQLSchema - enum_values: Dict[str, Dict[str, Any]] + enum_values: dict[str, dict[str, Any]] def __init__(self, schema: GraphQLSchema): self.enum_values = {} @@ -451,7 +451,7 @@ def visit_input_value( value_def: GraphQLInputObjectType, value_ast: ObjectValueNode, ) -> None: - value: Dict[str, Any] = { + value: dict[str, Any] = { field.name.value: field.value for field in value_ast.fields } diff --git a/ariadne/executable_schema.py b/ariadne/executable_schema.py index 138c628f2..2903570eb 100644 --- a/ariadne/executable_schema.py +++ b/ariadne/executable_schema.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict, List, Optional, Type, Union +from typing import Optional, Union from graphql import ( GraphQLSchema, @@ -19,15 +19,15 @@ SchemaBindables = Union[ SchemaBindable, - Type[Enum], - List[Union[SchemaBindable, Type[Enum]]], + type[Enum], + list[Union[SchemaBindable, type[Enum]]], ] def make_executable_schema( - type_defs: Union[str, List[str]], + type_defs: Union[str, list[str]], *bindables: SchemaBindables, - directives: Optional[Dict[str, Type[SchemaDirectiveVisitor]]] = None, + directives: Optional[dict[str, type[SchemaDirectiveVisitor]]] = None, convert_names_case: Union[bool, SchemaNameConverter] = False, ) -> GraphQLSchema: """Create a `GraphQLSchema` instance that can be used to execute queries. @@ -95,7 +95,13 @@ def make_executable_schema( ```python from dataclasses import dataclass from enum import Enum - from ariadne import ObjectType, QueryType, UnionType, graphql_sync, make_executable_schema + from ariadne import ( + ObjectType, + QueryType, + UnionType, + graphql_sync, + make_executable_schema, + ) # Define some types representing database models in real applications class UserLevel(str, Enum): @@ -361,14 +367,14 @@ def uppercase_resolved_value(*args, **kwargs): return schema -def join_type_defs(type_defs: List[str]) -> str: +def join_type_defs(type_defs: list[str]) -> str: return "\n\n".join(t.strip() for t in type_defs) def normalize_bindables( *bindables: SchemaBindables, -) -> List[SchemaBindable]: - normal_bindables: List[SchemaBindable] = [] +) -> list[SchemaBindable]: + normal_bindables: list[SchemaBindable] = [] for bindable in flatten_bindables(*bindables): if isinstance(bindable, SchemaBindable): normal_bindables.append(bindable) @@ -382,7 +388,7 @@ def normalize_bindables( def flatten_bindables( *bindables: SchemaBindables, -) -> List[Union[SchemaBindable, Type[Enum]]]: +) -> list[Union[SchemaBindable, type[Enum]]]: new_bindables = [] for bindable in bindables: diff --git a/ariadne/explorer/explorer.py b/ariadne/explorer/explorer.py index 39502ce70..69f8a5b30 100644 --- a/ariadne/explorer/explorer.py +++ b/ariadne/explorer/explorer.py @@ -1,4 +1,5 @@ -from typing import Any, Awaitable, Optional, Union +from collections.abc import Awaitable +from typing import Any, Optional, Union class Explorer: diff --git a/ariadne/explorer/playground.py b/ariadne/explorer/playground.py index fdc88528d..c30b7fdfe 100644 --- a/ariadne/explorer/playground.py +++ b/ariadne/explorer/playground.py @@ -1,12 +1,12 @@ import json -from typing import Dict, Optional, Union +from typing import Optional, Union from .explorer import Explorer from .template import read_template, render_template PLAYGROUND_HTML = read_template("playground.html") -SettingsDict = Dict[str, Union[str, int, bool, Dict[str, str]]] +SettingsDict = dict[str, Union[str, int, bool, dict[str, str]]] class ExplorerPlayground(Explorer): @@ -24,7 +24,7 @@ def __init__( prettier_tab_width: Optional[int] = None, prettier_use_tabs: Optional[bool] = None, request_credentials: Optional[str] = None, - request_global_headers: Optional[Dict[str, str]] = None, + request_global_headers: Optional[dict[str, str]] = None, schema_polling_enable: Optional[bool] = None, schema_polling_endpoint_filter: Optional[str] = None, schema_polling_interval: Optional[int] = None, @@ -63,7 +63,7 @@ def __init__( }, ) - def build_settings( + def build_settings( # noqa: C901 self, editor_cursor_shape: Optional[str] = None, editor_font_family: Optional[str] = None, @@ -75,7 +75,7 @@ def build_settings( prettier_tab_width: Optional[int] = None, prettier_use_tabs: Optional[bool] = None, request_credentials: Optional[str] = None, - request_global_headers: Optional[Dict[str, str]] = None, + request_global_headers: Optional[dict[str, str]] = None, schema_polling_enable: Optional[bool] = None, schema_polling_endpoint_filter: Optional[str] = None, schema_polling_interval: Optional[int] = None, diff --git a/ariadne/explorer/template.py b/ariadne/explorer/template.py index 0de6a18c8..918f6461f 100644 --- a/ariadne/explorer/template.py +++ b/ariadne/explorer/template.py @@ -12,15 +12,14 @@ import html from enum import IntEnum from os import path -from typing import List, Optional, Tuple - +from typing import Optional TEMPLATE_DIR = path.join(path.dirname(path.abspath(__file__)), "templates") def read_template(template: str) -> str: template_path = path.join(TEMPLATE_DIR, template) - with open(template_path, "r", encoding="utf-8") as fp: + with open(template_path, encoding="utf-8") as fp: return fp.read() @@ -34,7 +33,7 @@ class Token(IntEnum): ENDIF = 6 -TokenBlock = Tuple[Token, Optional[str]] +TokenBlock = tuple[Token, Optional[str]] def render_template(template: str, template_vars: Optional[dict] = None) -> str: @@ -47,8 +46,8 @@ def parse_template(template: str): return build_template_ast(tokens) -def tokenize_template(template: str) -> List[TokenBlock]: - tokens: List[TokenBlock] = [] +def tokenize_template(template: str) -> list[TokenBlock]: + tokens: list[TokenBlock] = [] cursor = 0 limit = len(template) @@ -78,7 +77,7 @@ def tokenize_template(template: str) -> List[TokenBlock]: return tokens -def tokenize_var(template: str, cursor: int) -> Tuple[TokenBlock, int]: +def tokenize_var(template: str, cursor: int) -> tuple[TokenBlock, int]: end = template.find("}}", cursor) if end == -1: raise ValueError( @@ -94,7 +93,7 @@ def tokenize_var(template: str, cursor: int) -> Tuple[TokenBlock, int]: return (Token.VAR, var_name), end + 2 -def tokenize_block(template: str, cursor: int) -> Tuple[TokenBlock, int]: +def tokenize_block(template: str, cursor: int) -> tuple[TokenBlock, int]: token: TokenBlock end = template.find("%}", cursor) @@ -161,13 +160,13 @@ def tokenize_block(template: str, cursor: int) -> Tuple[TokenBlock, int]: return token, end + 2 -def build_template_ast(tokens: List[TokenBlock]) -> "TemplateDocument": +def build_template_ast(tokens: list[TokenBlock]) -> "TemplateDocument": nodes = ast_to_nodes(tokens) return TemplateDocument(nodes) -def ast_to_nodes(tokens: List[TokenBlock]) -> List["TemplateNode"]: - nodes: List[TemplateNode] = [] +def ast_to_nodes(tokens: list[TokenBlock]) -> list["TemplateNode"]: + nodes: list[TemplateNode] = [] i = 0 limit = len(tokens) while i < limit: @@ -193,7 +192,7 @@ def ast_to_nodes(tokens: List[TokenBlock]) -> List["TemplateNode"]: if_block_args = token_args.split(" ") nesting = 0 - children: List[TokenBlock] = [] + children: list[TokenBlock] = [] if_not = token_type == Token.IF_NOT has_else = False for child in tokens[i + 1 :]: @@ -270,7 +269,7 @@ def render(self, _) -> str: class TemplateIfBlock(TemplateNode): - def __init__(self, args: List[str], nodes, if_not: bool = False) -> None: + def __init__(self, args: list[str], nodes, if_not: bool = False) -> None: self.args = args self.nodes = nodes self.if_not = if_not diff --git a/ariadne/extensions.py b/ariadne/extensions.py index cb2957a84..cdb3d1294 100644 --- a/ariadne/extensions.py +++ b/ariadne/extensions.py @@ -1,10 +1,10 @@ from contextlib import contextmanager -from typing import List, Optional, Type +from typing import Optional from graphql import GraphQLError from graphql.execution import MiddlewareManager -from .types import MiddlewareList, ContextValue, ExtensionList +from .types import ContextValue, ExtensionList, MiddlewareList class ExtensionManager: @@ -45,7 +45,7 @@ def __init__( def as_middleware_manager( self, middleware: MiddlewareList = None, - manager_class: Optional[Type[MiddlewareManager]] = None, + manager_class: Optional[type[MiddlewareManager]] = None, ) -> Optional[MiddlewareManager]: """Creates middleware manager instance combining middleware and extensions. @@ -83,7 +83,7 @@ def request(self): for ext in self.extensions_reversed: ext.request_finished(self.context) - def has_errors(self, errors: List[GraphQLError]): + def has_errors(self, errors: list[GraphQLError]): """Propagates GraphQL errors returned by GraphQL server to extensions. Should be called only when there are errors. diff --git a/ariadne/file_uploads.py b/ariadne/file_uploads.py index b77cc249d..3e364357d 100644 --- a/ariadne/file_uploads.py +++ b/ariadne/file_uploads.py @@ -1,4 +1,5 @@ from typing import Optional, Union + from typing_extensions import Protocol from .exceptions import HttpBadRequestError @@ -66,18 +67,18 @@ def combine_multipart_data( if not isinstance(operations, (dict, list)): raise HttpBadRequestError( - "Invalid type for the 'operations' multipart field ({}).".format(SPEC_URL) + f"Invalid type for the 'operations' multipart field ({SPEC_URL})." ) if not isinstance(files_map, dict): raise HttpBadRequestError( - "Invalid type for the 'map' multipart field ({}).".format(SPEC_URL) + f"Invalid type for the 'map' multipart field ({SPEC_URL})." ) files_map = inverse_files_map(files_map, files) if isinstance(operations, list): for i, operation in enumerate(operations): add_files_to_variables( - operation.get("variables"), "{}.variables".format(i), files_map + operation.get("variables"), f"{i}.variables", files_map ) if isinstance(operations, dict): add_files_to_variables(operations.get("variables"), "variables", files_map) @@ -89,28 +90,26 @@ def inverse_files_map(files_map: dict, files: FilesDict) -> dict: for field_name, paths in files_map.items(): if not isinstance(paths, list): raise HttpBadRequestError( - ( + "Invalid type for the 'map' multipart field entry " - "key '{}' array ({})." - ).format(field_name, SPEC_URL) + f"key '{field_name}' array ({SPEC_URL})." + ) for i, path in enumerate(paths): if not isinstance(path, str): raise HttpBadRequestError( - ( + "Invalid type for the 'map' multipart field entry key " - "'{}' array index '{}' value ({})." - ).format(field_name, i, SPEC_URL) + f"'{field_name}' array index '{i}' value ({SPEC_URL})." + ) try: inverted_map[path] = files[field_name] except KeyError as ex: raise HttpBadRequestError( - ("File data was missing for entry key '{}' ({}).").format( - field_name, SPEC_URL - ) + f"File data was missing for entry key '{field_name}' ({SPEC_URL})." ) from ex return inverted_map @@ -121,7 +120,7 @@ def add_files_to_variables( ): if isinstance(variables, dict): for variable, value in variables.items(): - variable_path = "{}.{}".format(path, variable) + variable_path = f"{path}.{variable}" if isinstance(value, (dict, list)): add_files_to_variables(value, variable_path, files_map) elif value is None: @@ -129,7 +128,7 @@ def add_files_to_variables( if isinstance(variables, list): for i, value in enumerate(variables): - variable_path = "{}.{}".format(path, i) + variable_path = f"{path}.{i}" if isinstance(value, (dict, list)): add_files_to_variables(value, variable_path, files_map) elif value is None: diff --git a/ariadne/format_error.py b/ariadne/format_error.py index f880a2ed7..8bc5e7d8b 100644 --- a/ariadne/format_error.py +++ b/ariadne/format_error.py @@ -1,7 +1,6 @@ from reprlib import repr # pylint: disable=redefined-builtin from traceback import format_exception - -from typing import List, Optional, cast +from typing import Optional, cast from graphql import GraphQLError @@ -56,7 +55,7 @@ def get_error_extension(error: GraphQLError) -> Optional[dict]: } -def get_formatted_error_traceback(error: Exception) -> List[str]: +def get_formatted_error_traceback(error: Exception) -> list[str]: """Get JSON-serializable stacktrace from `Exception`. Returns list of strings, with every item being separate line from stacktrace. diff --git a/ariadne/graphql.py b/ariadne/graphql.py index 474581acc..817a60596 100644 --- a/ariadne/graphql.py +++ b/ariadne/graphql.py @@ -1,17 +1,12 @@ from asyncio import ensure_future +from collections.abc import AsyncGenerator, Awaitable, Collection, Sequence from inspect import isawaitable from logging import Logger, LoggerAdapter from typing import ( Any, - AsyncGenerator, - Awaitable, - Collection, - List, Optional, - Sequence, - Type, - cast, Union, + cast, ) from warnings import warn @@ -27,6 +22,8 @@ execute, execute_sync, parse, +) +from graphql import ( subscribe as _subscribe, ) from graphql.validation import specified_rules, validate @@ -76,9 +73,9 @@ async def graphql( require_query: bool = False, error_formatter: ErrorFormatter = format_error, middleware: MiddlewareList = None, - middleware_manager_class: Optional[Type[MiddlewareManager]] = None, + middleware_manager_class: Optional[type[MiddlewareManager]] = None, extensions: Optional[ExtensionList] = None, - execution_context_class: Optional[Type[ExecutionContext]] = None, + execution_context_class: Optional[type[ExecutionContext]] = None, **kwargs, ) -> GraphQLResult: """Execute GraphQL query asynchronously. @@ -166,7 +163,7 @@ async def graphql( if callable(validation_rules): validation_rules = cast( - Optional[Collection[Type[ASTValidationRule]]], + Optional[Collection[type[ASTValidationRule]]], validation_rules(context_value, document, data), ) @@ -267,9 +264,9 @@ def graphql_sync( require_query: bool = False, error_formatter: ErrorFormatter = format_error, middleware: MiddlewareList = None, - middleware_manager_class: Optional[Type[MiddlewareManager]] = None, + middleware_manager_class: Optional[type[MiddlewareManager]] = None, extensions: Optional[ExtensionList] = None, - execution_context_class: Optional[Type[ExecutionContext]] = None, + execution_context_class: Optional[type[ExecutionContext]] = None, **kwargs, ) -> GraphQLResult: """Execute GraphQL query synchronously. @@ -357,7 +354,7 @@ def graphql_sync( if callable(validation_rules): validation_rules = cast( - Optional[Collection[Type[ASTValidationRule]]], + Optional[Collection[type[ASTValidationRule]]], validation_rules(context_value, document, data), ) @@ -529,7 +526,7 @@ async def subscribe( if callable(validation_rules): validation_rules = cast( - Optional[Collection[Type[ASTValidationRule]]], + Optional[Collection[type[ASTValidationRule]]], validation_rules(context_value, document, data), ) @@ -574,7 +571,7 @@ async def subscribe( return False, [error_formatter(error, debug)] if isinstance(result, ExecutionResult): - errors = cast(List[GraphQLError], result.errors) + errors = cast(list[GraphQLError], result.errors) for error_ in errors: # mypy issue #5080 log_error(error_, logger) return False, [error_formatter(error, debug) for error in errors] @@ -642,12 +639,12 @@ def add_extensions_to_response(extension_manager: ExtensionManager, response: di def validate_query( schema: GraphQLSchema, document_ast: DocumentNode, - rules: Optional[Collection[Type[ASTValidationRule]]] = None, + rules: Optional[Collection[type[ASTValidationRule]]] = None, max_errors: Optional[int] = None, type_info: Optional[TypeInfo] = None, enable_introspection: bool = True, query_validator: Optional[QueryValidator] = None, -) -> List[GraphQLError]: +) -> list[GraphQLError]: validate_fn: QueryValidator = query_validator or validate if not enable_introspection: @@ -690,13 +687,13 @@ def validate_variables(variables) -> None: def validate_operation_name(operation_name) -> None: if operation_name is not None and not isinstance(operation_name, str): - raise GraphQLError('"%s" is not a valid operation name.' % operation_name) + raise GraphQLError(f'"{operation_name}" is not a valid operation name.') def validate_operation_is_query( document_ast: DocumentNode, operation_name: Optional[str] ): - query_operations: List[Optional[str]] = [] + query_operations: list[Optional[str]] = [] for definition in document_ast.definitions: if ( isinstance(definition, OperationDefinitionNode) @@ -745,7 +742,7 @@ def validate_named_operation_is_not_subscription( def validate_anonymous_operation_is_not_subscription(document_ast: DocumentNode): - operations: List[OperationDefinitionNode] = [] + operations: list[OperationDefinitionNode] = [] for definition in document_ast.definitions: if isinstance(definition, OperationDefinitionNode): operations.append(definition) diff --git a/ariadne/inputs.py b/ariadne/inputs.py index 53d3dd155..dccb2343a 100644 --- a/ariadne/inputs.py +++ b/ariadne/inputs.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, cast +from typing import Optional, cast from graphql import GraphQLInputObjectType, GraphQLSchema from graphql.type.definition import GraphQLInputFieldOutType, GraphQLNamedType @@ -132,13 +132,13 @@ def resolve_repr(*_, input: ExampleInput): """ _out_type: Optional[GraphQLInputFieldOutType] - _out_names: Optional[Dict[str, str]] + _out_names: Optional[dict[str, str]] def __init__( self, name: str, out_type: Optional[GraphQLInputFieldOutType] = None, - out_names: Optional[Dict[str, str]] = None, + out_names: Optional[dict[str, str]] = None, ) -> None: """Initializes the `InputType` with a `name` and optionally out type and out names. @@ -192,13 +192,8 @@ def validate_graphql_type(self, graphql_type: Optional[GraphQLNamedType]) -> Non """Validates that schema's GraphQL type associated with this `InputType` is an `input`.""" if not graphql_type: - raise ValueError("Type %s is not defined in the schema" % self.name) + raise ValueError(f"Type {self.name} is not defined in the schema") if not isinstance(graphql_type, GraphQLInputObjectType): raise ValueError( - "%s is defined in the schema, but it is instance of %s (expected %s)" - % ( - self.name, - type(graphql_type).__name__, - GraphQLInputObjectType.__name__, - ) + f"{self.name} is defined in the schema, but it is instance of {type(graphql_type).__name__} (expected {GraphQLInputObjectType.__name__})" ) diff --git a/ariadne/load_schema.py b/ariadne/load_schema.py index dac3bc8b5..1c807fae3 100644 --- a/ariadne/load_schema.py +++ b/ariadne/load_schema.py @@ -1,6 +1,7 @@ import os +from collections.abc import Generator from pathlib import Path -from typing import Generator, Union +from typing import Union from graphql import parse from graphql.error import GraphQLSyntaxError @@ -40,7 +41,7 @@ def walk_graphql_files(path: Union[str, os.PathLike]) -> Generator[str, None, No def read_graphql_file(path: Union[str, os.PathLike]) -> str: - with open(path, "r", encoding="utf-8") as graphql_file: + with open(path, encoding="utf-8") as graphql_file: schema = graphql_file.read() try: parse(schema) diff --git a/ariadne/objects.py b/ariadne/objects.py index d47c1323a..65534ce9e 100644 --- a/ariadne/objects.py +++ b/ariadne/objects.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Optional, cast +from typing import Callable, Optional, cast from graphql.type import GraphQLNamedType, GraphQLObjectType, GraphQLSchema @@ -151,7 +151,7 @@ def resolve_user_full_name(obj: UserModel, *_): ``` """ - _resolvers: Dict[str, Resolver] + _resolvers: dict[str, Resolver] def __init__(self, name: str) -> None: """Initializes the `ObjectType` with a `name`. @@ -237,11 +237,10 @@ def validate_graphql_type(self, graphql_type: Optional[GraphQLNamedType]) -> Non """Validates that schema's GraphQL type associated with this `ObjectType` is a `type`.""" if not graphql_type: - raise ValueError("Type %s is not defined in the schema" % self.name) + raise ValueError(f"Type {self.name} is not defined in the schema") if not isinstance(graphql_type, GraphQLObjectType): raise ValueError( - "%s is defined in the schema, but it is instance of %s (expected %s)" - % (self.name, type(graphql_type).__name__, GraphQLObjectType.__name__) + f"{self.name} is defined in the schema, but it is instance of {type(graphql_type).__name__} (expected {GraphQLObjectType.__name__})" ) def bind_resolvers_to_graphql_type(self, graphql_type, replace_existing=True): @@ -249,7 +248,7 @@ def bind_resolvers_to_graphql_type(self, graphql_type, replace_existing=True): for field, resolver in self._resolvers.items(): if field not in graphql_type.fields: raise ValueError( - "Field %s is not defined on type %s" % (field, self.name) + f"Field {field} is not defined on type {self.name}" ) if graphql_type.fields[field].resolve is None or replace_existing: graphql_type.fields[field].resolve = resolver diff --git a/ariadne/scalars.py b/ariadne/scalars.py index 0a786d470..ae44bb73e 100644 --- a/ariadne/scalars.py +++ b/ariadne/scalars.py @@ -313,9 +313,8 @@ def validate_graphql_type(self, graphql_type: Optional[GraphQLNamedType]) -> Non """Validates that schema's GraphQL type associated with this `ScalarType` is a `scalar`.""" if not graphql_type: - raise ValueError("Scalar %s is not defined in the schema" % self.name) + raise ValueError(f"Scalar {self.name} is not defined in the schema") if not isinstance(graphql_type, GraphQLScalarType): raise ValueError( - "%s is defined in the schema, but it is instance of %s (expected %s)" - % (self.name, type(graphql_type).__name__, GraphQLScalarType.__name__) + f"{self.name} is defined in the schema, but it is instance of {type(graphql_type).__name__} (expected {GraphQLScalarType.__name__})" ) diff --git a/ariadne/schema_names.py b/ariadne/schema_names.py index 73441b6e8..f11b4839c 100644 --- a/ariadne/schema_names.py +++ b/ariadne/schema_names.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Tuple +from typing import Callable, Optional from graphql import ( GraphQLField, @@ -26,7 +26,7 @@ Returns a string with the Python name to use. """ -SchemaNameConverter = Callable[[str, GraphQLSchema, Tuple[str, ...]], str] +SchemaNameConverter = Callable[[str, GraphQLSchema, tuple[str, ...]], str] GRAPHQL_SPEC_TYPES = ( "__Directive", diff --git a/ariadne/schema_visitor.py b/ariadne/schema_visitor.py index 7d76d7dcf..c8db8e21c 100644 --- a/ariadne/schema_visitor.py +++ b/ariadne/schema_visitor.py @@ -1,13 +1,9 @@ +from collections.abc import Mapping from types import FunctionType from typing import ( Any, Callable, - Dict, - List, - Mapping, Optional, - Tuple, - Type, TypeVar, Union, cast, @@ -49,8 +45,8 @@ GraphQLEnumValue, ] V = TypeVar("V", bound=VisitableSchemaType) -VisitableMap = Dict[str, V] -IndexedObject = Union[VisitableMap, Tuple[V, ...]] +VisitableMap = dict[str, V] +IndexedObject = Union[VisitableMap, tuple[V, ...]] Callback = Callable[..., Any] @@ -71,7 +67,7 @@ def update_each_key(object_map: VisitableMap, callback: Callback): the key from the array or object, or a non-null V to replace the value. """ - keys_to_remove: List[str] = [] + keys_to_remove: list[str] = [] for key, value in object_map.copy().items(): result = callback(value, key) @@ -162,7 +158,7 @@ def visit_input_field_definition( def visit_schema( schema: GraphQLSchema, visitor_selector: Callable[ - [VisitableSchemaType, str], List["SchemaDirectiveVisitor"] + [VisitableSchemaType, str], list["SchemaDirectiveVisitor"] ], ) -> GraphQLSchema: """ @@ -445,8 +441,8 @@ def get_directive_declaration( def get_declared_directives( cls, schema: GraphQLSchema, - directive_visitors: Dict[str, Type["SchemaDirectiveVisitor"]], - ) -> Dict[str, GraphQLDirective]: + directive_visitors: dict[str, type["SchemaDirectiveVisitor"]], + ) -> dict[str, GraphQLDirective]: """Get GraphQL directives declaration from GraphQL schema by their names. Returns a `dict` where keys are strings with directive names in schema @@ -461,7 +457,7 @@ def get_declared_directives( declaration from. """ - declared_directives: Dict[str, GraphQLDirective] = {} + declared_directives: dict[str, GraphQLDirective] = {} def _add_directive(decl): declared_directives[decl.name] = decl @@ -490,7 +486,7 @@ def _get_overriden_directive(visitor_class, directive_name): each(directive_visitors, _get_overriden_directive) def _rest(decl, name): - if not name in directive_visitors: + if name not in directive_visitors: # SchemaDirectiveVisitors.visit_schema_directives might be called # multiple times with partial directive_visitors maps, so it's not # necessarily an error for directive_visitors to be missing an @@ -524,10 +520,10 @@ def _location_check(loc): def visit_schema_directives( cls, schema: GraphQLSchema, - directive_visitors: Dict[str, Type["SchemaDirectiveVisitor"]], + directive_visitors: dict[str, type["SchemaDirectiveVisitor"]], *, - context: Optional[Dict[str, Any]] = None, - ) -> Mapping[str, List["SchemaDirectiveVisitor"]]: + context: Optional[dict[str, Any]] = None, + ) -> Mapping[str, list["SchemaDirectiveVisitor"]]: """Apply directives to the GraphQL schema. Applied directives mutate the GraphQL schema in place. @@ -551,14 +547,14 @@ def visit_schema_directives( # Map from directive names to lists of SchemaDirectiveVisitor instances # created while visiting the schema. - created_visitors: Dict[str, List["SchemaDirectiveVisitor"]] = { + created_visitors: dict[str, list[SchemaDirectiveVisitor]] = { k: [] for k in directive_visitors } def _visitor_selector( type_: VisitableSchemaType, method_name: str - ) -> List["SchemaDirectiveVisitor"]: - visitors: List["SchemaDirectiveVisitor"] = [] + ) -> list["SchemaDirectiveVisitor"]: + visitors: list[SchemaDirectiveVisitor] = [] directive_nodes = type_.ast_node.directives if type_.ast_node else None if directive_nodes is None: return visitors @@ -577,7 +573,7 @@ def _visitor_selector( decl = declared_directives[directive_name] - args: Dict[str, Any] = {} + args: dict[str, Any] = {} if decl: # If this directive was explicitly declared, use the declared @@ -613,7 +609,7 @@ def _visitor_selector( return created_visitors -NamedTypeMap = Dict[str, GraphQLNamedType] +NamedTypeMap = dict[str, GraphQLNamedType] def heal_schema(schema: GraphQLSchema) -> GraphQLSchema: diff --git a/ariadne/subscriptions.py b/ariadne/subscriptions.py index d9d24a277..bfba226cf 100644 --- a/ariadne/subscriptions.py +++ b/ariadne/subscriptions.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict +from typing import Callable from graphql.type import GraphQLSchema @@ -107,7 +107,7 @@ async def resolve_post( using Apollo-Client subscriptions. """ - _subscribers: Dict[str, Subscriber] + _subscribers: dict[str, Subscriber] def __init__(self) -> None: """Initializes the `SubscriptionType` with a GraphQL name set to `Subscription`.""" @@ -180,7 +180,7 @@ def bind_subscribers_to_graphql_type(self, graphql_type): for field, subscriber in self._subscribers.items(): if field not in graphql_type.fields: raise ValueError( - "Field %s is not defined on type %s" % (field, self.name) + f"Field {field} is not defined on type {self.name}" ) graphql_type.fields[field].subscribe = subscriber diff --git a/ariadne/types.py b/ariadne/types.py index f77211d59..01f2b6364 100644 --- a/ariadne/types.py +++ b/ariadne/types.py @@ -1,18 +1,12 @@ +from collections.abc import AsyncGenerator, Collection, Sequence from dataclasses import dataclass from typing import ( Any, - AsyncGenerator, Callable, - Collection, - Dict, - List, Optional, - Sequence, - Tuple, - Type, Union, + runtime_checkable, ) -from typing_extensions import Protocol, runtime_checkable from graphql import ( DocumentNode, @@ -23,8 +17,8 @@ ) from graphql.utilities.type_info import TypeInfo from graphql.validation.rules import ASTValidationRule - from starlette.websockets import WebSocket +from typing_extensions import Protocol __all__ = [ "Resolver", @@ -85,7 +79,7 @@ `dict`: JSON-serializable query result. """ -GraphQLResult = Tuple[bool, dict] +GraphQLResult = tuple[bool, dict] """Result type for `subscribe` function. @@ -96,8 +90,8 @@ `dict or generator`: JSON-serializable query result or asynchronous generator with subscription's results. Depends if query was success or not. """ -SubscriptionResult = Tuple[ - bool, Union[List[dict], AsyncGenerator[ExecutionResult, None]] +SubscriptionResult = tuple[ + bool, Union[list[dict], AsyncGenerator[ExecutionResult, None]] ] """Type for subscription source functions. @@ -327,7 +321,7 @@ def add_query(self, operation_name: str, query: str): ) ``` """ -QueryParser = Callable[[ContextValue, Dict[str, Any]], DocumentNode] +QueryParser = Callable[[ContextValue, dict[str, Any]], DocumentNode] class QueryValidator(Protocol): @@ -381,10 +375,10 @@ def __call__( self, schema: GraphQLSchema, document_ast: DocumentNode, - rules: Optional[Collection[Type[ASTValidationRule]]] = None, + rules: Optional[Collection[type[ASTValidationRule]]] = None, max_errors: Optional[int] = None, type_info: Optional[TypeInfo] = None, - ) -> List[GraphQLError]: ... + ) -> list[GraphQLError]: ... """Type of `validation_rules` option of GraphQL servers. @@ -403,15 +397,15 @@ def __call__( `dict`: a GraphQL request's data. """ ValidationRules = Union[ - Collection[Type[ASTValidationRule]], + Collection[type[ASTValidationRule]], Callable[ [Optional[Any], DocumentNode, dict], - Optional[Collection[Type[ASTValidationRule]]], + Optional[Collection[type[ASTValidationRule]]], ], ] """List of extensions to use during GraphQL query execution.""" -ExtensionList = Optional[List[Union[Type["Extension"], Callable[[], "Extension"]]]] +ExtensionList = Optional[list[Union[type["Extension"], Callable[[], "Extension"]]]] """Type of `extensions` option of GraphQL servers. @@ -603,7 +597,7 @@ async def async_my_extension(): """ return next_(obj, info, **kwargs) - def has_errors(self, errors: List[GraphQLError], context: ContextValue) -> None: + def has_errors(self, errors: list[GraphQLError], context: ContextValue) -> None: """Extension hook executed when GraphQL encountered errors.""" def format(self, context: ContextValue) -> Optional[dict]: diff --git a/ariadne/unions.py b/ariadne/unions.py index 6edb42658..8122a3f71 100644 --- a/ariadne/unions.py +++ b/ariadne/unions.py @@ -1,6 +1,6 @@ from typing import Optional, cast -from graphql.type import GraphQLNamedType, GraphQLUnionType, GraphQLSchema +from graphql.type import GraphQLNamedType, GraphQLSchema, GraphQLUnionType from .types import Resolver, SchemaBindable @@ -179,9 +179,8 @@ def validate_graphql_type(self, graphql_type: Optional[GraphQLNamedType]) -> Non """Validates that schema's GraphQL type associated with this `UnionType` is an `union`.""" if not graphql_type: - raise ValueError("Type %s is not defined in the schema" % self.name) + raise ValueError(f"Type {self.name} is not defined in the schema") if not isinstance(graphql_type, GraphQLUnionType): raise ValueError( - "%s is defined in the schema, but it is instance of %s (expected %s)" - % (self.name, type(graphql_type).__name__, GraphQLUnionType.__name__) + f"{self.name} is defined in the schema, but it is instance of {type(graphql_type).__name__} (expected {GraphQLUnionType.__name__})" ) diff --git a/ariadne/utils.py b/ariadne/utils.py index 81b267fa0..f581d8e7f 100644 --- a/ariadne/utils.py +++ b/ariadne/utils.py @@ -1,11 +1,11 @@ import asyncio from collections.abc import Mapping from functools import wraps -from typing import Optional, Union, Callable, Dict, Any, cast +from typing import Any, Callable, Optional, Union, cast from warnings import warn -from graphql.language import DocumentNode, OperationDefinitionNode, OperationType from graphql import GraphQLError, GraphQLType, parse +from graphql.language import DocumentNode, OperationDefinitionNode, OperationType def convert_camel_case_to_snake(graphql_name: str) -> str: @@ -179,8 +179,8 @@ def convert_kwargs_to_snake_case(func: Callable) -> Callable: the `convert_schema_names` option on `make_executable_schema`. """ - def convert_to_snake_case(m: Mapping) -> Dict: - converted: Dict = {} + def convert_to_snake_case(m: Mapping) -> dict: + converted: dict = {} for k, v in m.items(): if isinstance(v, Mapping): v = convert_to_snake_case(v) diff --git a/ariadne/validation/query_cost.py b/ariadne/validation/query_cost.py index d5b29a681..329cfeab4 100644 --- a/ariadne/validation/query_cost.py +++ b/ariadne/validation/query_cost.py @@ -1,6 +1,6 @@ from functools import reduce from operator import add, mul -from typing import Any, Dict, List, Optional, Type, Union, cast +from typing import Any, Optional, Union, cast from graphql import ( GraphQLError, @@ -46,8 +46,8 @@ class CostValidator(ValidationRule): maximum_cost: int default_cost: int = 0 default_complexity: int = 1 - variables: Optional[Dict] = None - cost_map: Optional[Dict[str, Dict[str, Any]]] = None + variables: Optional[dict] = None + cost_map: Optional[dict[str, dict[str, Any]]] = None def __init__( self, @@ -56,8 +56,8 @@ def __init__( *, default_cost: int = 0, default_complexity: int = 1, - variables: Optional[Dict] = None, - cost_map: Optional[Dict[str, Dict[str, Any]]] = None, + variables: Optional[dict] = None, + cost_map: Optional[dict[str, dict[str, Any]]] = None, ) -> None: super().__init__(context) @@ -67,7 +67,7 @@ def __init__( self.default_cost = default_cost self.default_complexity = default_complexity self.cost = 0 - self.operation_multipliers: List[Any] = [] + self.operation_multipliers: list[Any] = [] def compute_node_cost(self, node: CostAwareNode, type_def, parent_multipliers=None): if parent_multipliers is None: @@ -87,7 +87,7 @@ def compute_node_cost(self, node: CostAwareNode, type_def, parent_multipliers=No continue field_type = get_named_type(field.type) try: - field_args: Dict[str, Any] = get_argument_values( + field_args: dict[str, Any] = get_argument_values( field, child_node, self.variables ) except Exception as e: @@ -195,10 +195,10 @@ def compute_cost(self, multipliers=None, use_multipliers=True, complexity=None): return complexity def get_args_from_cost_map( - self, node: FieldNode, parent_type: str, field_args: Dict + self, node: FieldNode, parent_type: str, field_args: dict ): cost_args = None - cost_map = cast(Dict[Any, Dict], self.cost_map) + cost_map = cast(dict[Any, dict], self.cost_map) if parent_type in cost_map: cost_args = cost_map[parent_type].get(node.name.value) if not cost_args: @@ -249,7 +249,7 @@ def get_args_from_directives(self, directives, field_args): ) multipliers = ( self.get_multipliers_from_list_node( - cast(List[Node], multipliers_arg.value.values), field_args + cast(list[Node], multipliers_arg.value.values), field_args ) if multipliers_arg and multipliers_arg.value @@ -271,7 +271,7 @@ def get_args_from_directives(self, directives, field_args): return None - def get_multipliers_from_list_node(self, multipliers: List[Node], field_args): + def get_multipliers_from_list_node(self, multipliers: list[Node], field_args): multipliers = [ node.value # type: ignore for node in multipliers @@ -279,7 +279,7 @@ def get_multipliers_from_list_node(self, multipliers: List[Node], field_args): ] return self.get_multipliers_from_string(multipliers, field_args) # type: ignore - def get_multipliers_from_string(self, multipliers: List[str], field_args): + def get_multipliers_from_string(self, multipliers: list[str], field_args): accessors = [s.split(".") for s in multipliers] multipliers = [] for accessor in accessors: @@ -308,7 +308,7 @@ def get_cost_exceeded_error(self) -> GraphQLError: ) -def validate_cost_map(cost_map: Dict[str, Dict[str, Any]], schema: GraphQLSchema): +def validate_cost_map(cost_map: dict[str, dict[str, Any]], schema: GraphQLSchema): for type_name, type_fields in cost_map.items(): if type_name not in schema.type_map: raise GraphQLError( @@ -347,9 +347,9 @@ def cost_validator( *, default_cost: int = 0, default_complexity: int = 1, - variables: Optional[Dict] = None, - cost_map: Optional[Dict[str, Dict[str, Any]]] = None, -) -> Type[ASTValidationRule]: + variables: Optional[dict] = None, + cost_map: Optional[dict[str, dict[str, Any]]] = None, +) -> type[ASTValidationRule]: class _CostValidator(CostValidator): def __init__(self, context: ValidationContext) -> None: super().__init__( @@ -361,4 +361,4 @@ def __init__(self, context: ValidationContext) -> None: cost_map=cost_map, ) - return cast(Type[ASTValidationRule], _CostValidator) + return cast(type[ASTValidationRule], _CostValidator) diff --git a/ariadne/wsgi.py b/ariadne/wsgi.py index ec652b0fd..57564b8c6 100644 --- a/ariadne/wsgi.py +++ b/ariadne/wsgi.py @@ -1,6 +1,6 @@ import json from inspect import isawaitable -from typing import Any, Callable, Dict, List, Optional, Type, Union, cast +from typing import Any, Callable, Optional, Union, cast from urllib.parse import parse_qsl from graphql import ( @@ -77,8 +77,8 @@ def __init__( execute_get_queries: bool = False, extensions: Optional[Extensions] = None, middleware: Optional[Middlewares] = None, - middleware_manager_class: Optional[Type[MiddlewareManager]] = None, - execution_context_class: Optional[Type[ExecutionContext]] = None, + middleware_manager_class: Optional[type[MiddlewareManager]] = None, + execution_context_class: Optional[type[ExecutionContext]] = None, ) -> None: """Initializes the WSGI app. @@ -167,7 +167,7 @@ def __init__( else: self.explorer = ExplorerGraphiQL() - def __call__(self, environ: dict, start_response: Callable) -> List[bytes]: + def __call__(self, environ: dict, start_response: Callable) -> list[bytes]: """An entrypoint to the WSGI application. Returns list of bytes with response body. @@ -191,7 +191,7 @@ def __call__(self, environ: dict, start_response: Callable) -> List[bytes]: def handle_graphql_error( self, error: GraphQLError, start_response: Callable - ) -> List[bytes]: + ) -> list[bytes]: """Handles a `GraphQLError` raised from `handle_request` and returns an error response to the client. @@ -211,7 +211,7 @@ def handle_graphql_error( def handle_http_error( self, error: HttpError, start_response: Callable - ) -> List[bytes]: + ) -> list[bytes]: """Handles a `HttpError` raised from `handle_request` and returns an error response to the client. @@ -227,7 +227,7 @@ def handle_http_error( response_body = error.message or error.status return [str(response_body).encode("utf-8")] - def handle_request(self, environ: dict, start_response: Callable) -> List[bytes]: + def handle_request(self, environ: dict, start_response: Callable) -> list[bytes]: """Handles WSGI HTTP request and returns a a response to the client. Returns list of bytes with response body. @@ -245,7 +245,7 @@ def handle_request(self, environ: dict, start_response: Callable) -> List[bytes] return self.handle_not_allowed_method(environ, start_response) - def handle_get(self, environ: dict, start_response) -> List[bytes]: + def handle_get(self, environ: dict, start_response) -> list[bytes]: """Handles WSGI HTTP GET request and returns a response to the client. Returns list of bytes with response body. @@ -266,7 +266,7 @@ def handle_get(self, environ: dict, start_response) -> List[bytes]: def handle_get_query( self, environ: dict, start_response, query_params: dict - ) -> List[bytes]: + ) -> list[bytes]: data = self.extract_data_from_get(query_params) result = self.execute_query(environ, data) return self.return_response_from_result(start_response, result) @@ -300,7 +300,7 @@ def extract_data_from_get(self, query_params: dict) -> dict: "variables": clean_variables, } - def handle_get_explorer(self, environ: dict, start_response) -> List[bytes]: + def handle_get_explorer(self, environ: dict, start_response) -> list[bytes]: """Handles WSGI HTTP GET explorer request and returns a response to the client. Returns list of bytes with response body. @@ -322,7 +322,7 @@ def handle_get_explorer(self, environ: dict, start_response) -> List[bytes]: ) return [cast(str, explorer_html).encode("utf-8")] - def handle_post(self, environ: dict, start_response: Callable) -> List[bytes]: + def handle_post(self, environ: dict, start_response: Callable) -> list[bytes]: """Handles WSGI HTTP POST request and returns a a response to the client. Returns list of bytes with response body. @@ -544,7 +544,7 @@ def get_middleware_for_request( def return_response_from_result( self, start_response: Callable, result: GraphQLResult - ) -> List[bytes]: + ) -> list[bytes]: """Returns WSGI response from GraphQL result. Returns a list of bytes with response body. @@ -566,7 +566,7 @@ def return_response_from_result( def handle_not_allowed_method( self, environ: dict, start_response: Callable - ) -> List[bytes]: + ) -> list[bytes]: """Handles request for unsupported HTTP method. Returns 200 response for `OPTIONS` request and 405 response for other @@ -636,7 +636,7 @@ def __init__( "application callable" ) - def __call__(self, environ: dict, start_response: Callable) -> List[bytes]: + def __call__(self, environ: dict, start_response: Callable) -> list[bytes]: """An entrypoint to the WSGI middleware. Returns list of bytes with response body. @@ -697,8 +697,8 @@ class FormData: """ charset: str - fields: Dict[str, Any] - files: Dict[str, Any] + fields: dict[str, Any] + files: dict[str, Any] def __init__(self, content_type: Optional[str]): """Initializes form data instance. diff --git a/benchmark/database.py b/benchmark/database.py index 0ecc4706a..4c8bf6b8c 100644 --- a/benchmark/database.py +++ b/benchmark/database.py @@ -17,7 +17,7 @@ class Database: data: dict def __init__(self): - with open(DATABASE_PATH, "r") as fp: + with open(DATABASE_PATH) as fp: raw_data = json.load(fp) self.data = {} diff --git a/benchmark/generate_data.py b/benchmark/generate_data.py index 5a5895de4..2fac801ca 100644 --- a/benchmark/generate_data.py +++ b/benchmark/generate_data.py @@ -8,10 +8,10 @@ except ImportError as exc: raise ImportError("Faker is required! Run pip install Faker") from exc -from ariadne import load_schema_from_path, make_executable_schema - from conf import DATABASE_PATH, SCHEMA_PATH +from ariadne import load_schema_from_path, make_executable_schema + fake = Faker() GROUPS_COUNT = 10 @@ -78,7 +78,6 @@ def main(): with open(DATABASE_PATH, "w") as fp: json.dump(database, fp, indent=2) - print("New testing data generated!") def generate_group(database: dict, db_id: int): diff --git a/benchmark/models.py b/benchmark/models.py index 8a582ffb5..296dbc869 100644 --- a/benchmark/models.py +++ b/benchmark/models.py @@ -1,7 +1,7 @@ from dataclasses import dataclass -from enum import Enum -from typing import List, Optional from datetime import datetime +from enum import Enum +from typing import Optional class RoleEnum(str, Enum): @@ -24,7 +24,7 @@ class GroupModel: name: str slug: str title: Optional[str] - roles: List[RoleEnum] + roles: list[RoleEnum] @dataclass @@ -36,8 +36,8 @@ class UserModel: title: Optional[str] email: str group_id: int - groups: List[int] - avatar_images: List[dict] + groups: list[int] + avatar_images: list[dict] status: UserStatusEnum posts: int joined_at: datetime @@ -77,5 +77,5 @@ class PostModel: poster_id: Optional[int] poster_name: str posted_at: datetime - content: List[dict] + content: list[dict] edits: int diff --git a/benchmark/rotate_results.py b/benchmark/rotate_results.py index f5e1093ad..5f355b2c7 100644 --- a/benchmark/rotate_results.py +++ b/benchmark/rotate_results.py @@ -27,7 +27,6 @@ def rotate_results(results_dir: Path): json_files = sorted(json_files, reverse=True) if len(json_files) > RESULTS_LIMIT: for file_to_delete in json_files[RESULTS_LIMIT:]: - print(f"Removed: {results_dir / file_to_delete}") file_path = results_dir / file_to_delete file_path.unlink() diff --git a/benchmark/schema.py b/benchmark/schema.py index 04a2b7c05..e4fa8b1dd 100644 --- a/benchmark/schema.py +++ b/benchmark/schema.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import List, Optional +from typing import Optional from ariadne import ( EnumType, @@ -22,7 +22,6 @@ UserStatusEnum, ) - datetime_scalar = ScalarType("DateTime") @@ -42,7 +41,7 @@ async def resolve_category(*_, id: str) -> Optional[CategoryModel]: @query_type.field("categories") -async def resolve_categories(*_, id: Optional[List[str]] = None) -> List[CategoryModel]: +async def resolve_categories(*_, id: Optional[list[str]] = None) -> list[CategoryModel]: if id: categories_ids = [int(i) for i in id] return await database.fetch_all("category", id__in=categories_ids) @@ -60,7 +59,7 @@ async def resolve_threads( *_, category: Optional[str] = None, starter: Optional[str] = None, -) -> List[ThreadModel]: +) -> list[ThreadModel]: filters = {} if category: filters["category_id"] = int(category) @@ -76,7 +75,7 @@ async def resolve_post(*_, id: str) -> Optional[PostModel]: @query_type.field("groups") -async def resolve_groups(*_) -> List[GroupModel]: +async def resolve_groups(*_) -> list[GroupModel]: return await database.fetch_all("group") @@ -86,7 +85,7 @@ async def resolve_group(*_, id: str) -> Optional[GroupModel]: @query_type.field("users") -async def resolve_user(*_, id: str) -> List[UserModel]: +async def resolve_user(*_, id: str) -> list[UserModel]: return await database.fetch_all("user") @@ -107,7 +106,7 @@ async def resolve_category_parent(obj: CategoryModel, info) -> Optional[Category @category_type.field("children") -async def resolve_category_children(obj: CategoryModel, info) -> List[CategoryModel]: +async def resolve_category_children(obj: CategoryModel, info) -> list[CategoryModel]: return await database.fetch_all("category", parent_id=obj.id) @@ -143,7 +142,7 @@ async def resolve_thread_last_poster(obj: ThreadModel, info) -> Optional[UserMod @thread_type.field("replies") -async def resolve_thread_replies(obj: ThreadModel, info) -> List[PostModel]: +async def resolve_thread_replies(obj: ThreadModel, info) -> list[PostModel]: return await database.fetch_all("post", thread_id=obj.id, parent_id=None) @@ -180,7 +179,7 @@ async def resolve_post_parent(obj: PostModel, info) -> Optional[PostModel]: @post_type.field("replies") -async def resolve_post_replies(obj: PostModel, info) -> List[PostModel]: +async def resolve_post_replies(obj: PostModel, info) -> list[PostModel]: return await database.fetch_all("post", parent_id=obj.id) @@ -188,7 +187,7 @@ async def resolve_post_replies(obj: PostModel, info) -> List[PostModel]: @group_type.field("members") -async def resolve_group_members(obj: GroupModel, info) -> List[UserModel]: +async def resolve_group_members(obj: GroupModel, info) -> list[UserModel]: return await database.fetch_all("user", group_id=obj.id) @@ -203,7 +202,7 @@ async def resolve_user_group(obj: UserModel, info) -> GroupModel: @user_type.field("groups") -async def resolve_user_groups(obj: UserModel, info) -> List[GroupModel]: +async def resolve_user_groups(obj: UserModel, info) -> list[GroupModel]: return await database.fetch_all("group", id__in=obj.groups) diff --git a/generate_reference.py b/generate_reference.py index 2680618e6..bf4721040 100644 --- a/generate_reference.py +++ b/generate_reference.py @@ -415,7 +415,7 @@ def visit_node(ast_node, module): if names_set.intersection(imported_names): import_name = get_import_name(module, ast_node.module, ast_node.level) module = import_module(import_name) - with open(module.__file__, "r") as fp: + with open(module.__file__) as fp: module_ast = ast.parse(fp.read()) visit_node(module_ast, module) @@ -431,7 +431,7 @@ def visit_node(ast_node, module): if name in names_set and name not in definitions: definitions[name] = ast_node - with open(root_module.__file__, "r") as fp: + with open(root_module.__file__) as fp: module_ast = ast.parse(fp.read()) visit_node(module_ast, root_module) @@ -460,7 +460,7 @@ def get_class_reference(obj, obj_ast: ast.ClassDef): bases = [base.id for base in obj_ast.bases] if bases: - reference += "(%s)" % (", ".join(bases)) + reference += "({})".format(", ".join(bases)) reference += ":\n ...\n```" diff --git a/pyproject.toml b/pyproject.toml index 66b6f32d8..93836c949 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,11 +15,11 @@ classifiers = [ "License :: OSI Approved :: BSD License", "Operating System :: OS Independent", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Software Development :: Libraries :: Python Modules", ] dependencies = [ @@ -30,25 +30,11 @@ dependencies = [ [project.optional-dependencies] dev = ["black", "mypy", "pylint"] -test = [ - "pytest", - "pytest-asyncio", - "pytest-benchmark", - "pytest-cov", - "pytest-mock", - "freezegun", - "syrupy", - "werkzeug", - "httpx", - "opentracing", - "opentelemetry-api", - "python-multipart>=0.0.5", - "aiodataloader", - "graphql-sync-dataloaders;python_version>\"3.7\"", -] asgi-file-uploads = ["python-multipart>=0.0.5"] tracing = ["opentracing"] telemetry = ["opentelemetry-api"] +types = ["mypy[faster-cache]>=1.0.0"] + [project.urls] "Homepage" = "https://ariadnegraphql.org/" @@ -68,35 +54,123 @@ exclude = [ "tests", ] +# Environment configuration + +## Default environment + [tool.hatch.envs.default] -features = ["dev", "test"] +features = ["dev", "types"] [tool.hatch.envs.default.scripts] -test = "coverage run -m pytest" +check = [ + "hatch fmt", + "hatch test -a -p", + "hatch test --cover", + "hatch run types:check", +] -[tool.black] -line-length = 88 -target-version = ['py36', 'py37', 'py38'] -include = '\.pyi?$' -exclude = ''' -/( - \.eggs - | \.git - | \.hg - | \.mypy_cache - | \.tox - | \.venv - | _build - | buck-out - | build - | dist - | snapshots -)/ -''' +## Types environment + +[tool.hatch.envs.types.scripts] +check = "mypy --install-types --non-interactive" + +## Test environment + +[tool.hatch.envs.hatch-test] +extra-dependencies = [ + "pytest", + "pytest-asyncio", + "pytest-benchmark", + "pytest-cov", + "pytest-mock", + "freezegun", + "syrupy", + "werkzeug", + "httpx", + "opentracing", + "opentelemetry-api", + "python-multipart>=0.0.5", + "aiodataloader", + "graphql-sync-dataloaders;python_version>\"3.7\"", +] +extra-args = [] + +[[tool.hatch.envs.hatch-test.matrix]] +python = ["3.9", "3.10", "3.11", "3.12", "3.13"] + +# Tool configuration + +## Pytest configuration [tool.pytest.ini_options] asyncio_mode = "strict" +asyncio_default_fixture_loop_scope = "function" testpaths = ["tests"] +## Types configuration + +[tool.mypy] +python_version = "3.9" +files = ["ariadne", "tests_mypy"] +check_untyped_defs = true +# disallow_untyped_defs = true +ignore_missing_imports = true +# warn_redundant_casts = true +# warn_unused_ignores = true +# disallow_any_generics = true +no_implicit_reexport = true +# strict = true +disable_error_code = ["import-untyped"] + +## Coverage configuration + [tool.coverage.run] source = ["ariadne", "tests"] + +[tool.coverage.report] +exclude_also = [ + "def __repr__", + "if self.debug:", + "if settings.DEBUG", + "raise AssertionError", + "raise NotImplementedError", + "if 0:", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", + "class .*\\bProtocol\\):", + "@(abc\\.)?abstractmethod", +] +omit = ["*/__about__.py", "*/__main__.py", "*/cli/__init__.py"] +fail_under = 90 + +## Ruff configuration + +[tool.ruff] +line-length = 88 +target-version = "py39" + +[tool.ruff.format] +docstring-code-format = true +docstring-code-line-length = 99 + +[tool.ruff.lint] +select = ["E", "F", "G", "I", "N", "Q", "UP", "C90", "T20", "TID"] +ignore = ["TID252"] +task-tags = ["NOTE", "TODO", "FIXME", "HACK", "XXX"] + +[tool.ruff.lint.pycodestyle] +ignore-overlong-task-comments = true + + +# [tool.ruff.lint.flake8-tidy-imports] +# ban-relative-imports = "all" + +[tool.ruff.lint.mccabe] +max-complexity = 15 + +[tool.ruff.lint.isort] +known-first-party = ["ariadne"] + +[tool.ruff.lint.flake8-pytest-style] +fixture-parentheses = false +mark-parentheses = false diff --git a/tests/asgi/conftest.py b/tests/asgi/conftest.py index a7034662a..de3809d50 100644 --- a/tests/asgi/conftest.py +++ b/tests/asgi/conftest.py @@ -1,5 +1,4 @@ import pytest - from starlette.testclient import TestClient from ariadne.asgi import GraphQL diff --git a/tests/asgi/test_query_execution.py b/tests/asgi/test_query_execution.py index e70376f01..47f8556a6 100644 --- a/tests/asgi/test_query_execution.py +++ b/tests/asgi/test_query_execution.py @@ -182,7 +182,7 @@ def test_query_is_executed_for_multipart_form_request_with_file( ), "map": json.dumps({"0": ["variables.file"]}), }, - files={"0": ("test.txt", "hello".encode("utf-8"))}, + files={"0": ("test.txt", b"hello")}, ) assert response.status_code == HTTPStatus.OK assert snapshot == response.json() diff --git a/tests/asgi/test_request_data_reading.py b/tests/asgi/test_request_data_reading.py index 38986c34a..fc5da24fb 100644 --- a/tests/asgi/test_request_data_reading.py +++ b/tests/asgi/test_request_data_reading.py @@ -47,7 +47,7 @@ def test_multipart_form_request_fails_if_operations_is_not_valid_json(client, sn "operations": "not a valid json", "map": json.dumps({"0": ["variables.file"]}), }, - files={"0": ("test.txt", "hello".encode("utf-8"))}, + files={"0": ("test.txt", b"hello")}, ) assert response.status_code == HTTPStatus.BAD_REQUEST assert snapshot == response.content @@ -65,7 +65,7 @@ def test_multipart_form_request_fails_if_map_is_not_valid_json(client, snapshot) ), "map": "not a valid json", }, - files={"0": ("test.txt", "hello".encode("utf-8"))}, + files={"0": ("test.txt", b"hello")}, ) assert response.status_code == HTTPStatus.BAD_REQUEST assert snapshot == response.content diff --git a/tests/asgi/test_sse.py b/tests/asgi/test_sse.py index 05a67257b..65042c215 100644 --- a/tests/asgi/test_sse.py +++ b/tests/asgi/test_sse.py @@ -1,10 +1,10 @@ import json from http import HTTPStatus -from typing import List, Dict, Any +from typing import Any from unittest.mock import Mock import pytest -from graphql import parse, GraphQLError +from graphql import GraphQLError, parse from httpx import Response from starlette.testclient import TestClient @@ -14,7 +14,7 @@ SSE_HEADER = {"Accept": "text/event-stream"} -def get_sse_events(response: Response) -> List[Dict[str, Any]]: +def get_sse_events(response: Response) -> list[dict[str, Any]]: events = [] for event in response.text.split("\r\n\r\n"): if len(event.strip()) == 0: @@ -96,7 +96,6 @@ def test_custom_query_parser_is_used_for_subscription_over_sse(schema): response = client.post("/", json={"query": "subscription { testRoot }"}) events = get_sse_events(response) - print(response) assert len(events) == 2 assert events[0]["data"]["data"] == {"testContext": "I'm context"} assert events[1]["event"] == "complete" diff --git a/tests/asgi/test_websockets_graphql_transport_ws.py b/tests/asgi/test_websockets_graphql_transport_ws.py index 5820d6e1d..c5c5cf222 100644 --- a/tests/asgi/test_websockets_graphql_transport_ws.py +++ b/tests/asgi/test_websockets_graphql_transport_ws.py @@ -2,7 +2,7 @@ from unittest.mock import Mock import pytest -from graphql import parse, GraphQLError +from graphql import GraphQLError, parse from graphql.language import OperationType from starlette.testclient import TestClient from starlette.websockets import WebSocketDisconnect @@ -11,6 +11,7 @@ from ariadne.asgi.handlers import GraphQLTransportWSHandler from ariadne.exceptions import WebSocketConnectionError from ariadne.utils import get_operation_type + from .websocket_utils import wait_for_condition diff --git a/tests/asgi/test_websockets_graphql_ws.py b/tests/asgi/test_websockets_graphql_ws.py index 48c47055d..938bd1115 100644 --- a/tests/asgi/test_websockets_graphql_ws.py +++ b/tests/asgi/test_websockets_graphql_ws.py @@ -2,12 +2,13 @@ from unittest.mock import Mock import pytest -from graphql import parse, GraphQLError +from graphql import GraphQLError, parse from starlette.testclient import TestClient from ariadne.asgi import GraphQL from ariadne.asgi.handlers import GraphQLWSHandler from ariadne.exceptions import WebSocketConnectionError + from .websocket_utils import wait_for_condition diff --git a/tests/conftest.py b/tests/conftest.py index 5dd06e353..8cf94cfdd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,7 +44,7 @@ def type_defs(): def resolve_hello(*_, name): - return "Hello, %s!" % name + return f"Hello, {name}!" def resolve_status(*_): @@ -76,7 +76,7 @@ def resolvers(): async def async_resolve_hello(*_, name): - return "Hello, %s!" % name + return f"Hello, {name}!" async def async_resolve_status(*_): @@ -145,7 +145,7 @@ def resolve_upload(*_, file): def resolve_echo(*_, text): - return "Echo: %s" % text + return f"Echo: {text}" @pytest.fixture diff --git a/tests/federation/test_interfaces.py b/tests/federation/test_interfaces.py index 57873891f..54013439c 100644 --- a/tests/federation/test_interfaces.py +++ b/tests/federation/test_interfaces.py @@ -1,6 +1,5 @@ -from graphql import graphql_sync - import pytest +from graphql import graphql_sync from ariadne.contrib.federation import ( FederatedInterfaceType, diff --git a/tests/federation/test_objects.py b/tests/federation/test_objects.py index dd6637158..498cd8f9e 100644 --- a/tests/federation/test_objects.py +++ b/tests/federation/test_objects.py @@ -1,6 +1,5 @@ -from graphql import graphql_sync - import pytest +from graphql import graphql_sync from ariadne.contrib.federation import ( FederatedObjectType, diff --git a/tests/federation/test_schema.py b/tests/federation/test_schema.py index 3a1f1dad5..5829556f6 100644 --- a/tests/federation/test_schema.py +++ b/tests/federation/test_schema.py @@ -226,7 +226,7 @@ def test_federated_schema_type_with_multiple_keys(): price: String } """ - product = FederatedObjectType("Product") + FederatedObjectType("Product") schema = make_federated_schema(type_defs) assert sic(print_object(schema.get_type("Product"))) == sic( diff --git a/tests/test_custom_scalars.py b/tests/test_custom_scalars.py index 9328af732..623414e22 100644 --- a/tests/test_custom_scalars.py +++ b/tests/test_custom_scalars.py @@ -102,14 +102,14 @@ def test_python_date_is_serialized_by_scalar(): def test_literal_with_valid_date_str_is_deserialized_to_python_date(): test_input = TEST_DATE_SERIALIZED - result = graphql_sync(schema, '{ testInput(value: "%s") }' % test_input) + result = graphql_sync(schema, f'{{ testInput(value: "{test_input}") }}') assert result.errors is None assert result.data == {"testInput": True} def test_attempt_deserialize_str_literal_without_valid_date_raises_error(): test_input = "invalid string" - result = graphql_sync(schema, '{ testInput(value: "%s") }' % test_input) + result = graphql_sync(schema, f'{{ testInput(value: "{test_input}") }}') assert result.errors is not None assert str(result.errors[0]).splitlines()[:1] == [ "Expected value of type 'DateInput!', found \"invalid string\"; " @@ -119,7 +119,7 @@ def test_attempt_deserialize_str_literal_without_valid_date_raises_error(): def test_attempt_deserialize_wrong_type_literal_raises_error(): test_input = 123 - result = graphql_sync(schema, "{ testInput(value: %s) }" % test_input) + result = graphql_sync(schema, f"{{ testInput(value: {test_input}) }}") assert result.errors is not None assert str(result.errors[0]).splitlines()[:1] == [ "Expected value of type 'DateInput!', found 123; " @@ -132,7 +132,7 @@ def test_default_literal_parser_is_used_to_extract_value_str_from_ast_node(): schema = make_executable_schema(type_defs, query, dateinput) result = graphql_sync( - schema, """{ testInput(value: "%s") }""" % TEST_DATE_SERIALIZED + schema, f"""{{ testInput(value: "{TEST_DATE_SERIALIZED}") }}""" ) assert result.errors is None assert result.data == {"testInput": True} diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index 3f0c1ace3..2248a4cbd 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -1,6 +1,5 @@ import pytest from aiodataloader import DataLoader as AsyncDataLoader - from graphql_sync_dataloaders import DeferredExecutionContext, SyncDataLoader from ariadne import QueryType, graphql, graphql_sync, make_executable_schema diff --git a/tests/test_enums.py b/tests/test_enums.py index c41aa826d..15f9c1e57 100644 --- a/tests/test_enums.py +++ b/tests/test_enums.py @@ -1,7 +1,7 @@ from enum import Enum, IntEnum import pytest -from graphql import graphql_sync, build_schema +from graphql import build_schema, graphql_sync from ariadne import EnumType, QueryType, make_executable_schema @@ -51,7 +51,7 @@ def test_successful_enum_value_passed_as_argument(): query.set_field("testEnum", lambda *_, value: True) schema = make_executable_schema([enum_definition, enum_param], query) - result = graphql_sync(schema, "{ testEnum(value: %s) }" % "NEWHOPE") + result = graphql_sync(schema, "{{ testEnum(value: {}) }}".format("NEWHOPE")) assert result.errors is None, result.errors @@ -72,7 +72,7 @@ def test_unsuccessful_invalid_enum_value_passed_as_argument(): query.set_field("testEnum", lambda *_, value: True) schema = make_executable_schema([enum_definition, enum_param], query) - result = graphql_sync(schema, "{ testEnum(value: %s) }" % "INVALID") + result = graphql_sync(schema, "{{ testEnum(value: {}) }}".format("INVALID")) assert result.errors is not None @@ -118,7 +118,7 @@ def test_dict_enum_arg_is_transformed_to_internal_value(): query.set_field("testEnum", lambda *_, value: value == 1977) schema = make_executable_schema([enum_definition, enum_param], [query, dict_enum]) - result = graphql_sync(schema, "{ testEnum(value: %s) }" % "NEWHOPE") + result = graphql_sync(schema, "{{ testEnum(value: {}) }}".format("NEWHOPE")) assert result.data["testEnum"] is True diff --git a/tests/test_error_formatting.py b/tests/test_error_formatting.py index 5e893378b..5176afcfa 100644 --- a/tests/test_error_formatting.py +++ b/tests/test_error_formatting.py @@ -24,11 +24,6 @@ def erroring_resolvers(failing_repr_mock): @query.field("hello") def resolve_hello_with_context_and_attribute_error(*_): # pylint: disable=undefined-variable, unused-variable - test_int = 123 - test_str = "test" - test_dict = {"test": "dict"} - test_obj = query - test_failing_repr = failing_repr_mock test_undefined.error() # trigger attr not found error return query diff --git a/tests/test_extensions.py b/tests/test_extensions.py index b458d63e9..f1c12ad91 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -5,7 +5,6 @@ from ariadne import ExtensionManager, graphql from ariadne.types import Extension - context = {} exception = ValueError() diff --git a/tests/test_fallback_resolvers.py b/tests/test_fallback_resolvers.py index b1514e062..2b284de6c 100644 --- a/tests/test_fallback_resolvers.py +++ b/tests/test_fallback_resolvers.py @@ -1,5 +1,5 @@ import pytest -from graphql import graphql_sync, build_schema +from graphql import build_schema, graphql_sync from ariadne import ObjectType, fallback_resolvers, snake_case_fallback_resolvers diff --git a/tests/test_objects.py b/tests/test_objects.py index eaf5b3e4f..11fa2e447 100644 --- a/tests/test_objects.py +++ b/tests/test_objects.py @@ -1,5 +1,5 @@ import pytest -from graphql import graphql_sync, build_schema +from graphql import build_schema, graphql_sync from ariadne import ObjectType diff --git a/tests/test_query_cost_validation.py b/tests/test_query_cost_validation.py index 79403a8ab..ec3abf88c 100644 --- a/tests/test_query_cost_validation.py +++ b/tests/test_query_cost_validation.py @@ -1,5 +1,4 @@ import pytest - from graphql import GraphQLError from graphql.language import parse from graphql.validation import validate diff --git a/tests/test_schema_visitor.py b/tests/test_schema_visitor.py index 527706abd..7a5c36119 100644 --- a/tests/test_schema_visitor.py +++ b/tests/test_schema_visitor.py @@ -1,4 +1,3 @@ -from typing import List from graphql.type import GraphQLObjectType, GraphQLSchema @@ -67,7 +66,7 @@ def test_visitor(): class SimpleVisitor(SchemaVisitor): visitCount = 0 - names: List[str] = [] + names: list[str] = [] def __init__(self, schema: GraphQLSchema): self.schema = schema diff --git a/tests/wsgi/conftest.py b/tests/wsgi/conftest.py index fe543a9a0..7910635c2 100644 --- a/tests/wsgi/conftest.py +++ b/tests/wsgi/conftest.py @@ -1,9 +1,9 @@ import json from io import StringIO from unittest.mock import Mock -from werkzeug.test import Client import pytest +from werkzeug.test import Client from ariadne.wsgi import GraphQL, GraphQLMiddleware diff --git a/tests/wsgi/test_http_error_handling.py b/tests/wsgi/test_http_error_handling.py index c41ca27a4..797054572 100644 --- a/tests/wsgi/test_http_error_handling.py +++ b/tests/wsgi/test_http_error_handling.py @@ -1,7 +1,7 @@ from unittest.mock import Mock from ariadne.constants import HttpStatusResponse -from ariadne.exceptions import HttpError, HttpBadRequestError +from ariadne.exceptions import HttpBadRequestError, HttpError def test_http_errors_raised_in_handle_request_are_passed_to_http_error_handler( diff --git a/tests_integrations/fastapi/test_websocket_connection.py b/tests_integrations/fastapi/test_websocket_connection.py index 68537b172..c7599a00e 100644 --- a/tests_integrations/fastapi/test_websocket_connection.py +++ b/tests_integrations/fastapi/test_websocket_connection.py @@ -1,10 +1,10 @@ -from ariadne import SubscriptionType, make_executable_schema -from ariadne.asgi import GraphQL -from ariadne.asgi.handlers import GraphQLTransportWSHandler from fastapi import FastAPI from fastapi.websockets import WebSocket from starlette.testclient import TestClient +from ariadne import SubscriptionType, make_executable_schema +from ariadne.asgi import GraphQL +from ariadne.asgi.handlers import GraphQLTransportWSHandler subscription_type = SubscriptionType() diff --git a/tests_integrations/starlette/test_websocket_connection.py b/tests_integrations/starlette/test_websocket_connection.py index 4e86b3151..5171aca6e 100644 --- a/tests_integrations/starlette/test_websocket_connection.py +++ b/tests_integrations/starlette/test_websocket_connection.py @@ -1,10 +1,10 @@ -from ariadne import SubscriptionType, make_executable_schema -from ariadne.asgi import GraphQL -from ariadne.asgi.handlers import GraphQLTransportWSHandler from starlette.applications import Starlette from starlette.routing import Mount, WebSocketRoute from starlette.testclient import TestClient +from ariadne import SubscriptionType, make_executable_schema +from ariadne.asgi import GraphQL +from ariadne.asgi.handlers import GraphQLTransportWSHandler subscription_type = SubscriptionType() diff --git a/tests_mypy/middlewares.py b/tests_mypy/middlewares.py index 34cb8b141..9e6a96adc 100644 --- a/tests_mypy/middlewares.py +++ b/tests_mypy/middlewares.py @@ -1,4 +1,5 @@ from typing import Any + from graphql import GraphQLResolveInfo from ariadne import make_executable_schema @@ -6,7 +7,6 @@ from ariadne.asgi.handlers import GraphQLHTTPHandler from ariadne.wsgi import GraphQL as GraphQLWSGI - schema = make_executable_schema( """ type Query {