diff --git a/allspice/allspice.py b/allspice/allspice.py index 81cddc0..c5820e6 100644 --- a/allspice/allspice.py +++ b/allspice/allspice.py @@ -1,6 +1,6 @@ import json import logging -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Union import requests import urllib3 @@ -92,7 +92,7 @@ def __get_url(self, endpoint): self.logger.debug("Url: %s" % url) return url - def __get(self, endpoint: str, params=frozendict()) -> requests.Response: + def __get(self, endpoint: str, params: Mapping = frozendict()) -> requests.Response: request = self.requests.get(self.__get_url(endpoint), headers=self.headers, params=params) if request.status_code not in [200, 201]: message = f"Received status code: {request.status_code} ({request.url})" @@ -116,7 +116,7 @@ def parse_result(result) -> Dict: return json.loads(result.text) return {} - def requests_get(self, endpoint: str, params=frozendict(), sudo=None): + def requests_get(self, endpoint: str, params: Mapping = frozendict(), sudo=None): combined_params = {} combined_params.update(params) if sudo: @@ -202,7 +202,9 @@ def requests_post( :return: The JSON response parsed as a dict """ - args = { + # This should ideally be a TypedDict of the type of arguments taken by + # `requests.post`. + args: dict[str, Any] = { "headers": self.headers.copy(), } if data is not None: @@ -258,14 +260,14 @@ def get_users(self) -> List[User]: results = self.requests_get(AllSpice.GET_USERS_ADMIN) return [User.parse_response(self, result) for result in results] - def get_user_by_email(self, email: str) -> User: + def get_user_by_email(self, email: str) -> Optional[User]: users = self.get_users() for user in users: if user.email == email or email in user.emails: return user return None - def get_user_by_name(self, username: str) -> User: + def get_user_by_name(self, username: str) -> Optional[User]: users = self.get_users() for user in users: if user.username == username: diff --git a/allspice/apiobject.py b/allspice/apiobject.py index afdf66e..955242e 100644 --- a/allspice/apiobject.py +++ b/allspice/apiobject.py @@ -10,6 +10,7 @@ Any, ClassVar, Dict, + FrozenSet, List, Literal, Optional, @@ -19,6 +20,11 @@ Union, ) +try: + from typing_extensions import Self +except ImportError: + from typing import Self + from .baseapiobject import ApiObject, ReadonlyApiObject from .exceptions import ConflictException, NotFoundException @@ -70,7 +76,7 @@ def __hash__(self): return hash(self.allspice_client) ^ hash(self.name) @classmethod - def request(cls, allspice_client, name: str) -> "Organization": + def request(cls, allspice_client, name: str) -> Self: return cls._request(allspice_client, {"name": name}) @classmethod @@ -131,7 +137,7 @@ def create_repo( else: self.allspice_client.logger.error(result["message"]) raise Exception("Repository not created... (gitea: %s)" % result["message"]) - return Repository.parse_response(self, result) + return Repository.parse_response(self.allspice_client, result) def get_repositories(self) -> List["Repository"]: results = self.allspice_client.requests_get_paginated( @@ -359,7 +365,7 @@ def create_repo( else: self.allspice_client.logger.error(result["message"]) raise Exception("Repository not created... (gitea: %s)" % result["message"]) - return Repository.parse_response(self, result) + return Repository.parse_response(self.allspice_client, result) def get_repositories(self) -> List["Repository"]: """Get all Repositories owned by this User.""" @@ -381,7 +387,7 @@ def get_teams(self) -> List["Team"]: def get_accessible_repos(self) -> List["Repository"]: """Get all Repositories accessible by the logged in User.""" results = self.allspice_client.requests_get("/user/repos", sudo=self) - return [Repository.parse_response(self, result) for result in results] + return [Repository.parse_response(self.allspice_client, result) for result in results] def __request_emails(self): result = self.allspice_client.requests_get(User.USER_MAIL % self.login) @@ -562,7 +568,7 @@ def __hash__(self): if r["email"] == "" else User.parse_response(allspice_client, r) ), - "updated_at": lambda allspice_client, t: Util.convert_time(t), + "updated_at": lambda _, t: Util.convert_time(t), } @classmethod @@ -747,7 +753,6 @@ def get_issues( setattr(issue, "_repository", self) # This is mostly for compatibility with an older implementation Issue._add_read_property("repo", self, issue) - Issue._add_read_property("owner", self.owner, issue) issues.append(issue) return issues @@ -909,7 +914,7 @@ def create_design_review( :return: The created Design Review """ - data = { + data: dict[str, Any] = { "title": title, } @@ -1002,11 +1007,11 @@ def remove_collaborator(self, user_name: str): def transfer_ownership( self, - new_owner: Union["User", "Organization"], - new_teams: Set["Team"] = frozenset(), + new_owner: Union[User, Organization], + new_teams: Set[Team] | FrozenSet[Team] = frozenset(), ): url = Repository.REPO_TRANSFER.format(owner=self.owner.username, repo=self.name) - data = {"new_owner": new_owner.username} + data: dict[str, Any] = {"new_owner": new_owner.username} if isinstance(new_owner, Organization): new_team_ids = [team.id for team in new_teams if team in new_owner.get_teams()] data["team_ids"] = new_team_ids @@ -1014,10 +1019,10 @@ def transfer_ownership( # TODO: make sure this instance is either updated or discarded def get_git_content( - self: Optional[str] = None, + self, ref: Optional["Ref"] = None, commit: "Optional[Commit]" = None, - ) -> List["Content"]: + ) -> List[Content]: """ Get a list of all files in the repository. @@ -1430,8 +1435,8 @@ def __hash__(self): return hash(self.allspice_client) ^ hash(self.id) _fields_to_parsers: ClassVar[dict] = { - "closed_at": lambda allspice_client, t: Util.convert_time(t), - "due_on": lambda allspice_client, t: Util.convert_time(t), + "closed_at": lambda _, t: Util.convert_time(t), + "due_on": lambda _, t: Util.convert_time(t), } _patchable_fields: ClassVar[set[str]] = { @@ -1508,10 +1513,10 @@ def __init__(self, allspice_client): def __eq__(self, other): if not isinstance(other, Comment): return False - return self.repo == other.repo and self.id == other.id + return self.repository == other.repository and self.id == other.id def __hash__(self): - return hash(self.repo) ^ hash(self.id) + return hash(self.repository) ^ hash(self.id) @classmethod def request(cls, allspice_client, owner: str, repo: str, id: str) -> "Comment": @@ -1580,7 +1585,7 @@ def create_attachment(self, file: IO, name: Optional[str] = None) -> Attachment: :return: The created attachment. """ - args = { + args: dict[str, Any] = { "files": {"attachment": file}, } if name is not None: @@ -1702,6 +1707,9 @@ def get_statuses(self) -> List[CommitStatus]: @cached_property def _fields_for_path(self) -> dict[str, str]: matches = self.URL_REGEXP.search(self.url) + if not matches: + raise ValueError(f"Invalid commit URL: {self.url}") + return { "owner": matches.group(1), "repo": matches.group(2), @@ -1717,7 +1725,7 @@ class CommitStatusState(Enum): WARNING = "warning" @classmethod - def try_init(cls, value: str) -> Union[CommitStatusState, str]: + def try_init(cls, value: str) -> CommitStatusState | str: """ Try converting a string to the enum, and if that fails, return the string itself. @@ -1726,7 +1734,7 @@ def try_init(cls, value: str) -> Union[CommitStatusState, str]: try: return cls(value) except ValueError: - value + return value class CommitStatus(ReadonlyApiObject): @@ -1838,10 +1846,10 @@ def __init__(self, allspice_client): def __eq__(self, other): if not isinstance(other, Issue): return False - return self.repo == other.repo and self.id == other.id + return self.repository == other.repository and self.id == other.id def __hash__(self): - return hash(self.repo) ^ hash(self.id) + return hash(self.repository) ^ hash(self.id) _fields_to_parsers: ClassVar[dict] = { "milestone": lambda allspice_client, m: Milestone.parse_response(allspice_client, m), @@ -1880,9 +1888,10 @@ def request(cls, allspice_client, owner: str, repo: str, number: str): api_object = cls._request(allspice_client, {"owner": owner, "repo": repo, "index": number}) # The repository in the response is a RepositoryMeta object, so request # the full repository object and add it to the issue object. - repo = Repository.request(allspice_client, owner, repo) - setattr(api_object, "_repository", repo) - cls._add_read_property("repo", repo, api_object) + repository = Repository.request(allspice_client, owner, repo) + setattr(api_object, "_repository", repository) + # For legacy reasons + cls._add_read_property("repo", repository, api_object) return api_object @classmethod @@ -1895,9 +1904,13 @@ def create_issue(cls, allspice_client, repo: Repository, title: str, body: str = cls._add_read_property("repo", repo, issue) return issue + @property + def owner(self) -> Organization | User: + return self.repository.owner + def get_time_sum(self, user: User) -> int: results = self.allspice_client.requests_get( - Issue.GET_TIME % (self.owner.username, self.repo.name, self.number) + Issue.GET_TIME % (self.owner.username, self.repository.name, self.number) ) return sum(result["time"] for result in results if result and result["user_id"] == user.id) @@ -1921,7 +1934,7 @@ def get_comments(self) -> List[Comment]: results = self.allspice_client.requests_get( self.GET_COMMENTS.format( - owner=self.owner.username, repo=self.repo.name, index=self.number + owner=self.owner.username, repo=self.repository.name, index=self.number ) ) @@ -1931,7 +1944,7 @@ def create_comment(self, body: str) -> Comment: """https://hub.allspice.io/api/swagger#/issue/issueCreateComment""" path = self.GET_COMMENTS.format( - owner=self.owner.username, repo=self.repo.name, index=self.number + owner=self.owner.username, repo=self.repository.name, index=self.number ) response = self.allspice_client.requests_post(path, data={"body": body}) @@ -2004,10 +2017,10 @@ def __init__(self, allspice_client): def __eq__(self, other): if not isinstance(other, DesignReview): return False - return self.repo == other.repo and self.id == other.id + return self.repository == other.repository and self.id == other.id def __hash__(self): - return hash(self.repo) ^ hash(self.id) + return hash(self.repository) ^ hash(self.id) @classmethod def parse_response(cls, allspice_client, result) -> "DesignReview": @@ -2224,6 +2237,7 @@ class Release(ApiObject): prerelease: bool published_at: str repo: Optional["Repository"] + repository: Optional["Repository"] tag_name: str tarball_url: str target_commitish: str @@ -2262,6 +2276,8 @@ def __hash__(self): @classmethod def parse_response(cls, allspice_client, result, repo) -> Release: release = super().parse_response(allspice_client, result) + Release._add_read_property("repository", repo, release) + # For legacy reasons Release._add_read_property("repo", repo, release) setattr( release, @@ -2283,8 +2299,8 @@ def request( ) -> Release: args = {"owner": owner, "repo": repo, "id": id} release_response = cls._get_gitea_api_object(allspice_client, args) - repo = Repository.request(allspice_client, owner, repo) - release = cls.parse_response(allspice_client, release_response, repo) + repository = Repository.request(allspice_client, owner, repo) + release = cls.parse_response(allspice_client, release_response, repository) return release def commit(self): @@ -2302,7 +2318,7 @@ def create_asset(self, file: IO, name: Optional[str] = None) -> ReleaseAsset: :return: The created asset. """ - args = {"files": {"attachment": file}} + args: dict[str, Any] = {"files": {"attachment": file}} if name is not None: args["params"] = {"name": name} @@ -2453,12 +2469,13 @@ def __init__(self, allspice_client): super().__init__(allspice_client) def __eq__(self, other): - if not isinstance(other, Team): + if not isinstance(other, Content): return False - return self.repo == self.repo and self.sha == other.sha and self.name == other.name + + return self.sha == other.sha and self.name == other.name def __hash__(self): - return hash(self.repo) ^ hash(self.sha) ^ hash(self.name) + return hash(self.sha) ^ hash(self.name) Ref = Union[Branch, Commit, str] diff --git a/allspice/baseapiobject.py b/allspice/baseapiobject.py index 75e7b04..06b4935 100644 --- a/allspice/baseapiobject.py +++ b/allspice/baseapiobject.py @@ -1,4 +1,14 @@ -from typing import ClassVar, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar, Mapping, Optional + +try: + from typing_extensions import Self +except ImportError: + from typing import Self + +if TYPE_CHECKING: + from allspice.allspice import AllSpice from .exceptions import MissingEqualityImplementation, ObjectIsInvalid, RawRequestEndpointMissing @@ -8,46 +18,51 @@ def __init__(self, allspice_client): self.allspice_client = allspice_client self.deleted = False # set if .delete was called, so that an exception is risen - def __str__(self): + def __str__(self) -> str: return "AllSpiceObject (%s):" % (type(self)) - def __eq__(self, other): + def __eq__(self, other) -> bool: """Compare only fields that are part of the gitea-data identity""" raise MissingEqualityImplementation() - def __hash__(self): + def __hash__(self) -> int: """Hash only fields that are part of the gitea-data identity""" raise MissingEqualityImplementation() _fields_to_parsers: ClassVar[dict] = {} + # TODO: This should probably be made an abstract function as all children + # redefine it. @classmethod - def request(cls, allspice_client): - if hasattr("API_OBJECT", cls): - return cls._request(allspice_client) - else: - raise RawRequestEndpointMissing() + def request(cls, allspice_client: AllSpice) -> Self: + # This never would've worked, so maybe we should remove this function + # outright. + return cls._request(allspice_client) @classmethod - def _request(cls, allspice_client, args): + def _request(cls, allspice_client: AllSpice, args: Mapping) -> Self: result = cls._get_gitea_api_object(allspice_client, args) api_object = cls.parse_response(allspice_client, result) return api_object @classmethod - def _get_gitea_api_object(cls, allspice_client, args): + def _get_gitea_api_object(cls, allspice_client: AllSpice, args: Mapping) -> Mapping: """Retrieving an object always as GET_API_OBJECT""" - return allspice_client.requests_get(cls.API_OBJECT.format(**args)) + if hasattr(cls, "API_OBJECT"): + raw_request_endpoint = getattr(cls, "API_OBJECT") + return allspice_client.requests_get(raw_request_endpoint.format(**args)) + else: + raise RawRequestEndpointMissing() @classmethod - def parse_response(cls, allspice_client, result) -> "ReadonlyApiObject": + def parse_response(cls, allspice_client: AllSpice, result: Mapping) -> Self: # allspice_client.logger.debug("Found api object of type %s (id: %s)" % (type(cls), id)) api_object = cls(allspice_client) cls._initialize(allspice_client, api_object, result) return api_object @classmethod - def _initialize(cls, allspice_client, api_object, result): + def _initialize(cls, allspice_client: AllSpice, api_object: Self, result: Mapping): for name, value in result.items(): if name in cls._fields_to_parsers and value is not None: parse_func = cls._fields_to_parsers[name] @@ -59,7 +74,7 @@ def _initialize(cls, allspice_client, api_object, result): cls._add_read_property(name, None, api_object) @classmethod - def _add_read_property(cls, name, value, api_object): + def _add_read_property(cls, name: str, value: Any, api_object: ReadonlyApiObject): if not hasattr(api_object, name): setattr(api_object, "_" + name, value) prop = property((lambda n: lambda self: self._get_var(n))(name)) @@ -67,7 +82,7 @@ def _add_read_property(cls, name, value, api_object): else: raise AttributeError(f"Attribute {name} already exists on api object.") - def _get_var(self, name): + def _get_var(self, name: str) -> Any: if self.deleted: raise ObjectIsInvalid() return getattr(self, "_" + name) @@ -76,21 +91,23 @@ def _get_var(self, name): class ApiObject(ReadonlyApiObject): _patchable_fields: ClassVar[set[str]] = set() - def __init__(self, allspice_client): + def __init__(self, allspice_client: AllSpice): super().__init__(allspice_client) self._dirty_fields = set() - def _commit(self, route_fields: dict, dirty_fields: Optional[dict] = None): + def _commit(self, route_fields: dict, dirty_fields: Optional[Mapping] = None): if self.deleted: raise ObjectIsInvalid() if not hasattr(self, "API_OBJECT"): raise RawRequestEndpointMissing() + raw_request_endpoint = getattr(self, "API_OBJECT") + if dirty_fields is None: dirty_fields = self.get_dirty_fields() self.allspice_client.requests_patch( - self.API_OBJECT.format(**route_fields), + raw_request_endpoint.format(**route_fields), dirty_fields, ) self._dirty_fields = set() @@ -100,7 +117,7 @@ def commit(self): _parsers_to_fields: ClassVar[dict] = {} - def get_dirty_fields(self): + def get_dirty_fields(self) -> dict[str, Any]: dirty_fields_values = {} for field in self._dirty_fields: value = getattr(self, field) @@ -111,13 +128,13 @@ def get_dirty_fields(self): return dirty_fields_values @classmethod - def _initialize(cls, allspice_client, api_object, result): + def _initialize(cls, allspice_client: AllSpice, api_object: Self, result: Mapping): super()._initialize(allspice_client, api_object, result) for name in cls._patchable_fields: cls._add_write_property(name, None, api_object) @classmethod - def _add_write_property(cls, name, value, api_object): + def _add_write_property(cls, name: str, value: Any, api_object: Self): if not hasattr(api_object, "_" + name): setattr(api_object, "_" + name, value) prop = property( @@ -126,8 +143,8 @@ def _add_write_property(cls, name, value, api_object): ) setattr(cls, name, prop) - def __set_var(self, name, i): + def __set_var(self, name: str, value: Any): if self.deleted: raise ObjectIsInvalid() self._dirty_fields.add(name) - setattr(self, "_" + name, i) + setattr(self, "_" + name, value) diff --git a/allspice/utils/bom_generation.py b/allspice/utils/bom_generation.py index 8695932..b1a85db 100644 --- a/allspice/utils/bom_generation.py +++ b/allspice/utils/bom_generation.py @@ -399,7 +399,7 @@ def generate_bom_for_system_capture( def _get_first_matching_key_value( alternatives: Union[list[str], str], - attributes: dict[str, str | None], + attributes: dict[str, str], ) -> Optional[str]: """ Search for a series of alternative keys in a dictionary, and return the @@ -437,7 +437,7 @@ def _map_attributes( def _group_entries( components: list[BomEntry], - group_by: list[str], + group_by: Optional[list[str]], columns_mapping: dict[str, ColumnConfig], ) -> list[BomEntry]: """ @@ -449,7 +449,7 @@ def _group_entries( # If grouping is off, we just add a quantity of 1 to each component and # return early. - if group_by is None: + if group_by is None or len(group_by) == 0: for component in components: component[QUANTITY_COLUMN_NAME] = "1" return components @@ -496,7 +496,7 @@ def _group_entries( return rows -def _remove_non_bom_components(components: list[dict[str, str]]) -> list[dict[str, str]]: +def _remove_non_bom_components(components: list[ComponentAttributes]) -> list[ComponentAttributes]: """ Filter out components of types that should not be included in the BOM. """ diff --git a/allspice/utils/list_components.py b/allspice/utils/list_components.py index 9511a2c..da0afde 100644 --- a/allspice/utils/list_components.py +++ b/allspice/utils/list_components.py @@ -28,7 +28,7 @@ # Maps a sheet name to a list of tuples, where each tuple is a child sheet and # the number of repetitions of that child sheet in the parent sheet. SchdocHierarchy = dict[str, list[tuple[str, int]]] -ComponentAttributes = dict[str, str | None] +ComponentAttributes = dict[str, str] class SupportedTool(Enum): @@ -52,9 +52,9 @@ def list_components( repository: Repository, source_file: str, variant: Optional[str] = None, - ref: str = "main", + ref: Ref = "main", combine_multi_part: bool = False, -) -> list[dict[str, str]]: +) -> list[ComponentAttributes]: """ Get a list of all components in a schematic. @@ -128,9 +128,9 @@ def list_components_for_altium( repository: Repository, prjpcb_file: str, variant: Optional[str] = None, - ref: str = "main", + ref: Ref = "main", combine_multi_part: bool = False, -) -> list[dict[str, str]]: +) -> list[ComponentAttributes]: """ Get a list of all components in an Altium project. @@ -170,6 +170,9 @@ def list_components_for_altium( f"Variant {variant} not found in PrjPcb file. " "Please check the name of the variant." ) + else: + # Ensuring variant_details is always bound, even if it is not used. + variant_details = None schdoc_files_in_proj = _extract_schdoc_list_from_prjpcb(prjpcb_ini) allspice_client.logger.info("Found %d SchDoc files", len(schdoc_files_in_proj)) @@ -213,6 +216,10 @@ def list_components_for_altium( components = _combine_multi_part_components_for_altium(components) if variant is not None: + if variant_details is None: + # This should never happen, but mypy doesn't know that. + raise ValueError(f"Variant {variant} not found in PrjPcb file.") + components = _apply_variations(components, variant_details, allspice_client.logger) return components @@ -222,9 +229,9 @@ def list_components_for_orcad( allspice_client: AllSpice, repository: Repository, dsn_path: str, - ref: str = "main", + ref: Ref = "main", combine_multi_part: bool = False, -) -> list[dict[str, str]]: +) -> list[ComponentAttributes]: """ Get a list of all components in an OrCAD DSN schematic. @@ -256,7 +263,7 @@ def list_components_for_system_capture( repository: Repository, sdax_path: str, ref: Ref = "main", -) -> list[dict[str, str]]: +) -> list[ComponentAttributes]: """ Get a list of all components in a System Capture SDAX schematic. @@ -310,7 +317,7 @@ def _list_components_multi_page_schematic( return components -def _fetch_generated_json(repo: Repository, file_path: str, ref: str) -> dict: +def _fetch_generated_json(repo: Repository, file_path: str, ref: Ref) -> dict: attempts = 0 while attempts < 5: try: @@ -635,10 +642,10 @@ def _extract_variations( def _apply_variations( - components: list[dict[str, str | None]], + components: list[dict[str, str]], variant_details: configparser.SectionProxy, logger: Logger, -) -> list[dict[str, str | None]]: +) -> list[dict[str, str]]: """ Apply the variations of a specific variant to the components. This should be done before the components are mapped to columns or grouped. @@ -661,7 +668,7 @@ def _apply_variations( patch_component_unique_id: dict[str, str] = {} # The keys are the same as above, and the values are a key-value of the # parameter to patch and the value to patch it to. - components_to_patch: dict[tuple[str, str], tuple[str, str]] = {} + components_to_patch: dict[tuple[str, str], list[tuple[str, str]]] = {} for key, value in variant_details.items(): # Note that this is in lowercase, as configparser stores all keys in diff --git a/allspice/utils/netlist_generation.py b/allspice/utils/netlist_generation.py index 6b55687..3a4f228 100644 --- a/allspice/utils/netlist_generation.py +++ b/allspice/utils/netlist_generation.py @@ -28,7 +28,7 @@ class NetlistEntry: pins: list[str] -Netlist = list[str] +Netlist = dict[NetlistEntry, set[str]] def generate_netlist( @@ -53,7 +53,17 @@ def generate_netlist( allspice_client.logger.info(f"Generating netlist for {repository.name=} on {ref=}") allspice_client.logger.info(f"Fetching {pcb_file=}") - pcb_components = _extract_all_pcb_components(allspice_client.logger, repository, ref, pcb_file) + if isinstance(pcb_file, Content): + pcb_file_path = pcb_file.path + else: + pcb_file_path = pcb_file + + pcb_components = _extract_all_pcb_components( + allspice_client.logger, + repository, + ref, + pcb_file_path, + ) return _group_netlist_entries(pcb_components) @@ -96,7 +106,7 @@ def _extract_all_pcb_components( return components -def _group_netlist_entries(components: list[PcbComponent]) -> dict[NetlistEntry]: +def _group_netlist_entries(components: list[PcbComponent]) -> dict[NetlistEntry, set[str]]: """ Group connected pins by the net """ diff --git a/requirements-test.txt b/requirements-test.txt index 9171e24..0ce2ded 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -3,7 +3,10 @@ diff_cover~=9.1 libcst~=1.4.0 MonkeyType~=23.3 pdoc~=14.5 +pyright~=1.1 pytest-recording~=0.13 pytest~=8.2 ruff~=0.5 -syrupy~=4.6 \ No newline at end of file +syrupy~=4.6 +types-requests~=2.32 +typing_extensions~=4.11; python_version < "3.11" diff --git a/scripts/generate_attribute_types.py b/scripts/generate_attribute_types.py index ad2680c..59ec9b2 100755 --- a/scripts/generate_attribute_types.py +++ b/scripts/generate_attribute_types.py @@ -28,6 +28,7 @@ import monkeytype.stubs import monkeytype.typing import pytest +from libcst import codemod from libcst.codemod.visitors import ApplyTypeAnnotationsVisitor HOOKED_FUNCTION_NAME = "parse_response" @@ -305,7 +306,7 @@ def main(): stub_module = cst.parse_module(module_stub.render()) source_module = cst.parse_module(source) - context = cst.codemod.CodemodContext() + context = codemod.CodemodContext() CustomApplyTypeAnnotationsVisitor.store_stub_in_context( context, stub_module, diff --git a/scripts/typecheck.sh b/scripts/typecheck.sh new file mode 100755 index 0000000..f271eda --- /dev/null +++ b/scripts/typecheck.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e +set -u + +git ls-files | grep '\.py' | xargs pyright