diff --git a/scripts/idl/lint/lint_rules_parser.py b/scripts/idl/lint/lint_rules_parser.py index 7e1d69cc200061..4ba475ecc193c5 100755 --- a/scripts/idl/lint/lint_rules_parser.py +++ b/scripts/idl/lint/lint_rules_parser.py @@ -12,12 +12,12 @@ import traceback try: - from .types import RequiredAttributesRule, AttributeRequirement, ClusterRequirement + from .types import RequiredAttributesRule, AttributeRequirement, ClusterRequirement, RequiredCommandsRule, ClusterCommandRequirement except: import sys sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), "..", "..")) - from idl.lint.types import RequiredAttributesRule, AttributeRequirement, ClusterRequirement + from idl.lint.types import RequiredAttributesRule, AttributeRequirement, ClusterRequirement, RequiredCommandsRule, ClusterCommandRequirement def parseNumberString(n): @@ -33,11 +33,18 @@ class RequiredAttribute: code: int +@dataclass +class RequiredCommand: + name: str + code: int + + @dataclass class DecodedCluster: name: str code: int required_attributes: List[RequiredAttribute] + required_commands: List[RequiredCommand] def DecodeClusterFromXml(element: xml.etree.ElementTree.Element): @@ -53,6 +60,7 @@ def DecodeClusterFromXml(element: xml.etree.ElementTree.Element): try: name = element.find('name').text.replace(' ', '') required_attributes = [] + required_commands = [] for attr in element.findall('attribute'): if attr.attrib['side'] != 'server': @@ -61,16 +69,34 @@ def DecodeClusterFromXml(element: xml.etree.ElementTree.Element): if 'optional' in attr.attrib and attr.attrib['optional'] == 'true': continue + # when introducing access controls, the content of attributes may either be: + # myName + # or + # myName... + attr_name = attr.text + if attr.find('description') is not None: + attr_name = attr.find('description').text + required_attributes.append( RequiredAttribute( - name=attr.text, + name=attr_name, code=parseNumberString(attr.attrib['code']) )) + for cmd in element.findall('command'): + if cmd.attrib['source'] != 'client': + continue + + if 'optional' in cmd.attrib and cmd.attrib['optional'] == 'true': + continue + + required_commands.append(RequiredCommand(name=cmd.attrib["name"], code=parseNumberString(cmd.attrib['code']))) + return DecodedCluster( name=name, code=parseNumberString(element.find('code').text), required_attributes=required_attributes, + required_commands=required_commands ) except Exception as e: logging.exception("Failed to decode cluster %r" % element) @@ -98,16 +124,17 @@ class LintRulesContext: """ def __init__(self): - self._linter_rule = RequiredAttributesRule("Rules file") + self._required_attributes_rule = RequiredAttributesRule("Required attributes") + self._required_commands_rule = RequiredCommandsRule("Required commands") # Map cluster names to the underlying code self._cluster_codes: Mapping[str, int] = {} def GetLinterRules(self): - return [self._linter_rule] + return [self._required_attributes_rule, self._required_commands_rule] def RequireAttribute(self, r: AttributeRequirement): - self._linter_rule.RequireAttribute(r) + self._required_attributes_rule.RequireAttribute(r) def RequireClusterInEndpoint(self, name: str, code: int): """Mark that a specific cluster is always required in the given endpoint @@ -117,14 +144,14 @@ def RequireClusterInEndpoint(self, name: str, code: int): logging.error("Known names: %s" % (",".join(self._cluster_codes.keys()), )) return - self._linter_rule.RequireClusterInEndpoint(ClusterRequirement( + self._required_attributes_rule.RequireClusterInEndpoint(ClusterRequirement( endpoint_id=code, - cluster_id=self._cluster_codes[name], + cluster_code=self._cluster_codes[name], cluster_name=name, )) def LoadXml(self, path: str): - """Load XML data from the given path and add it to + """Load XML data from the given path and add it to internal processing. Adds attribute requirement rules as needed. """ @@ -137,15 +164,21 @@ def LoadXml(self, path: str): self._cluster_codes[decoded.name] = decoded.code for attr in decoded.required_attributes: - self._linter_rule.RequireAttribute(AttributeRequirement( + self._required_attributes_rule.RequireAttribute(AttributeRequirement( code=attr.code, name=attr.name, filter_cluster=decoded.code)) - # TODO: add cluster ID to internal registry + for cmd in decoded.required_commands: + self._required_commands_rule.RequireCommand( + ClusterCommandRequirement( + cluster_code=decoded.code, + command_code=cmd.code, + command_name=cmd.name + )) class LintRulesTransformer(Transformer): """ - A transformer capable to transform data parsed by Lark according to + A transformer capable to transform data parsed by Lark according to lint_rules_grammar.lark. """ diff --git a/scripts/idl/lint/types.py b/scripts/idl/lint/types.py index 5222a4a0bc083e..85dfbdeacd0bc3 100644 --- a/scripts/idl/lint/types.py +++ b/scripts/idl/lint/types.py @@ -71,16 +71,45 @@ class AttributeRequirement: @dataclass class ClusterRequirement: endpoint_id: int - cluster_id: int + cluster_code: int cluster_name: str -class RequiredAttributesRule(LintRule): +class ErrorAccumulatingRule(LintRule): + """Contains a lint error list and helps helpers to add to such a list of rules.""" + def __init__(self, name): - super(RequiredAttributesRule, self).__init__(name) + super(ErrorAccumulatingRule, self).__init__(name) self._lint_errors = [] self._idl = None + def _AddLintError(self, text, location): + self._lint_errors.append(LintError("%s: %s" % (self.name, text), location)) + + def _ParseLocation(self, meta: Optional[ParseMetaData]) -> Optional[LocationInFile]: + """Create a location in the current file that is being parsed. """ + if not meta or not self._idl.parse_file_name: + return None + return LocationInFile(self._idl.parse_file_name, meta) + + def LintIdl(self, idl: Idl) -> List[LintError]: + self._idl = idl + self._lint_errors = [] + self._LintImpl() + return self._lint_errors + + @abstractmethod + def _LintImpl(self): + """Implements actual linting of the IDL. + + Uses the underlying _idl for validation. + """ + pass + + +class RequiredAttributesRule(ErrorAccumulatingRule): + def __init__(self, name): + super(RequiredAttributesRule, self).__init__(name) # Map attribute code to name self._mandatory_attributes: List[AttributeRequirement] = [] self._mandatory_clusters: List[ClusterRequirement] = [] @@ -93,6 +122,11 @@ def __repr__(self): for attr in self._mandatory_attributes: result += " - %r\n" % attr + if self._mandatory_clusters: + result += " mandatory_clusters:\n" + for cluster in self._mandatory_clusters: + result += " - %r\n" % cluster + result += "}" return result @@ -103,15 +137,6 @@ def RequireAttribute(self, attr: AttributeRequirement): def RequireClusterInEndpoint(self, requirement: ClusterRequirement): self._mandatory_clusters.append(requirement) - def _ParseLocation(self, meta: Optional[ParseMetaData]) -> Optional[LocationInFile]: - """Create a location in the current file that is being parsed. """ - if not meta or not self._idl.parse_file_name: - return None - return LocationInFile(self._idl.parse_file_name, meta) - - def _AddLintError(self, text, location): - self._lint_errors.append(LintError("%s: %s" % (self.name, text), location)) - def _ServerClusterDefinition(self, name: str, location: Optional[LocationInFile]): """Finds the server cluster definition with the given name. @@ -173,12 +198,62 @@ def _LintImpl(self): if requirement.endpoint_id != endpoint.number: continue - if requirement.cluster_id not in cluster_codes: + if requirement.cluster_code not in cluster_codes: self._AddLintError("Endpoint %d does not expose cluster %s (%d)" % - (requirement.endpoint_id, requirement.cluster_name, requirement.cluster_id), location=None) + (requirement.endpoint_id, requirement.cluster_name, requirement.cluster_code), location=None) - def LintIdl(self, idl: Idl) -> List[LintError]: - self._idl = idl - self._lint_errors = [] - self._LintImpl() - return self._lint_errors + +@dataclass +class ClusterCommandRequirement: + cluster_code: int + command_code: int + command_name: str + + +class RequiredCommandsRule(ErrorAccumulatingRule): + def __init__(self, name): + super(RequiredCommandsRule, self).__init__(name) + + # Maps cluster id to mandatory cluster requirement + self._mandatory_commands: Maping[int, List[ClusterCommandRequirement]] = {} + + def __repr__(self): + result = "RequiredCommandsRule{\n" + + if self._mandatory_commands: + result += " mandatory_commands:\n" + for key, value in self._mandatory_commands.items(): + result += " - cluster %d:\n" % key + for requirement in value: + result += " - %r\n" % requirement + + result += "}" + return result + + def RequireCommand(self, cmd: ClusterCommandRequirement): + """Mark a command required""" + + if cmd.cluster_code in self._mandatory_commands: + self._mandatory_commands[cmd.cluster_code].append(cmd) + else: + self._mandatory_commands[cmd.cluster_code] = [cmd] + + def _LintImpl(self): + for cluster in self._idl.clusters: + if cluster.side != ClusterSide.SERVER: + continue # only validate server-side: + + if cluster.code not in self._mandatory_commands: + continue # no known mandatory commands + + defined_commands = set([c.code for c in cluster.commands]) + + for requirement in self._mandatory_commands[cluster.code]: + if requirement.command_code in defined_commands: + continue # command exists + + self._AddLintError( + "Cluster %s does not define mandatory command %s(%d)" % ( + cluster.name, requirement.command_name, requirement.command_code), + self._ParseLocation(cluster.parse_meta) + )