From 1ba42951054cdb9c59df675fb581f2e814215531 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Pito=C5=84?= Date: Thu, 15 Feb 2024 19:14:27 +0100 Subject: [PATCH] Execute queries over GET --- ariadne/asgi/handlers/http.py | 56 +++++++++++++++++++++++++++++++---- ariadne/graphql.py | 33 +++++++++++++++++++++ 2 files changed, 84 insertions(+), 5 deletions(-) diff --git a/ariadne/asgi/handlers/http.py b/ariadne/asgi/handlers/http.py index 8b49802b2..561b507b3 100644 --- a/ariadne/asgi/handlers/http.py +++ b/ariadne/asgi/handlers/http.py @@ -35,6 +35,7 @@ class GraphQLHTTPHandler(GraphQLHttpHandlerBase): def __init__( self, + execute_get_queries: bool = False, extensions: Optional[Extensions] = None, middleware: Optional[Middlewares] = None, middleware_manager_class: Optional[Type[MiddlewareManager]] = None, @@ -43,6 +44,9 @@ def __init__( # Optional arguments + `execute_get_queries`: a `bool` that controls if `query` operations + sent using the `GET` method should be executed. Defaults to `False`. + `extensions`: an `Extensions` list or callable returning a list of extensions server should use during query execution. Defaults to no extensions. @@ -58,6 +62,7 @@ def __init__( """ super().__init__() + self.execute_get_queries = execute_get_queries self.extensions = extensions self.middleware = middleware self.middleware_manager_class = middleware_manager_class or MiddlewareManager @@ -114,9 +119,12 @@ async def handle_request(self, request: Request) -> Response: `request`: the `Request` instance from Starlette or FastAPI. """ - if request.method == "GET" and self.introspection and self.explorer: - # only render explorer when introspection is enabled - return await self.render_explorer(request, self.explorer) + if request.method == "GET": + if self.execute_get_queries and request.query_params.get("query"): + return await self.graphql_http_server(request) + elif self.introspection and self.explorer: + # only render explorer when introspection is enabled + return await self.render_explorer(request, self.explorer) if request.method == "POST": return await self.graphql_http_server(request) @@ -182,6 +190,12 @@ async def extract_data_from_request(self, request: Request): return await self.extract_data_from_json_request(request) if content_type == DATA_TYPE_MULTIPART: return await self.extract_data_from_multipart_request(request) + if ( + request.method == "GET" + and self.execute_get_queries + and request.query_params.get("query") + ): + return await self.extract_data_from_get_request(request) raise HttpBadRequestError( "Posted content must be of type {} or {}".format( @@ -189,7 +203,7 @@ async def extract_data_from_request(self, request: Request): ) ) - async def extract_data_from_json_request(self, request: Request): + async def extract_data_from_json_request(self, request: Request) -> dict: """Extracts GraphQL data from JSON request. Returns a `dict` with GraphQL query data that was not yet validated. @@ -203,7 +217,9 @@ async def extract_data_from_json_request(self, request: Request): except (TypeError, ValueError) as ex: raise HttpBadRequestError("Request body is not a valid JSON") from ex - async def extract_data_from_multipart_request(self, request: Request): + async def extract_data_from_multipart_request( + self, request: Request + ) -> dict | list: """Extracts GraphQL data from `multipart/form-data` request. Returns an unvalidated `dict` with GraphQL query data. @@ -240,6 +256,35 @@ async def extract_data_from_multipart_request(self, request: Request): return combine_multipart_data(operations, files_map, request_files) + async def extract_data_from_get_request(self, request: Request) -> dict: + """Extracts GraphQL data from GET request's querystring. + + Returns a `dict` with GraphQL query data that was not yet validated. + + # Required arguments + + `request`: the `Request` instance from Starlette or FastAPI. + """ + query = request.query_params["query"].strip() + operation_name = request.query_params.get("operationName", "").strip() + variables = request.query_params.get("variables", "").strip() + + clean_variables = None + + if variables: + try: + clean_variables = json.loads(variables) + except (TypeError, ValueError) as ex: + raise HttpBadRequestError( + "Variables query arg is not a valid JSON" + ) from ex + + return { + "query": query, + "operationName": operation_name or None, + "variables": clean_variables, + } + async def execute_graphql_query( self, request: Any, @@ -284,6 +329,7 @@ async def execute_graphql_query( query_validator=self.query_validator, query_document=query_document, validation_rules=self.validation_rules, + require_query=request.method == "GET", debug=self.debug, introspection=self.introspection, logger=self.logger, diff --git a/ariadne/graphql.py b/ariadne/graphql.py index f3d25a733..1c5ee1428 100644 --- a/ariadne/graphql.py +++ b/ariadne/graphql.py @@ -22,6 +22,7 @@ GraphQLError, GraphQLSchema, MiddlewareManager, + OperationDefinitionNode, TypeInfo, execute, execute_sync, @@ -71,6 +72,7 @@ async def graphql( introspection: bool = True, logger: Union[None, str, Logger, LoggerAdapter] = None, validation_rules: Optional[ValidationRules] = None, + require_query: bool = False, error_formatter: ErrorFormatter = format_error, middleware: MiddlewareList = None, middleware_manager_class: Optional[Type[MiddlewareManager]] = None, @@ -123,6 +125,9 @@ async def graphql( `validation_rules`: a `list` of or callable returning list of custom validation rules to use to validate query before it's executed. + `require_query`: a `bool` controlling if GraphQL operation to execute must be + a query (vs. mutation or subscription). + `error_formatter`: an `ErrorFormatter` callable to use to convert GraphQL errors encountered during query execution to JSON-serializable format. @@ -169,6 +174,7 @@ async def graphql( enable_introspection=introspection, query_validator=query_validator, ) + if validation_errors: return handle_graphql_errors( validation_errors, @@ -178,6 +184,9 @@ async def graphql( extension_manager=extension_manager, ) + if require_query and not validation_errors: + validate_operation_is_query(document, operation_name) + if callable(root_value): try: root_value = root_value( # type: ignore @@ -639,3 +648,27 @@ def validate_variables(variables) -> None: def validate_operation_name(operation_name) -> None: if operation_name is not None and not isinstance(operation_name, str): raise GraphQLError('"%s" is not a valid operation name.' % operation_name) + + +def validate_operation_is_query(document_ast: DocumentNode, operation_name: str): + query_operations: List[Optional[str]] = [] + for definition in document_ast.definitions: + if ( + isinstance(definition, OperationDefinitionNode) + and definition.operation.name == "QUERY" + ): + if definition.name: + query_operations.append(definition.name.value) + else: + query_operations.append(None) + + if operation_name: + if query_operations not in query_operations: + raise GraphQLError( + f"Operation '{operation_name}' can't be executed using the GET " + "HTTP method. Use POST instead." + ) + elif len(query_operations) != 1: + raise GraphQLError( + f"'operationName' is required if 'query' defines multiple operations." + )