Skip to content

Commit

Permalink
Add support for validating mandatory commands in matter idl (#19117)
Browse files Browse the repository at this point in the history
* Add support for validating mandatory commands in matter idl

* Restyle

* Undo extra space on decoractors

* Fix bug in finding attribute names

* Restyle
  • Loading branch information
andy31415 authored and pull[bot] committed Jan 19, 2024
1 parent 764ce96 commit 2392172
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 31 deletions.
57 changes: 45 additions & 12 deletions scripts/idl/lint/lint_rules_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
import traceback

try:
from .types import RequiredAttributesRule, AttributeRequirement, ClusterRequirement
from .types import RequiredAttributesRule, AttributeRequirement, ClusterRequirement, RequiredCommandsRule, ClusterCommandRequirement
except:
import sys

sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), "..", ".."))
from idl.lint.types import RequiredAttributesRule, AttributeRequirement, ClusterRequirement
from idl.lint.types import RequiredAttributesRule, AttributeRequirement, ClusterRequirement, RequiredCommandsRule, ClusterCommandRequirement


def parseNumberString(n):
Expand All @@ -33,11 +33,18 @@ class RequiredAttribute:
code: int


@dataclass
class RequiredCommand:
name: str
code: int


@dataclass
class DecodedCluster:
name: str
code: int
required_attributes: List[RequiredAttribute]
required_commands: List[RequiredCommand]


def DecodeClusterFromXml(element: xml.etree.ElementTree.Element):
Expand All @@ -53,6 +60,7 @@ def DecodeClusterFromXml(element: xml.etree.ElementTree.Element):
try:
name = element.find('name').text.replace(' ', '')
required_attributes = []
required_commands = []

for attr in element.findall('attribute'):
if attr.attrib['side'] != 'server':
Expand All @@ -61,16 +69,34 @@ def DecodeClusterFromXml(element: xml.etree.ElementTree.Element):
if 'optional' in attr.attrib and attr.attrib['optional'] == 'true':
continue

# when introducing access controls, the content of attributes may either be:
# <attribute ...>myName</attribute>
# or
# <attribute ...><description>myName</description><access .../>...</attribute>
attr_name = attr.text
if attr.find('description') is not None:
attr_name = attr.find('description').text

required_attributes.append(
RequiredAttribute(
name=attr.text,
name=attr_name,
code=parseNumberString(attr.attrib['code'])
))

for cmd in element.findall('command'):
if cmd.attrib['source'] != 'client':
continue

if 'optional' in cmd.attrib and cmd.attrib['optional'] == 'true':
continue

required_commands.append(RequiredCommand(name=cmd.attrib["name"], code=parseNumberString(cmd.attrib['code'])))

return DecodedCluster(
name=name,
code=parseNumberString(element.find('code').text),
required_attributes=required_attributes,
required_commands=required_commands
)
except Exception as e:
logging.exception("Failed to decode cluster %r" % element)
Expand Down Expand Up @@ -98,16 +124,17 @@ class LintRulesContext:
"""

def __init__(self):
self._linter_rule = RequiredAttributesRule("Rules file")
self._required_attributes_rule = RequiredAttributesRule("Required attributes")
self._required_commands_rule = RequiredCommandsRule("Required commands")

# Map cluster names to the underlying code
self._cluster_codes: Mapping[str, int] = {}

def GetLinterRules(self):
return [self._linter_rule]
return [self._required_attributes_rule, self._required_commands_rule]

def RequireAttribute(self, r: AttributeRequirement):
self._linter_rule.RequireAttribute(r)
self._required_attributes_rule.RequireAttribute(r)

def RequireClusterInEndpoint(self, name: str, code: int):
"""Mark that a specific cluster is always required in the given endpoint
Expand All @@ -117,14 +144,14 @@ def RequireClusterInEndpoint(self, name: str, code: int):
logging.error("Known names: %s" % (",".join(self._cluster_codes.keys()), ))
return

self._linter_rule.RequireClusterInEndpoint(ClusterRequirement(
self._required_attributes_rule.RequireClusterInEndpoint(ClusterRequirement(
endpoint_id=code,
cluster_id=self._cluster_codes[name],
cluster_code=self._cluster_codes[name],
cluster_name=name,
))

def LoadXml(self, path: str):
"""Load XML data from the given path and add it to
"""Load XML data from the given path and add it to
internal processing. Adds attribute requirement rules
as needed.
"""
Expand All @@ -137,15 +164,21 @@ def LoadXml(self, path: str):
self._cluster_codes[decoded.name] = decoded.code

for attr in decoded.required_attributes:
self._linter_rule.RequireAttribute(AttributeRequirement(
self._required_attributes_rule.RequireAttribute(AttributeRequirement(
code=attr.code, name=attr.name, filter_cluster=decoded.code))

# TODO: add cluster ID to internal registry
for cmd in decoded.required_commands:
self._required_commands_rule.RequireCommand(
ClusterCommandRequirement(
cluster_code=decoded.code,
command_code=cmd.code,
command_name=cmd.name
))


class LintRulesTransformer(Transformer):
"""
A transformer capable to transform data parsed by Lark according to
A transformer capable to transform data parsed by Lark according to
lint_rules_grammar.lark.
"""

Expand Down
113 changes: 94 additions & 19 deletions scripts/idl/lint/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,45 @@ class AttributeRequirement:
@dataclass
class ClusterRequirement:
endpoint_id: int
cluster_id: int
cluster_code: int
cluster_name: str


class RequiredAttributesRule(LintRule):
class ErrorAccumulatingRule(LintRule):
"""Contains a lint error list and helps helpers to add to such a list of rules."""

def __init__(self, name):
super(RequiredAttributesRule, self).__init__(name)
super(ErrorAccumulatingRule, self).__init__(name)
self._lint_errors = []
self._idl = None

def _AddLintError(self, text, location):
self._lint_errors.append(LintError("%s: %s" % (self.name, text), location))

def _ParseLocation(self, meta: Optional[ParseMetaData]) -> Optional[LocationInFile]:
"""Create a location in the current file that is being parsed. """
if not meta or not self._idl.parse_file_name:
return None
return LocationInFile(self._idl.parse_file_name, meta)

def LintIdl(self, idl: Idl) -> List[LintError]:
self._idl = idl
self._lint_errors = []
self._LintImpl()
return self._lint_errors

@abstractmethod
def _LintImpl(self):
"""Implements actual linting of the IDL.
Uses the underlying _idl for validation.
"""
pass


class RequiredAttributesRule(ErrorAccumulatingRule):
def __init__(self, name):
super(RequiredAttributesRule, self).__init__(name)
# Map attribute code to name
self._mandatory_attributes: List[AttributeRequirement] = []
self._mandatory_clusters: List[ClusterRequirement] = []
Expand All @@ -93,6 +122,11 @@ def __repr__(self):
for attr in self._mandatory_attributes:
result += " - %r\n" % attr

if self._mandatory_clusters:
result += " mandatory_clusters:\n"
for cluster in self._mandatory_clusters:
result += " - %r\n" % cluster

result += "}"
return result

Expand All @@ -103,15 +137,6 @@ def RequireAttribute(self, attr: AttributeRequirement):
def RequireClusterInEndpoint(self, requirement: ClusterRequirement):
self._mandatory_clusters.append(requirement)

def _ParseLocation(self, meta: Optional[ParseMetaData]) -> Optional[LocationInFile]:
"""Create a location in the current file that is being parsed. """
if not meta or not self._idl.parse_file_name:
return None
return LocationInFile(self._idl.parse_file_name, meta)

def _AddLintError(self, text, location):
self._lint_errors.append(LintError("%s: %s" % (self.name, text), location))

def _ServerClusterDefinition(self, name: str, location: Optional[LocationInFile]):
"""Finds the server cluster definition with the given name.
Expand Down Expand Up @@ -173,12 +198,62 @@ def _LintImpl(self):
if requirement.endpoint_id != endpoint.number:
continue

if requirement.cluster_id not in cluster_codes:
if requirement.cluster_code not in cluster_codes:
self._AddLintError("Endpoint %d does not expose cluster %s (%d)" %
(requirement.endpoint_id, requirement.cluster_name, requirement.cluster_id), location=None)
(requirement.endpoint_id, requirement.cluster_name, requirement.cluster_code), location=None)

def LintIdl(self, idl: Idl) -> List[LintError]:
self._idl = idl
self._lint_errors = []
self._LintImpl()
return self._lint_errors

@dataclass
class ClusterCommandRequirement:
cluster_code: int
command_code: int
command_name: str


class RequiredCommandsRule(ErrorAccumulatingRule):
def __init__(self, name):
super(RequiredCommandsRule, self).__init__(name)

# Maps cluster id to mandatory cluster requirement
self._mandatory_commands: Maping[int, List[ClusterCommandRequirement]] = {}

def __repr__(self):
result = "RequiredCommandsRule{\n"

if self._mandatory_commands:
result += " mandatory_commands:\n"
for key, value in self._mandatory_commands.items():
result += " - cluster %d:\n" % key
for requirement in value:
result += " - %r\n" % requirement

result += "}"
return result

def RequireCommand(self, cmd: ClusterCommandRequirement):
"""Mark a command required"""

if cmd.cluster_code in self._mandatory_commands:
self._mandatory_commands[cmd.cluster_code].append(cmd)
else:
self._mandatory_commands[cmd.cluster_code] = [cmd]

def _LintImpl(self):
for cluster in self._idl.clusters:
if cluster.side != ClusterSide.SERVER:
continue # only validate server-side:

if cluster.code not in self._mandatory_commands:
continue # no known mandatory commands

defined_commands = set([c.code for c in cluster.commands])

for requirement in self._mandatory_commands[cluster.code]:
if requirement.command_code in defined_commands:
continue # command exists

self._AddLintError(
"Cluster %s does not define mandatory command %s(%d)" % (
cluster.name, requirement.command_name, requirement.command_code),
self._ParseLocation(cluster.parse_meta)
)

0 comments on commit 2392172

Please sign in to comment.