From 3b66ae55835cacd9085eb37014a95e4a6aca3515 Mon Sep 17 00:00:00 2001 From: Svein Seldal Date: Wed, 10 Jul 2024 12:34:14 +0200 Subject: [PATCH] Add mypy support and fixup project to give no errors * Permissive mypy configuration as starting point * Add minimal type annotations to get no mypy errors * Add runtime test for self.network before using the network * Network.add_node() doesn't accept LocalNode * PeriodicMessageTask.update() don't stop the task unless its running * Variable.desc ensure that the object is int * Variable.read() fail with ValueError unless a valid fmt is used * Variable.write() ensure the description is a string * BaseNode.__init__() fail if no node_id is provided * ObjectDictionary.__getitem__() when splitting "." only return if the object is not an ODVariable * ODRecord.__eq__(), ODArray.__eq__() and ODVariable.__eq__() test type of other before comparing * ODVariable.encode_raw(), .decode_phys(), .encode_phys() add type tests of ensure the input is of correct type * PdoMap various methods: ensure necessary attributes are set --- canopen/emcy.py | 7 ++- canopen/network.py | 54 +++++++++++------- canopen/nmt.py | 23 ++++++-- canopen/node/base.py | 11 ++-- canopen/node/local.py | 28 ++++++++-- canopen/node/remote.py | 8 +-- canopen/objectdictionary/__init__.py | 83 ++++++++++++++++++---------- canopen/pdo/base.py | 33 ++++++++++- canopen/sdo/base.py | 13 +++-- canopen/variable.py | 18 ++++-- pyproject.toml | 9 +++ 11 files changed, 204 insertions(+), 83 deletions(-) diff --git a/canopen/emcy.py b/canopen/emcy.py index 1bbdeb75..266d8015 100644 --- a/canopen/emcy.py +++ b/canopen/emcy.py @@ -1,3 +1,4 @@ +from __future__ import annotations import struct import logging import threading @@ -52,7 +53,7 @@ def reset(self): def wait( self, emcy_code: Optional[int] = None, timeout: float = 10 - ) -> "EmcyError": + ) -> Optional[EmcyError]: """Wait for a new EMCY to arrive. :param emcy_code: EMCY code to wait for @@ -86,10 +87,14 @@ def __init__(self, cob_id: int): self.cob_id = cob_id def send(self, code: int, register: int = 0, data: bytes = b""): + if self.network is None: + raise RuntimeError("A Network is required") payload = EMCY_STRUCT.pack(code, register, data) self.network.send_message(self.cob_id, payload) def reset(self, register: int = 0, data: bytes = b""): + if self.network is None: + raise RuntimeError("A Network is required") payload = EMCY_STRUCT.pack(0, register, data) self.network.send_message(self.cob_id, payload) diff --git a/canopen/network.py b/canopen/network.py index f6446208..cebc1e6d 100644 --- a/canopen/network.py +++ b/canopen/network.py @@ -3,18 +3,21 @@ from collections.abc import MutableMapping import logging import threading -from typing import Callable, Dict, Iterator, List, Optional, Union +from typing import Callable, Dict, Iterator, List, Optional, Union, TYPE_CHECKING, TextIO try: import can from can import Listener from can import CanError except ImportError: - # Do not fail if python-can is not installed - can = None - CanError = Exception - class Listener: - """ Dummy listener """ + # Type checkers don't like this conditional logic, so it is only run when + # not type checking + if not TYPE_CHECKING: + # Do not fail if python-can is not installed + can = None + CanError = Exception + class Listener: + """ Dummy listener """ from canopen.node import RemoteNode, LocalNode from canopen.sync import SyncProducer @@ -24,6 +27,9 @@ class Listener: from canopen.objectdictionary.eds import import_from_node from canopen.objectdictionary import ObjectDictionary +if TYPE_CHECKING: + from can.typechecking import CanData + logger = logging.getLogger(__name__) Callback = Callable[[int, bytearray, float], None] @@ -45,7 +51,7 @@ def __init__(self, bus: Optional[can.BusABC] = None): #: List of :class:`can.Listener` objects. #: Includes at least MessageListener. self.listeners = [MessageListener(self)] - self.notifier = None + self.notifier: Optional[can.Notifier] = None self.nodes: Dict[int, Union[RemoteNode, LocalNode]] = {} self.subscribers: Dict[int, List[Callback]] = {} self.send_lock = threading.Lock() @@ -138,15 +144,15 @@ def __exit__(self, type, value, traceback): def add_node( self, - node: Union[int, RemoteNode, LocalNode], - object_dictionary: Union[str, ObjectDictionary, None] = None, + node: Union[int, RemoteNode], + object_dictionary: Union[str, ObjectDictionary, TextIO, None] = None, upload_eds: bool = False, ) -> RemoteNode: """Add a remote node to the network. :param node: Can be either an integer representing the node ID, a - :class:`canopen.RemoteNode` or :class:`canopen.LocalNode` object. + :class:`canopen.RemoteNode` object. :param object_dictionary: Can be either a string for specifying the path to an Object Dictionary file or a @@ -161,14 +167,16 @@ def add_node( if upload_eds: logger.info("Trying to read EDS from node %d", node) object_dictionary = import_from_node(node, self) - node = RemoteNode(node, object_dictionary) - self[node.id] = node - return node + nodeobj = RemoteNode(node, object_dictionary) + else: + nodeobj = node + self[nodeobj.id] = nodeobj + return nodeobj def create_node( self, node: int, - object_dictionary: Union[str, ObjectDictionary, None] = None, + object_dictionary: Union[str, ObjectDictionary, TextIO, None] = None, ) -> LocalNode: """Create a local node in the network. @@ -183,11 +191,13 @@ def create_node( The Node object that was added. """ if isinstance(node, int): - node = LocalNode(node, object_dictionary) - self[node.id] = node - return node + nodeobj = LocalNode(node, object_dictionary) + else: + nodeobj = node + self[nodeobj.id] = nodeobj + return nodeobj - def send_message(self, can_id: int, data: bytes, remote: bool = False) -> None: + def send_message(self, can_id: int, data: CanData, remote: bool = False) -> None: """Send a raw CAN message to the network. This method may be overridden in a subclass if you need to integrate @@ -215,7 +225,7 @@ def send_message(self, can_id: int, data: bytes, remote: bool = False) -> None: self.check() def send_periodic( - self, can_id: int, data: bytes, period: float, remote: bool = False + self, can_id: int, data: CanData, period: float, remote: bool = False ) -> PeriodicMessageTask: """Start sending a message periodically. @@ -295,7 +305,7 @@ class PeriodicMessageTask: def __init__( self, can_id: int, - data: bytes, + data: CanData, period: float, bus, remote: bool = False, @@ -335,10 +345,12 @@ def update(self, data: bytes) -> None: old_data = self.msg.data self.msg.data = new_data if hasattr(self._task, "modify_data"): + assert self._task is not None # This will never be None, but mypy needs this self._task.modify_data(self.msg) elif new_data != old_data: # Stop and start (will mess up period unfortunately) - self._task.stop() + if self._task is not None: + self._task.stop() self._start() diff --git a/canopen/nmt.py b/canopen/nmt.py index 8ce737ea..d9e54a20 100644 --- a/canopen/nmt.py +++ b/canopen/nmt.py @@ -2,7 +2,10 @@ import logging import struct import time -from typing import Callable, Optional +from typing import Callable, Optional, List, TYPE_CHECKING + +if TYPE_CHECKING: + from canopen.network import Network, PeriodicMessageTask logger = logging.getLogger(__name__) @@ -45,7 +48,7 @@ class NmtBase: def __init__(self, node_id: int): self.id = node_id - self.network = None + self.network: Optional[Network] = None self._state = 0 def on_command(self, can_id, data, timestamp): @@ -107,11 +110,11 @@ class NmtMaster(NmtBase): def __init__(self, node_id: int): super(NmtMaster, self).__init__(node_id) self._state_received = None - self._node_guarding_producer = None + self._node_guarding_producer: Optional[PeriodicMessageTask] = None #: Timestamp of last heartbeat message self.timestamp: Optional[float] = None self.state_update = threading.Condition() - self._callbacks = [] + self._callbacks: List[Callable[[int], None]] = [] def on_heartbeat(self, can_id, data, timestamp): with self.state_update: @@ -139,6 +142,8 @@ def send_command(self, code: int): super(NmtMaster, self).send_command(code) logger.info( "Sending NMT command 0x%X to node %d", code, self.id) + if self.network is None: + raise RuntimeError("A Network is required") self.network.send_message(0, [code, self.id]) def wait_for_heartbeat(self, timeout: float = 10): @@ -181,7 +186,9 @@ def start_node_guarding(self, period: float): Period (in seconds) at which the node guarding should be advertised to the slave node. """ if self._node_guarding_producer : self.stop_node_guarding() - self._node_guarding_producer = self.network.send_periodic(0x700 + self.id, None, period, True) + if self.network is None: + raise RuntimeError("A Network is required") + self._node_guarding_producer = self.network.send_periodic(0x700 + self.id, [], period, True) def stop_node_guarding(self): """Stops the node guarding mechanism.""" @@ -197,7 +204,7 @@ class NmtSlave(NmtBase): def __init__(self, node_id: int, local_node): super(NmtSlave, self).__init__(node_id) - self._send_task = None + self._send_task: Optional[PeriodicMessageTask] = None self._heartbeat_time_ms = 0 self._local_node = local_node @@ -216,6 +223,8 @@ def send_command(self, code: int) -> None: if self._state == 0: logger.info("Sending boot-up message") + if self.network is None: + raise RuntimeError("A Network is required") self.network.send_message(0x700 + self.id, [0]) # The heartbeat service should start on the transition @@ -246,6 +255,8 @@ def start_heartbeat(self, heartbeat_time_ms: int): self.stop_heartbeat() if heartbeat_time_ms > 0: logger.info("Start the heartbeat timer, interval is %d ms", self._heartbeat_time_ms) + if self.network is None: + raise RuntimeError("A network is required") self._send_task = self.network.send_periodic( 0x700 + self.id, [self._state], heartbeat_time_ms / 1000.0) diff --git a/canopen/node/base.py b/canopen/node/base.py index bf72d959..9d182398 100644 --- a/canopen/node/base.py +++ b/canopen/node/base.py @@ -1,4 +1,4 @@ -from typing import TextIO, Union +from typing import TextIO, Union, Optional from canopen.objectdictionary import ObjectDictionary, import_od @@ -14,8 +14,8 @@ class BaseNode: def __init__( self, - node_id: int, - object_dictionary: Union[ObjectDictionary, str, TextIO], + node_id: Optional[int], + object_dictionary: Union[ObjectDictionary, str, TextIO, None], ): self.network = None @@ -23,4 +23,7 @@ def __init__( object_dictionary = import_od(object_dictionary, node_id) self.object_dictionary = object_dictionary - self.id = node_id or self.object_dictionary.node_id + node_id = node_id or self.object_dictionary.node_id + if node_id is None: + raise ValueError("Node ID must be specified") + self.id: int = node_id diff --git a/canopen/node/local.py b/canopen/node/local.py index eb74b98d..148a3ede 100644 --- a/canopen/node/local.py +++ b/canopen/node/local.py @@ -1,29 +1,45 @@ import logging -from typing import Dict, Union +from typing import Dict, Union, List, Protocol, TextIO, Optional from canopen.node.base import BaseNode from canopen.sdo import SdoServer, SdoAbortedError from canopen.pdo import PDO, TPDO, RPDO from canopen.nmt import NmtSlave from canopen.emcy import EmcyProducer -from canopen.objectdictionary import ObjectDictionary +from canopen.objectdictionary import ObjectDictionary, ODVariable from canopen import objectdictionary logger = logging.getLogger(__name__) +class WriteCallback(Protocol): + """LocalNode Write Callback Protocol""" + def __call__(self, *, index: int, subindex: int, + od: ODVariable, + data: bytes) -> None: + ''' Write Callback ''' + + +class ReadCallback(Protocol): + """LocalNode Read Callback Protocol""" + def __call__(self, *, index: int, subindex: int, + od: ODVariable + ) -> Union[bool, int, float, str, bytes, None]: + ''' Read Callback ''' + + class LocalNode(BaseNode): def __init__( self, - node_id: int, - object_dictionary: Union[ObjectDictionary, str], + node_id: Optional[int], + object_dictionary: Union[ObjectDictionary, str, TextIO, None], ): super(LocalNode, self).__init__(node_id, object_dictionary) self.data_store: Dict[int, Dict[int, bytes]] = {} - self._read_callbacks = [] - self._write_callbacks = [] + self._read_callbacks: List[ReadCallback] = [] + self._write_callbacks: List[WriteCallback] = [] self.sdo = SdoServer(0x600 + self.id, 0x580 + self.id, self) self.tpdo = TPDO(self) diff --git a/canopen/node/remote.py b/canopen/node/remote.py index 4f3281db..9301a92d 100644 --- a/canopen/node/remote.py +++ b/canopen/node/remote.py @@ -1,5 +1,5 @@ import logging -from typing import Union, TextIO +from typing import Union, TextIO, List, Optional from canopen.sdo import SdoClient, SdoCommunicationError, SdoAbortedError from canopen.nmt import NmtMaster @@ -26,8 +26,8 @@ class RemoteNode(BaseNode): def __init__( self, - node_id: int, - object_dictionary: Union[ObjectDictionary, str, TextIO], + node_id: Optional[int], + object_dictionary: Union[ObjectDictionary, str, TextIO, None], load_od: bool = False, ): super(RemoteNode, self).__init__(node_id, object_dictionary) @@ -35,7 +35,7 @@ def __init__( #: Enable WORKAROUND for reversed PDO mapping entries self.curtis_hack = False - self.sdo_channels = [] + self.sdo_channels: List[SdoClient] = [] self.sdo = self.add_sdo(0x600 + self.id, 0x580 + self.id) self.tpdo = TPDO(self) self.rpdo = RPDO(self) diff --git a/canopen/objectdictionary/__init__.py b/canopen/objectdictionary/__init__.py index 1e80283b..45429129 100644 --- a/canopen/objectdictionary/__init__.py +++ b/canopen/objectdictionary/__init__.py @@ -4,7 +4,7 @@ from __future__ import annotations import struct -from typing import Dict, Iterator, List, Optional, TextIO, Union +from typing import Dict, Iterator, List, Optional, TextIO, Union, cast from collections.abc import MutableMapping, Mapping import logging @@ -67,7 +67,7 @@ def export_od( finally: # If dest is opened in this fn, it should be closed if opened_here: - dest.close() + cast(TextIO, dest).close() # The cast is needed to help the type checker def import_od( @@ -90,7 +90,7 @@ def import_od( return ObjectDictionary() if hasattr(source, "read"): # File like object - filename = source.name + filename = cast(TextIO, source).name elif hasattr(source, "tag"): # XML tree, probably from an EPF file filename = "od.epf" @@ -135,7 +135,9 @@ def __getitem__( if item is None: if isinstance(index, str) and '.' in index: idx, sub = index.split('.', maxsplit=1) - return self[idx][sub] + var = self[idx] + if not isinstance(var, ODVariable): + return var[sub] raise KeyError(f"{pretty_index(index)} was not found in Object Dictionary") return item @@ -156,7 +158,7 @@ def __iter__(self) -> Iterator[int]: def __len__(self) -> int: return len(self.indices) - def __contains__(self, index: Union[int, str]): + def __contains__(self, index: object): return index in self.names or index in self.indices def add_object(self, obj: Union[ODArray, ODRecord, ODVariable]) -> None: @@ -184,6 +186,7 @@ def get_variable( return obj elif isinstance(obj, (ODRecord, ODArray)): return obj.get(subindex) + return None class ODRecord(MutableMapping): @@ -203,14 +206,17 @@ def __init__(self, name: str, index: int): self.name = name #: Storage location of index self.storage_location = None - self.subindices = {} - self.names = {} + self.subindices: Dict[int, ODVariable] = {} + self.names: Dict[str, ODVariable] = {} def __repr__(self) -> str: return f"<{type(self).__qualname__} {self.name!r} at {pretty_index(self.index)}>" def __getitem__(self, subindex: Union[int, str]) -> ODVariable: - item = self.names.get(subindex) or self.subindices.get(subindex) + if isinstance(subindex, str): + item = self.names.get(subindex) + else: + item = self.subindices.get(subindex) if item is None: raise KeyError(f"Subindex {pretty_index(None, subindex)} was not found") return item @@ -230,11 +236,11 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[int]: return iter(sorted(self.subindices)) - def __contains__(self, subindex: Union[int, str]) -> bool: + def __contains__(self, subindex: object) -> bool: return subindex in self.names or subindex in self.subindices - def __eq__(self, other: ODRecord) -> bool: - return self.index == other.index + def __eq__(self, other: object) -> bool: + return isinstance(other, type(self)) and self.index == other.index def add_member(self, variable: ODVariable) -> None: """Adds a :class:`~canopen.objectdictionary.ODVariable` to the record.""" @@ -262,14 +268,17 @@ def __init__(self, name: str, index: int): self.name = name #: Storage location of index self.storage_location = None - self.subindices = {} - self.names = {} + self.subindices: Dict[int, ODVariable] = {} + self.names: Dict[str, ODVariable] = {} def __repr__(self) -> str: return f"<{type(self).__qualname__} {self.name!r} at {pretty_index(self.index)}>" def __getitem__(self, subindex: Union[int, str]) -> ODVariable: - var = self.names.get(subindex) or self.subindices.get(subindex) + if isinstance(subindex, str): + var = self.names.get(subindex) + else: + var = self.subindices.get(subindex) if var is not None: # This subindex is defined pass @@ -294,8 +303,8 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[int]: return iter(sorted(self.subindices)) - def __eq__(self, other: ODArray) -> bool: - return self.index == other.index + def __eq__(self, other: object) -> bool: + return isinstance(other, type(self)) and self.index == other.index def add_member(self, variable: ODVariable) -> None: """Adds a :class:`~canopen.objectdictionary.ODVariable` to the record.""" @@ -335,7 +344,7 @@ def __init__(self, name: str, index: int, subindex: int = 0): #: The :class:`~canopen.ObjectDictionary`, #: :class:`~canopen.objectdictionary.ODRecord` or #: :class:`~canopen.objectdictionary.ODArray` owning the variable - self.parent = None + self.parent: Union[ObjectDictionary, ODArray, ODRecord, None] = None #: 16-bit address of the object in the dictionary self.index = index #: 8-bit sub-index of the object in the dictionary @@ -383,8 +392,9 @@ def qualname(self) -> str: return f"{self.parent.name}.{self.name}" return self.name - def __eq__(self, other: ODVariable) -> bool: - return (self.index == other.index and + def __eq__(self, other: object) -> bool: + return (isinstance(other, type(self)) and + self.index == other.index and self.subindex == other.subindex) def __len__(self) -> int: @@ -441,12 +451,21 @@ def encode_raw(self, value: Union[int, float, str, bytes, bytearray]) -> bytes: if isinstance(value, (bytes, bytearray)): return value elif self.data_type == VISIBLE_STRING: + if not isinstance(value, str): + raise TypeError(f"Value of type {type(value)!r} doesn't match VISIBLE_STRING") return value.encode("ascii") elif self.data_type == UNICODE_STRING: + if not isinstance(value, str): + raise TypeError(f"Value of type {type(value)!r} doesn't match UNICODE_STRING") return value.encode("utf_16_le") elif self.data_type in (DOMAIN, OCTET_STRING): + if not isinstance(value, (bytes, bytearray)): + t = "DOMAIN" if self.data_type == DOMAIN else "OCTET_STRING" + raise TypeError(f"Value of type {type(value)!r} doesn't match {t}") return bytes(value) elif self.data_type in self.STRUCT_TYPES: + if not isinstance(value, (bool, int, float)): + raise TypeError(f"Value of type {type(value)!r} is unexpected for numeric types") if self.data_type in INTEGER_TYPES: value = int(value) if self.data_type in NUMBER_TYPES: @@ -467,13 +486,17 @@ def encode_raw(self, value: Union[int, float, str, bytes, bytearray]) -> bytes: raise TypeError( f"Do not know how to encode {value!r} to data type 0x{self.data_type:X}") - def decode_phys(self, value: int) -> Union[int, bool, float, str, bytes]: + def decode_phys(self, value: Union[int, bool, float, str, bytes]) -> Union[int, bool, float, str, bytes]: if self.data_type in INTEGER_TYPES: + if not isinstance(value, (int, float)): + raise TypeError(f"Value of type {type(value)!r} is unexpected for numeric types") value *= self.factor return value - def encode_phys(self, value: Union[int, bool, float, str, bytes]) -> int: + def encode_phys(self, value: Union[int, bool, float, str, bytes]) -> Union[int, bool, float, str, bytes]: if self.data_type in INTEGER_TYPES: + if not isinstance(value, (int, float)): + raise TypeError(f"Value of type {type(value)!r} is unexpected for numeric types") value /= self.factor value = int(round(value)) return value @@ -498,27 +521,29 @@ def encode_desc(self, desc: str) -> int: raise ValueError( f"No value corresponds to '{desc}'. Valid values are: {valid_values}") - def decode_bits(self, value: int, bits: List[int]) -> int: + def decode_bits(self, value: int, bits: Union[range, str, List[int]]) -> int: try: - bits = self.bit_definitions[bits] + bits = self.bit_definitions[cast(str, bits)] except (TypeError, KeyError): pass mask = 0 - for bit in bits: + lbits = cast(List[int], bits) + for bit in lbits: mask |= 1 << bit - return (value & mask) >> min(bits) + return (value & mask) >> min(lbits) - def encode_bits(self, original_value: int, bits: List[int], bit_value: int): + def encode_bits(self, original_value: int, bits: Union[range, str, List[int]], bit_value: int): try: - bits = self.bit_definitions[bits] + bits = self.bit_definitions[cast(str, bits)] except (TypeError, KeyError): pass temp = original_value mask = 0 - for bit in bits: + lbits = cast(List[int], bits) + for bit in lbits: mask |= 1 << bit temp &= ~mask - temp |= bit_value << min(bits) + temp |= bit_value << min(lbits) return temp diff --git a/canopen/pdo/base.py b/canopen/pdo/base.py index f2a7d205..8af16ad1 100644 --- a/canopen/pdo/base.py +++ b/canopen/pdo/base.py @@ -338,7 +338,7 @@ def read(self, from_od=False) -> None: DCF value will be used, otherwise the EDS default will be used instead. """ - def _raw_from(param): + def _raw_from(param) -> int: if from_od: if param.od.value is not None: return param.od.value @@ -464,6 +464,10 @@ def subscribe(self) -> None: known to match what's stored on the node. """ if self.enabled: + if self.pdo_node.network is None: + raise RuntimeError("A Network is required") + if self.cob_id is None: + raise RuntimeError("A valid COB-ID is required") logger.info("Subscribing to enabled PDO 0x%X on the network", self.cob_id) self.pdo_node.network.subscribe(self.cob_id, self.on_message) @@ -511,6 +515,10 @@ def add_variable( def transmit(self) -> None: """Transmit the message once.""" + if self.pdo_node.network is None: + raise RuntimeError("A Network is required") + if self.cob_id is None: + raise RuntimeError("A valid COB-ID is required") self.pdo_node.network.send_message(self.cob_id, self.data) def start(self, period: Optional[float] = None) -> None: @@ -521,6 +529,11 @@ def start(self, period: Optional[float] = None) -> None: on the object before. :raises ValueError: When neither the argument nor the :attr:`period` is given. """ + if self.pdo_node.network is None: + raise RuntimeError("A Network is required") + if self.cob_id is None: + raise RuntimeError("A valid COB-ID is required") + # Stop an already running transmission if we have one, otherwise we # overwrite the reference and can lose our handle to shut it down self.stop() @@ -551,9 +564,13 @@ def remote_request(self) -> None: Silently ignore if not allowed. """ if self.enabled and self.rtr_allowed: + if self.pdo_node.network is None: + raise RuntimeError("A Network is required") + if self.cob_id is None: + raise RuntimeError("A valid COB-ID is required") self.pdo_node.network.send_message(self.cob_id, bytes(), remote=True) - def wait_for_reception(self, timeout: float = 10) -> float: + def wait_for_reception(self, timeout: float = 10) -> Optional[float]: """Wait for the next transmit PDO. :param float timeout: Max time to wait in seconds. @@ -581,6 +598,12 @@ def get_data(self) -> bytes: :return: PdoVariable value as :class:`bytes`. """ + # FIXME TYPING: These asserts are for type checking. More robust errors + # should be raised if these are not set. + assert self.offset is not None + assert self.pdo_parent is not None + assert self.od.data_type is not None + byte_offset, bit_offset = divmod(self.offset, 8) if bit_offset or self.length % 8: @@ -608,6 +631,12 @@ def set_data(self, data: bytes): :param data: Value for the PDO variable in the PDO message. """ + # FIXME TYPING: These asserts are for type checking. More robust errors + # should be raised if these are not set. + assert self.offset is not None + assert self.pdo_parent is not None + assert self.od.data_type is not None + byte_offset, bit_offset = divmod(self.offset, 8) logger.debug("Updating %s to %s in %s", self.name, binascii.hexlify(data), self.pdo_parent.name) diff --git a/canopen/sdo/base.py b/canopen/sdo/base.py index 0bb068b4..bfae583b 100644 --- a/canopen/sdo/base.py +++ b/canopen/sdo/base.py @@ -1,7 +1,7 @@ from __future__ import annotations import binascii -from typing import Iterator, Optional, Union +from typing import Iterator, Optional, Union, cast from collections.abc import Mapping from canopen import objectdictionary @@ -63,7 +63,7 @@ def __iter__(self) -> Iterator[int]: def __len__(self) -> int: return len(self.od) - def __contains__(self, key: Union[int, str]) -> bool: + def __contains__(self, key: object) -> bool: return key in self.od def get_variable( @@ -78,6 +78,7 @@ def get_variable( return obj elif isinstance(obj, (SdoRecord, SdoArray)): return obj.get(subindex) + return None def upload(self, index: int, subindex: int) -> bytes: raise NotImplementedError() @@ -110,7 +111,7 @@ def __iter__(self) -> Iterator[int]: def __len__(self) -> int: return len(self.od) - def __contains__(self, subindex: Union[int, str]) -> bool: + def __contains__(self, subindex: object) -> bool: return subindex in self.od @@ -130,10 +131,10 @@ def __iter__(self) -> Iterator[int]: return iter(range(1, len(self) + 1)) def __len__(self) -> int: - return self[0].raw + return cast(int, self[0].raw) - def __contains__(self, subindex: int) -> bool: - return 0 <= subindex <= len(self) + def __contains__(self, subindex: object) -> bool: + return isinstance(subindex, int) and 0 <= subindex <= len(self) class SdoVariable(variable.Variable): diff --git a/canopen/variable.py b/canopen/variable.py index 3ec67c79..7683721f 100644 --- a/canopen/variable.py +++ b/canopen/variable.py @@ -1,5 +1,5 @@ import logging -from typing import Union +from typing import List, Union, cast from collections.abc import Mapping from canopen import objectdictionary @@ -77,7 +77,7 @@ def raw(self) -> Union[int, bool, float, str, bytes]: value = self.od.decode_raw(self.data) text = f"Value of {self.name!r} ({pretty_index(self.index, self.subindex)}) is {value!r}" if value in self.od.value_descriptions: - text += f" ({self.od.value_descriptions[value]})" + text += f" ({self.od.value_descriptions[cast(int, value)]})" logger.debug(text) return value @@ -108,7 +108,10 @@ def phys(self, value: Union[int, bool, float, str, bytes]): @property def desc(self) -> str: """Converts to and from a description of the value as a string.""" - value = self.od.decode_desc(self.raw) + raw = self.raw + if not isinstance(raw, int): + raise TypeError("Description can only be used with integer values") + value = self.od.decode_desc(raw) logger.debug("Description is '%s'", value) return value @@ -141,6 +144,7 @@ def read(self, fmt: str = "raw") -> Union[int, bool, float, str, bytes]: return self.phys elif fmt == "desc": return self.desc + raise ValueError(f"Uknown format {fmt!r}") def write( self, value: Union[int, bool, float, str, bytes], fmt: str = "raw" @@ -160,17 +164,23 @@ def write( elif fmt == "phys": self.phys = value elif fmt == "desc": + if not isinstance(value, str): + raise TypeError("Description must be a string") self.desc = value class Bits(Mapping): + # Attribute type (since not defined in __init__) + raw: int + def __init__(self, variable: Variable): self.variable = variable self.read() @staticmethod - def _get_bits(key): + def _get_bits(key: Union[int, str, slice]) -> Union[range, List[int], str]: + bits: Union[range, List[int], str] if isinstance(key, slice): bits = range(key.start, key.stop, key.step) elif isinstance(key, int): diff --git a/pyproject.toml b/pyproject.toml index a87ba589..876b8637 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,3 +48,12 @@ testpaths = [ filterwarnings = [ "ignore::DeprecationWarning", ] + +[tool.mypy] +files = "canopen" +strict = "False" +ignore_missing_imports = "True" +disable_error_code = [ + "annotation-unchecked", +] +