Skip to content

Commit

Permalink
fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
underchemist committed Sep 27, 2021
1 parent dd0a1d3 commit 9702533
Show file tree
Hide file tree
Showing 15 changed files with 48 additions and 40 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ repos:
rev: v0.910
hooks:
- id: mypy
language_version: python
language_version: python
args: [--install-types, --non-interactive]
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@

class DataSourceManager(ABC):
@abstractmethod
def get_data_sources() -> List[Type[DataSource]]:
def get_data_sources(self) -> List[Type[DataSource]]:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from hashlib import sha256
from logging import getLogger
from os import path
from typing import Any, Coroutine, Dict, Final, List, Type
from typing import Any, Awaitable, Callable, Dict, Final, List, Type

# geoalchemy import required for sa.MetaData reflection, even though unused in module
import geoalchemy2 as ga # noqa: F401
Expand Down Expand Up @@ -51,7 +51,7 @@ class PostgresqlDataSource(DataSource):
def __init__(
self,
connection_name: str,
connection_tester: Coroutine[None, None, Database],
connection_tester: Callable[[Database, str], Awaitable[None]],
):
super().__init__(f"{self.DATA_SOURCE_NAME}:{connection_name}")
self.db = Database(settings.url(connection_name))
Expand Down Expand Up @@ -155,7 +155,7 @@ def get_if_available(
async def get_feature_set_provider(
self,
layer: PostgresqlLayer,
constraints: ItemConstraints = None,
constraints: ItemConstraints,
ast: Type[Node] = None,
) -> Type[FeatureSetProvider]:
filters = (
Expand Down Expand Up @@ -189,7 +189,9 @@ async def get_feature_set_provider(
await self.db.fetch_one(
sa.select([sa.func.count()]).select_from(layer.model).where(filters)
)
)[0]
)[
0
] # type: ignore

return PostgresqlFeatureSetProvider(self.db, id_set, layer, total_count)

Expand Down Expand Up @@ -223,7 +225,7 @@ async def _get_derived_layers(self) -> Dict[str, PostgresqlLayer]:
bboxes=[table_spatial_extents[qualified_layer_name]],
intervals=table_temporal_extents[qualified_layer_name]
if qualified_layer_name in table_temporal_extents
else [[None, None]],
else [[None, None]], # type: ignore
data_source_id=self.id,
schema_name=tables[qualified_layer_name]["schema_name"],
table_name=tables[qualified_layer_name]["table_name"],
Expand Down Expand Up @@ -415,7 +417,7 @@ async def _get_table_temporal_extents( # noqa: C901
table_models: Dict[str, sa.Table],
table_temporal_fields: Dict[str, List[TemporalInstant]],
) -> Dict[str, List[datetime]]:
table_temporal_extents = {}
table_temporal_extents = {} # type: ignore
for qualified_table_name, temporal_fields in table_temporal_fields.items():
start = None
end = None
Expand Down
34 changes: 18 additions & 16 deletions oaff/app/oaff/app/request_handlers/collection_items.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime, tzinfo
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, cast

from pygeofilter.ast import (
And,
Expand Down Expand Up @@ -82,7 +82,7 @@ def _get_page_link_retriever(

frontend_config = get_frontend_configuration()

def retriever(total_count: int, result_count: int) -> List[Link]:
def retriever(total_count: int, result_count: int) -> Dict[PageLinkRel, Link]:
links = {}
if request.offset > 0:
links[PageLinkRel.PREV] = Link(
Expand All @@ -107,7 +107,7 @@ def _collect_ast(
self,
bbox: BBox,
datetime: Any,
) -> Type[Node]:
) -> Optional[Type[Node]]:
if bbox is None and datetime is None:
return None
else:
Expand All @@ -120,21 +120,23 @@ def _collect_ast(

async def _spatial_bounds_to_node(
self,
spatial_bounds: Optional[
Union[
Tuple[float, float, float, float],
Tuple[float, float, float, float, float, float],
]
spatial_bounds: Union[
Tuple[float, float, float, float],
Tuple[float, float, float, float, float, float],
],
spatial_bounds_crs: str,
data_source: DataSource,
layer: Layer,
) -> BBox:
x_min, y_min, x_max, y_max = (
spatial_bounds
if len(spatial_bounds) == 4
else (spatial_bounds[i] for i in [0, 1, 3, 4])
)
# recommended usage for Union of types
# https://github.com/python/mypy/issues/1178#issuecomment-176185607
if len(spatial_bounds) == 4:
a, b, c, d = cast(Tuple[float, float, float, float], spatial_bounds)
else:
a, b, _, c, d, _ = cast(
Tuple[float, float, float, float, float, float], spatial_bounds
)
x_min, y_min, x_max, y_max = (a, b, c, d)
transformer = Transformer.from_crs(
"EPSG:4326", # True until Features API spec part 2 is implemented
f"{layer.geometry_crs_auth_name}:{layer.geometry_crs_auth_code}",
Expand All @@ -155,13 +157,13 @@ async def _datetime_to_node( # noqa: C901
self,
temporal_bounds: Union[Tuple[datetime], Tuple[datetime, datetime]],
layer: Layer,
) -> Type[Node]:
) -> Optional[Type[Node]]:
if len(list(filter(lambda bound: bound is not None, temporal_bounds))) == 0:
return None
nodes = []
for data_field in layer.temporal_attributes:
if len(temporal_bounds) == 2:
query_start, query_end = temporal_bounds
query_start, query_end = cast(Tuple[datetime, datetime], temporal_bounds)
if data_field.__class__ is TemporalInstant:
if query_start is not None and query_end is not None:
nodes.append(
Expand Down Expand Up @@ -333,6 +335,6 @@ def _match_query_time_to_end_field(
)

def _match_query_time_to(
self, query_time: datetime, tz_aware: bool, tz: Type[tzinfo]
self, query_time: datetime, tz_aware: bool, tz: tzinfo
) -> datetime:
return query_time if tz_aware else query_time.astimezone(tz).replace(tzinfo=None)
2 changes: 1 addition & 1 deletion oaff/app/oaff/app/request_handlers/collections_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
LOGGER: Final = getLogger(__file__)


class CollectionsList(RequestHandler):
class CollectionsList(RequestHandler): # type: ignore
@classmethod
def type_name(cls) -> str:
return CollectionsList.__name__
Expand Down
2 changes: 1 addition & 1 deletion oaff/app/oaff/app/request_handlers/conformance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from oaff.app.responses.response_format import ResponseFormat


class Conformance(RequestHandler):
class Conformance(RequestHandler): # type: ignore
@classmethod
def type_name(cls) -> str:
return Conformance.__name__
Expand Down
2 changes: 1 addition & 1 deletion oaff/app/oaff/app/request_handlers/landing_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from oaff.app.settings import OPENAPI_OGC_TYPE


class LandingPage(RequestHandler):
class LandingPage(RequestHandler): # type: ignore
@classmethod
def type_name(cls) -> str:
return LandingPage.__name__
Expand Down
2 changes: 1 addition & 1 deletion oaff/app/oaff/app/responses/response_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from oaff.app.responses.response_type import ResponseType


class ResponseFormat(dict, Enum):
class ResponseFormat(dict, Enum): # type: ignore
html = {
ResponseType.DATA: "text/html",
ResponseType.METADATA: "text/html",
Expand Down
2 changes: 1 addition & 1 deletion oaff/app/oaff/app/responses/templates/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def get_rendered_html(template_name: str, data: object, locale: Locales) -> str:
loader=PackageLoader("oaff.app", path.join("responses", "templates", "html")),
autoescape=select_autoescape(["html"]),
)
env.install_gettext_translations(get_translations_for_locale(locale))
env.install_gettext_translations(get_translations_for_locale(locale)) # type: ignore
frontend_config = get_frontend_configuration()
return env.get_template(f"{template_name}.jinja2").render(
response=data,
Expand Down
4 changes: 3 additions & 1 deletion oaff/fastapi/api/openapi/openapi.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Callable

from fastapi.applications import FastAPI
from fastapi.requests import Request

from oaff.fastapi.api.openapi.vnd_response import VndResponse


def get_openapi_handler(app: FastAPI) -> VndResponse:
def get_openapi_handler(app: FastAPI) -> Callable[[Request], VndResponse]:
def handler(_: Request):
# OpenAPI spec must be modified because FastAPI doesn't support
# encoding style: https://github.com/tiangolo/fastapi/issues/283
Expand Down
7 changes: 3 additions & 4 deletions oaff/fastapi/api/routes/collections.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from datetime import datetime
from logging import getLogger
from typing import Final, Optional, Tuple, Union
from typing import Final, Optional, Tuple, Union, cast
from urllib.parse import quote

import iso8601
import pytz
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Query
from fastapi.param_functions import Depends
from fastapi.params import Query
from fastapi.requests import Request

from oaff.app.requests.collection import Collection as CollectionRequestType
Expand Down Expand Up @@ -216,7 +215,7 @@ def parse_datetime(datetime_str: str) -> datetime:
raise HTTPException(
status_code=400, detail="datetime start cannot be after end"
)
return tuple(result)
return cast(Union[Tuple[datetime], Tuple[datetime, datetime]], tuple(result))


def _get_safe_url(path_template: str, request: Request, root: str) -> str:
Expand Down
4 changes: 2 additions & 2 deletions oaff/fastapi/api/routes/common/common_parameters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from re import compile, search, sub
from typing import Final, Optional, Set
from typing import Final, List, Optional

from fastapi import Header, Query
from fastapi.requests import Request
Expand Down Expand Up @@ -66,7 +66,7 @@ async def populate(
)

@classmethod
def _header_options_by_preference(cls, header_value: str) -> Set[str]:
def _header_options_by_preference(cls, header_value: str) -> List[str]:
options = list(
filter(
lambda option: len(option) > 0,
Expand Down
3 changes: 2 additions & 1 deletion oaff/fastapi/api/routes/common/parameter_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from oaff.fastapi.api.routes.common.common_parameters import COMMON_QUERY_PARAMS


def strict(request: Request, permitted: Optional[List[str]] = []) -> None:
def strict(request: Request, permitted: Optional[List[str]] = None) -> None:
permitted = list() if permitted is None else permitted
excessive = set(request.query_params.keys()).difference(
set(permitted + COMMON_QUERY_PARAMS)
)
Expand Down
4 changes: 2 additions & 2 deletions oaff/fastapi/api/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def _change_page(url: str, forward: bool) -> str:
for key, value in {
**parameters,
**{
"offset": max(offset + limit * (1 if forward else -1), 0),
"limit": limit,
"offset": str(max(offset + limit * (1 if forward else -1), 0)),
"limit": str(limit),
},
}.items()
]
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ exclude="^(tests|testing)$"
namespace_packages=true
explicit_package_bases=true
ignore_missing_imports=true
ignore_errors=true
no_warn_no_return=true
no_strict_optional=true

0 comments on commit 9702533

Please sign in to comment.