diff --git a/scripts/idl/lint/lint_rules_parser.py b/scripts/idl/lint/lint_rules_parser.py
index 7e1d69cc200061..4ba475ecc193c5 100755
--- a/scripts/idl/lint/lint_rules_parser.py
+++ b/scripts/idl/lint/lint_rules_parser.py
@@ -12,12 +12,12 @@
import traceback
try:
- from .types import RequiredAttributesRule, AttributeRequirement, ClusterRequirement
+ from .types import RequiredAttributesRule, AttributeRequirement, ClusterRequirement, RequiredCommandsRule, ClusterCommandRequirement
except:
import sys
sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), "..", ".."))
- from idl.lint.types import RequiredAttributesRule, AttributeRequirement, ClusterRequirement
+ from idl.lint.types import RequiredAttributesRule, AttributeRequirement, ClusterRequirement, RequiredCommandsRule, ClusterCommandRequirement
def parseNumberString(n):
@@ -33,11 +33,18 @@ class RequiredAttribute:
code: int
+@dataclass
+class RequiredCommand:
+ name: str
+ code: int
+
+
@dataclass
class DecodedCluster:
name: str
code: int
required_attributes: List[RequiredAttribute]
+ required_commands: List[RequiredCommand]
def DecodeClusterFromXml(element: xml.etree.ElementTree.Element):
@@ -53,6 +60,7 @@ def DecodeClusterFromXml(element: xml.etree.ElementTree.Element):
try:
name = element.find('name').text.replace(' ', '')
required_attributes = []
+ required_commands = []
for attr in element.findall('attribute'):
if attr.attrib['side'] != 'server':
@@ -61,16 +69,34 @@ def DecodeClusterFromXml(element: xml.etree.ElementTree.Element):
if 'optional' in attr.attrib and attr.attrib['optional'] == 'true':
continue
+ # when introducing access controls, the content of attributes may either be:
+ # myName
+ # or
+ # myName...
+ attr_name = attr.text
+ if attr.find('description') is not None:
+ attr_name = attr.find('description').text
+
required_attributes.append(
RequiredAttribute(
- name=attr.text,
+ name=attr_name,
code=parseNumberString(attr.attrib['code'])
))
+ for cmd in element.findall('command'):
+ if cmd.attrib['source'] != 'client':
+ continue
+
+ if 'optional' in cmd.attrib and cmd.attrib['optional'] == 'true':
+ continue
+
+ required_commands.append(RequiredCommand(name=cmd.attrib["name"], code=parseNumberString(cmd.attrib['code'])))
+
return DecodedCluster(
name=name,
code=parseNumberString(element.find('code').text),
required_attributes=required_attributes,
+ required_commands=required_commands
)
except Exception as e:
logging.exception("Failed to decode cluster %r" % element)
@@ -98,16 +124,17 @@ class LintRulesContext:
"""
def __init__(self):
- self._linter_rule = RequiredAttributesRule("Rules file")
+ self._required_attributes_rule = RequiredAttributesRule("Required attributes")
+ self._required_commands_rule = RequiredCommandsRule("Required commands")
# Map cluster names to the underlying code
self._cluster_codes: Mapping[str, int] = {}
def GetLinterRules(self):
- return [self._linter_rule]
+ return [self._required_attributes_rule, self._required_commands_rule]
def RequireAttribute(self, r: AttributeRequirement):
- self._linter_rule.RequireAttribute(r)
+ self._required_attributes_rule.RequireAttribute(r)
def RequireClusterInEndpoint(self, name: str, code: int):
"""Mark that a specific cluster is always required in the given endpoint
@@ -117,14 +144,14 @@ def RequireClusterInEndpoint(self, name: str, code: int):
logging.error("Known names: %s" % (",".join(self._cluster_codes.keys()), ))
return
- self._linter_rule.RequireClusterInEndpoint(ClusterRequirement(
+ self._required_attributes_rule.RequireClusterInEndpoint(ClusterRequirement(
endpoint_id=code,
- cluster_id=self._cluster_codes[name],
+ cluster_code=self._cluster_codes[name],
cluster_name=name,
))
def LoadXml(self, path: str):
- """Load XML data from the given path and add it to
+ """Load XML data from the given path and add it to
internal processing. Adds attribute requirement rules
as needed.
"""
@@ -137,15 +164,21 @@ def LoadXml(self, path: str):
self._cluster_codes[decoded.name] = decoded.code
for attr in decoded.required_attributes:
- self._linter_rule.RequireAttribute(AttributeRequirement(
+ self._required_attributes_rule.RequireAttribute(AttributeRequirement(
code=attr.code, name=attr.name, filter_cluster=decoded.code))
- # TODO: add cluster ID to internal registry
+ for cmd in decoded.required_commands:
+ self._required_commands_rule.RequireCommand(
+ ClusterCommandRequirement(
+ cluster_code=decoded.code,
+ command_code=cmd.code,
+ command_name=cmd.name
+ ))
class LintRulesTransformer(Transformer):
"""
- A transformer capable to transform data parsed by Lark according to
+ A transformer capable to transform data parsed by Lark according to
lint_rules_grammar.lark.
"""
diff --git a/scripts/idl/lint/types.py b/scripts/idl/lint/types.py
index 5222a4a0bc083e..85dfbdeacd0bc3 100644
--- a/scripts/idl/lint/types.py
+++ b/scripts/idl/lint/types.py
@@ -71,16 +71,45 @@ class AttributeRequirement:
@dataclass
class ClusterRequirement:
endpoint_id: int
- cluster_id: int
+ cluster_code: int
cluster_name: str
-class RequiredAttributesRule(LintRule):
+class ErrorAccumulatingRule(LintRule):
+ """Contains a lint error list and helps helpers to add to such a list of rules."""
+
def __init__(self, name):
- super(RequiredAttributesRule, self).__init__(name)
+ super(ErrorAccumulatingRule, self).__init__(name)
self._lint_errors = []
self._idl = None
+ def _AddLintError(self, text, location):
+ self._lint_errors.append(LintError("%s: %s" % (self.name, text), location))
+
+ def _ParseLocation(self, meta: Optional[ParseMetaData]) -> Optional[LocationInFile]:
+ """Create a location in the current file that is being parsed. """
+ if not meta or not self._idl.parse_file_name:
+ return None
+ return LocationInFile(self._idl.parse_file_name, meta)
+
+ def LintIdl(self, idl: Idl) -> List[LintError]:
+ self._idl = idl
+ self._lint_errors = []
+ self._LintImpl()
+ return self._lint_errors
+
+ @abstractmethod
+ def _LintImpl(self):
+ """Implements actual linting of the IDL.
+
+ Uses the underlying _idl for validation.
+ """
+ pass
+
+
+class RequiredAttributesRule(ErrorAccumulatingRule):
+ def __init__(self, name):
+ super(RequiredAttributesRule, self).__init__(name)
# Map attribute code to name
self._mandatory_attributes: List[AttributeRequirement] = []
self._mandatory_clusters: List[ClusterRequirement] = []
@@ -93,6 +122,11 @@ def __repr__(self):
for attr in self._mandatory_attributes:
result += " - %r\n" % attr
+ if self._mandatory_clusters:
+ result += " mandatory_clusters:\n"
+ for cluster in self._mandatory_clusters:
+ result += " - %r\n" % cluster
+
result += "}"
return result
@@ -103,15 +137,6 @@ def RequireAttribute(self, attr: AttributeRequirement):
def RequireClusterInEndpoint(self, requirement: ClusterRequirement):
self._mandatory_clusters.append(requirement)
- def _ParseLocation(self, meta: Optional[ParseMetaData]) -> Optional[LocationInFile]:
- """Create a location in the current file that is being parsed. """
- if not meta or not self._idl.parse_file_name:
- return None
- return LocationInFile(self._idl.parse_file_name, meta)
-
- def _AddLintError(self, text, location):
- self._lint_errors.append(LintError("%s: %s" % (self.name, text), location))
-
def _ServerClusterDefinition(self, name: str, location: Optional[LocationInFile]):
"""Finds the server cluster definition with the given name.
@@ -173,12 +198,62 @@ def _LintImpl(self):
if requirement.endpoint_id != endpoint.number:
continue
- if requirement.cluster_id not in cluster_codes:
+ if requirement.cluster_code not in cluster_codes:
self._AddLintError("Endpoint %d does not expose cluster %s (%d)" %
- (requirement.endpoint_id, requirement.cluster_name, requirement.cluster_id), location=None)
+ (requirement.endpoint_id, requirement.cluster_name, requirement.cluster_code), location=None)
- def LintIdl(self, idl: Idl) -> List[LintError]:
- self._idl = idl
- self._lint_errors = []
- self._LintImpl()
- return self._lint_errors
+
+@dataclass
+class ClusterCommandRequirement:
+ cluster_code: int
+ command_code: int
+ command_name: str
+
+
+class RequiredCommandsRule(ErrorAccumulatingRule):
+ def __init__(self, name):
+ super(RequiredCommandsRule, self).__init__(name)
+
+ # Maps cluster id to mandatory cluster requirement
+ self._mandatory_commands: Maping[int, List[ClusterCommandRequirement]] = {}
+
+ def __repr__(self):
+ result = "RequiredCommandsRule{\n"
+
+ if self._mandatory_commands:
+ result += " mandatory_commands:\n"
+ for key, value in self._mandatory_commands.items():
+ result += " - cluster %d:\n" % key
+ for requirement in value:
+ result += " - %r\n" % requirement
+
+ result += "}"
+ return result
+
+ def RequireCommand(self, cmd: ClusterCommandRequirement):
+ """Mark a command required"""
+
+ if cmd.cluster_code in self._mandatory_commands:
+ self._mandatory_commands[cmd.cluster_code].append(cmd)
+ else:
+ self._mandatory_commands[cmd.cluster_code] = [cmd]
+
+ def _LintImpl(self):
+ for cluster in self._idl.clusters:
+ if cluster.side != ClusterSide.SERVER:
+ continue # only validate server-side:
+
+ if cluster.code not in self._mandatory_commands:
+ continue # no known mandatory commands
+
+ defined_commands = set([c.code for c in cluster.commands])
+
+ for requirement in self._mandatory_commands[cluster.code]:
+ if requirement.command_code in defined_commands:
+ continue # command exists
+
+ self._AddLintError(
+ "Cluster %s does not define mandatory command %s(%d)" % (
+ cluster.name, requirement.command_name, requirement.command_code),
+ self._ParseLocation(cluster.parse_meta)
+ )