diff --git a/connexion/json_schema.py b/connexion/json_schema.py index f2d67a39b..f6d4c364b 100644 --- a/connexion/json_schema.py +++ b/connexion/json_schema.py @@ -4,6 +4,7 @@ import contextlib import io +import json import os import typing as t import urllib.parse @@ -13,9 +14,11 @@ import requests import yaml -from jsonschema import Draft4Validator, RefResolver -from jsonschema.exceptions import RefResolutionError, ValidationError # noqa +from jsonschema import Draft4Validator +from jsonschema.exceptions import ValidationError from jsonschema.validators import extend +from referencing import Registry, Resource +from referencing.jsonschema import DRAFT4 from .utils import deep_get @@ -62,12 +65,27 @@ def __call__(self, uri): return yaml.load(fh, ExtendedSafeLoader) -handlers = { - "http": URLHandler(), - "https": URLHandler(), - "file": FileHandler(), - "": FileHandler(), -} +def resource_from_spec(spec: t.Dict[str, t.Any]) -> Resource: + """Create a `referencing.Resource` from a schema specification.""" + return Resource.from_contents(spec, default_specification=DRAFT4) + + +def retrieve(uri: str) -> Resource: + """Retrieve a resource from a URI. + + This function is passed to the `referencing.Registry`, + which calls it any URI is not present in the registry is accessed.""" + parsed = urllib.parse.urlsplit(uri) + if parsed.scheme in ("http", "https"): + content = URLHandler()(uri) + elif parsed.scheme in ("file", ""): + content = FileHandler()(uri) + else: # pragma: no cover + # Default branch from jsonschema.RefResolver.resolve_remote() + # for backwards compatibility. + with urllib.request.urlopen(uri) as url: + content = json.loads(url.read().decode("utf-8")) + return resource_from_spec(content) def resolve_refs(spec, store=None, base_uri=""): @@ -78,9 +96,14 @@ def resolve_refs(spec, store=None, base_uri=""): """ spec = deepcopy(spec) store = store or {} - resolver = RefResolver(base_uri, spec, store, handlers=handlers) + registry = Registry(retrieve=retrieve).with_resources( + ( + (base_uri, resource_from_spec(spec)), + *((key, resource_from_spec(value)) for key, value in store.items()), + ) + ) - def _do_resolve(node): + def _do_resolve(node, resolver): if isinstance(node, Mapping) and "$ref" in node: path = node["$ref"][2:].split("/") try: @@ -88,22 +111,22 @@ def _do_resolve(node): retrieved = deep_get(spec, path) node.update(retrieved) if isinstance(retrieved, Mapping) and "$ref" in retrieved: - node = _do_resolve(node) + node = _do_resolve(node, resolver) node.pop("$ref", None) return node except KeyError: # resolve external references - with resolver.resolving(node["$ref"]) as resolved: - return _do_resolve(resolved) + resolved = resolver.lookup(node["$ref"]) + return _do_resolve(resolved.contents, resolved.resolver) elif isinstance(node, Mapping): for k, v in node.items(): - node[k] = _do_resolve(v) + node[k] = _do_resolve(v, resolver) elif isinstance(node, (list, tuple)): for i, _ in enumerate(node): - node[i] = _do_resolve(node[i]) + node[i] = _do_resolve(node[i], resolver) return node - res = _do_resolve(spec) + res = _do_resolve(spec, registry.resolver(base_uri)) return res diff --git a/pyproject.toml b/pyproject.toml index b1ab36041..114f6749a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ Jinja2 = ">= 3.0.0" python-multipart = ">= 0.0.15" PyYAML = ">= 5.1" requests = ">= 2.27" +referencing = ">= 0.12.0" starlette = ">= 0.35" typing-extensions = ">= 4.6.1" werkzeug = ">= 2.2.1" diff --git a/tests/test_references.py b/tests/test_references.py index 8d9bc8d1a..90457992c 100644 --- a/tests/test_references.py +++ b/tests/test_references.py @@ -1,8 +1,9 @@ from unittest import mock import pytest -from connexion.json_schema import RefResolutionError, resolve_refs +from connexion.json_schema import resolve_refs from connexion.jsonifier import Jsonifier +from referencing.exceptions import Unresolvable DEFINITIONS = { "new_stack": { @@ -50,7 +51,7 @@ def test_non_existent_reference(api): } ] } - with pytest.raises(RefResolutionError) as exc_info: # type: py.code.ExceptionInfo + with pytest.raises(Unresolvable) as exc_info: # type: py.code.ExceptionInfo resolve_refs(op_spec, {}) exception = exc_info.value @@ -69,7 +70,7 @@ def test_invalid_reference(api): ] } - with pytest.raises(RefResolutionError) as exc_info: # type: py.code.ExceptionInfo + with pytest.raises(Unresolvable) as exc_info: # type: py.code.ExceptionInfo resolve_refs( op_spec, {"definitions": DEFINITIONS, "parameters": PARAMETER_DEFINITIONS} ) @@ -84,7 +85,7 @@ def test_resolve_invalid_reference(api): "parameters": [{"$ref": "/parameters/fail"}], } - with pytest.raises(RefResolutionError) as exc_info: + with pytest.raises(Unresolvable) as exc_info: resolve_refs(op_spec, {"parameters": PARAMETER_DEFINITIONS}) exception = exc_info.value