diff --git a/src/python_testing/TC_DeviceBasicComposition.py b/src/python_testing/TC_DeviceBasicComposition.py index 2288c7c5a4b596..bb85b045923d06 100644 --- a/src/python_testing/TC_DeviceBasicComposition.py +++ b/src/python_testing/TC_DeviceBasicComposition.py @@ -15,108 +15,25 @@ # limitations under the License. # -import base64 -import copy -import functools -import json import logging -import pathlib -import sys -from collections import defaultdict -from dataclasses import dataclass, field -from pprint import pprint -from typing import Any, Callable, Optional +from dataclasses import dataclass +from typing import Any, Callable import chip.clusters as Clusters import chip.clusters.ClusterObjects import chip.tlv +from basic_composition_support import BasicCompositionTests from chip.clusters.Attribute import ValueDecodeFailure from chip.tlv import uint from conformance_support import ConformanceDecision, conformance_allowed +from global_attribute_ids import GlobalAttributeIds from matter_testing_support import (AttributePathLocation, ClusterPathLocation, CommandPathLocation, MatterBaseTest, async_test_body, default_matter_test_main) from mobly import asserts from spec_parsing_support import CommandType, build_xml_clusters - -ATTRIBUTE_LIST_ID = 0xFFFB -ACCEPTED_COMMAND_LIST_ID = 0xFFF9 -GENERATED_COMMAND_LIST_ID = 0xFFF8 -FEATURE_MAP_ID = 0xFFFC -CLUSTER_REVISION_ID = 0xFFFD - - -def MatterTlvToJson(tlv_data: dict[int, Any]) -> dict[str, Any]: - """Given TLV data for a specific cluster instance, convert to the Matter JSON format.""" - - matter_json_dict = {} - - key_type_mappings = { - chip.tlv.uint: "UINT", - int: "INT", - bool: "BOOL", - list: "ARRAY", - dict: "STRUCT", - chip.tlv.float32: "FLOAT", - float: "DOUBLE", - bytes: "BYTES", - str: "STRING", - ValueDecodeFailure: "ERROR", - type(None): "NULL", - } - - def ConvertValue(value) -> Any: - if isinstance(value, ValueDecodeFailure): - raise ValueError(f"Bad Value: {str(value)}") - - if isinstance(value, bytes): - return base64.b64encode(value).decode("UTF-8") - elif isinstance(value, list): - value = [ConvertValue(item) for item in value] - elif isinstance(value, dict): - value = MatterTlvToJson(value) - - return value - - for key in tlv_data: - value_type = type(tlv_data[key]) - value = copy.deepcopy(tlv_data[key]) - - element_type: str = key_type_mappings[value_type] - sub_element_type = "" - - try: - new_value = ConvertValue(value) - except ValueError as e: - new_value = str(e) - - if element_type: - if element_type == "ARRAY": - if len(new_value): - sub_element_type = key_type_mappings[type(tlv_data[key][0])] - else: - sub_element_type = "?" - - new_key = "" - if element_type: - if sub_element_type: - new_key = f"{str(key)}:{element_type}-{sub_element_type}" - else: - new_key = f"{str(key)}:{element_type}" - else: - new_key = str(key) - - matter_json_dict[new_key] = new_value - - return matter_json_dict - - -@dataclass -class TagProblem: - root: int - missing_attribute: bool - missing_feature: bool - duplicates: set[int] - same_tag: set[int] = field(default_factory=set) +from taglist_and_topology_test_support import (create_device_type_list_for_root, create_device_type_lists, find_tag_list_problems, + find_tree_roots, get_all_children, get_direct_children_of_root, parts_list_cycles, + separate_endpoint_types) def check_int_in_range(min_value: int, max_value: int, allow_null: bool = False) -> Callable: @@ -186,232 +103,11 @@ def check_no_duplicates(obj: Any) -> None: raise ValueError(f"Value {str(obj)} contains duplicate values") -def separate_endpoint_types(endpoint_dict: dict[int, Any]) -> tuple[list[int], list[int]]: - """Returns a tuple containing the list of flat endpoints and a list of tree endpoints""" - flat = [] - tree = [] - for endpoint_id, endpoint in endpoint_dict.items(): - if endpoint_id == 0: - continue - aggregator_id = 0x000e - content_app_id = 0x0024 - device_types = [d.deviceType for d in endpoint[Clusters.Descriptor][Clusters.Descriptor.Attributes.DeviceTypeList]] - if aggregator_id in device_types: - flat.append(endpoint_id) - else: - if content_app_id in device_types: - continue - tree.append(endpoint_id) - return (flat, tree) - - -def get_all_children(endpoint_id, endpoint_dict: dict[int, Any]) -> set[int]: - """Returns all the children (include subchildren) of the given endpoint - This assumes we've already checked that there are no cycles, so we can do the dumb things and just trace the tree - """ - children = set() - - def add_children(endpoint_id, children): - immediate_children = endpoint_dict[endpoint_id][Clusters.Descriptor][Clusters.Descriptor.Attributes.PartsList] - if not immediate_children: - return - children.update(set(immediate_children)) - for child in immediate_children: - add_children(child, children) - - add_children(endpoint_id, children) - return children - - -def find_tree_roots(tree_endpoints: list[int], endpoint_dict: dict[int, Any]) -> set[int]: - """Returns a set of all the endpoints in tree_endpoints that are roots for a tree (not include singletons)""" - tree_roots = set() - - def find_tree_root(current_id): - for endpoint_id, endpoint in endpoint_dict.items(): - if endpoint_id not in tree_endpoints: - continue - if current_id in endpoint[Clusters.Descriptor][Clusters.Descriptor.Attributes.PartsList]: - # this is not the root, move up - return find_tree_root(endpoint_id) - return current_id - - for endpoint_id in tree_endpoints: - root = find_tree_root(endpoint_id) - if root != endpoint_id: - tree_roots.add(root) - return tree_roots - - -def parts_list_cycles(tree_endpoints: list[int], endpoint_dict: dict[int, Any]) -> list[int]: - """Returns a list of all the endpoints in the tree_endpoints list that contain cycles""" - def parts_list_cycle_detect(visited: set, current_id: int) -> bool: - if current_id in visited: - return True - visited.add(current_id) - for child in endpoint_dict[current_id][Clusters.Descriptor][Clusters.Descriptor.Attributes.PartsList]: - child_has_cycles = parts_list_cycle_detect(visited, child) - if child_has_cycles: - return True - return False - - cycles = [] - # This is quick enough that we can do all the endpoints wihtout searching for the roots - for endpoint_id in tree_endpoints: - visited = set() - if parts_list_cycle_detect(visited, endpoint_id): - cycles.append(endpoint_id) - return cycles - - -def create_device_type_lists(roots: list[int], endpoint_dict: dict[int, Any]) -> dict[int, dict[int, set[int]]]: - """Returns a list of endpoints per device type for each root in the list""" - device_types = {} - for root in roots: - tree_device_types = defaultdict(set) - eps = get_all_children(root, endpoint_dict) - eps.add(root) - for ep in eps: - for d in endpoint_dict[ep][Clusters.Descriptor][Clusters.Descriptor.Attributes.DeviceTypeList]: - tree_device_types[d.deviceType].add(ep) - device_types[root] = tree_device_types - - return device_types - - -def get_direct_children_of_root(endpoint_dict: dict[int, Any]) -> set[int]: - root_children = set(endpoint_dict[0][Clusters.Descriptor][Clusters.Descriptor.Attributes.PartsList]) - direct_children = root_children - for ep in root_children: - ep_children = set(endpoint_dict[ep][Clusters.Descriptor][Clusters.Descriptor.Attributes.PartsList]) - direct_children = direct_children - ep_children - return direct_children - - -def create_device_type_list_for_root(direct_children, endpoint_dict: dict[int, Any]) -> dict[int, set[int]]: - device_types = defaultdict(set) - for ep in direct_children: - for d in endpoint_dict[ep][Clusters.Descriptor][Clusters.Descriptor.Attributes.DeviceTypeList]: - device_types[d.deviceType].add(ep) - return device_types - - -def cmp_tag_list(a: Clusters.Descriptor.Structs.SemanticTagStruct, b: Clusters.Descriptor.Structs.SemanticTagStruct): - if a.mfgCode != b.mfgCode: - return -1 if a.mfgCode < b.mfgCode else 1 - if a.namespaceID != b.namespaceID: - return -1 if a.namespaceID < b.namespaceID else 1 - if a.tag != b.tag: - return -1 if a.tag < b.tag else 1 - if a.label != b.label: - return -1 if a.label < b.label else 1 - return 0 - - -def find_tag_list_problems(roots: list[int], device_types: dict[int, dict[int, set[int]]], endpoint_dict: dict[int, Any]) -> dict[int, TagProblem]: - """Checks for non-spec compliant tag lists""" - tag_problems = {} - for root in roots: - for _, endpoints in device_types[root].items(): - if len(endpoints) < 2: - continue - for endpoint in endpoints: - missing_feature = not bool(endpoint_dict[endpoint][Clusters.Descriptor] - [Clusters.Descriptor.Attributes.FeatureMap] & Clusters.Descriptor.Bitmaps.Feature.kTagList) - if Clusters.Descriptor.Attributes.TagList not in endpoint_dict[endpoint][Clusters.Descriptor] or endpoint_dict[endpoint][Clusters.Descriptor][Clusters.Descriptor.Attributes.TagList] == []: - tag_problems[endpoint] = TagProblem(root=root, missing_attribute=True, - missing_feature=missing_feature, duplicates=endpoints) - continue - # Check that this tag isn't the same as the other tags in the endpoint list - duplicate_tags = set() - for other in endpoints: - if other == endpoint: - continue - # The OTHER endpoint is missing a tag list attribute - ignore this here, we'll catch that when we assess this endpoint as the primary - if Clusters.Descriptor.Attributes.TagList not in endpoint_dict[other][Clusters.Descriptor]: - continue - - if sorted(endpoint_dict[endpoint][Clusters.Descriptor][Clusters.Descriptor.Attributes.TagList], key=functools.cmp_to_key(cmp_tag_list)) == sorted(endpoint_dict[other][Clusters.Descriptor][Clusters.Descriptor.Attributes.TagList], key=functools.cmp_to_key(cmp_tag_list)): - duplicate_tags.add(other) - if len(duplicate_tags) != 0: - duplicate_tags.add(endpoint) - tag_problems[endpoint] = TagProblem(root=root, missing_attribute=False, missing_feature=missing_feature, - duplicates=endpoints, same_tag=duplicate_tags) - continue - if missing_feature: - tag_problems[endpoint] = TagProblem(root=root, missing_attribute=False, - missing_feature=missing_feature, duplicates=endpoints) - - return tag_problems - - -class TC_DeviceBasicComposition(MatterBaseTest): +class TC_DeviceBasicComposition(MatterBaseTest, BasicCompositionTests): @async_test_body async def setup_class(self): super().setup_class() - dev_ctrl = self.default_controller - self.problems = [] - - do_test_over_pase = self.user_params.get("use_pase_only", True) - dump_device_composition_path: Optional[str] = self.user_params.get("dump_device_composition_path", None) - - if do_test_over_pase: - info = self.get_setup_payload_info() - - commissionable_nodes = dev_ctrl.DiscoverCommissionableNodes( - info.filter_type, info.filter_value, stopOnFirst=True, timeoutSecond=15) - logging.info(f"Commissionable nodes: {commissionable_nodes}") - # TODO: Support BLE - if commissionable_nodes is not None and len(commissionable_nodes) > 0: - commissionable_node = commissionable_nodes[0] - instance_name = f"{commissionable_node.instanceName}._matterc._udp.local" - vid = f"{commissionable_node.vendorId}" - pid = f"{commissionable_node.productId}" - address = f"{commissionable_node.addresses[0]}" - logging.info(f"Found instance {instance_name}, VID={vid}, PID={pid}, Address={address}") - - node_id = 1 - dev_ctrl.EstablishPASESessionIP(address, info.passcode, node_id) - else: - asserts.fail("Failed to find the DUT according to command line arguments.") - else: - # Using the already commissioned node - node_id = self.dut_node_id - - wildcard_read = (await dev_ctrl.Read(node_id, [()])) - endpoints_tlv = wildcard_read.tlvAttributes - - node_dump_dict = {endpoint_id: MatterTlvToJson(endpoints_tlv[endpoint_id]) for endpoint_id in endpoints_tlv} - logging.debug(f"Raw TLV contents of Node: {json.dumps(node_dump_dict, indent=2)}") - - if dump_device_composition_path is not None: - with open(pathlib.Path(dump_device_composition_path).with_suffix(".json"), "wt+") as outfile: - json.dump(node_dump_dict, outfile, indent=2) - with open(pathlib.Path(dump_device_composition_path).with_suffix(".txt"), "wt+") as outfile: - pprint(wildcard_read.attributes, outfile, indent=1, width=200, compact=True) - - logging.info("###########################################################") - logging.info("Start of actual tests") - logging.info("###########################################################") - - # ======= State kept for use by all tests ======= - - # All endpoints in "full object" indexing format - self.endpoints = wildcard_read.attributes - - # All endpoints in raw TLV format - self.endpoints_tlv = wildcard_read.tlvAttributes - - def get_test_name(self) -> str: - """Return the function name of the caller. Used to create logging entries.""" - return sys._getframe().f_back.f_code.co_name - - def fail_current_test(self, msg: Optional[str] = None): - if not msg: - # Without a message, just log the last problem seen - asserts.fail(msg=self.problems[-1].problem) - else: - asserts.fail(msg) + await self.setup_class_helper() # ======= START OF ACTUAL TESTS ======= def test_TC_SM_1_1(self): @@ -482,15 +178,17 @@ class RequiredMandatoryAttribute: validators: list[Callable] ATTRIBUTES_TO_CHECK = [ - RequiredMandatoryAttribute(id=CLUSTER_REVISION_ID, name="ClusterRevision", validators=[check_int_in_range(1, 0xFFFF)]), - RequiredMandatoryAttribute(id=FEATURE_MAP_ID, name="FeatureMap", validators=[check_int_in_range(0, 0xFFFF_FFFF)]), - RequiredMandatoryAttribute(id=ATTRIBUTE_LIST_ID, name="AttributeList", + RequiredMandatoryAttribute(id=GlobalAttributeIds.CLUSTER_REVISION_ID, name="ClusterRevision", + validators=[check_int_in_range(1, 0xFFFF)]), + RequiredMandatoryAttribute(id=GlobalAttributeIds.FEATURE_MAP_ID, name="FeatureMap", + validators=[check_int_in_range(0, 0xFFFF_FFFF)]), + RequiredMandatoryAttribute(id=GlobalAttributeIds.ATTRIBUTE_LIST_ID, name="AttributeList", validators=[check_non_empty_list_of_ints_in_range(0, 0xFFFF_FFFF), check_no_duplicates]), # TODO: Check for EventList # RequiredMandatoryAttribute(id=0xFFFA, name="EventList", validator=check_list_of_ints_in_range(0, 0xFFFF_FFFF)), - RequiredMandatoryAttribute(id=ACCEPTED_COMMAND_LIST_ID, name="AcceptedCommandList", + RequiredMandatoryAttribute(id=GlobalAttributeIds.ACCEPTED_COMMAND_LIST_ID, name="AcceptedCommandList", validators=[check_list_of_ints_in_range(0, 0xFFFF_FFFF), check_no_duplicates]), - RequiredMandatoryAttribute(id=GENERATED_COMMAND_LIST_ID, name="GeneratedCommandList", + RequiredMandatoryAttribute(id=GlobalAttributeIds.GENERATED_COMMAND_LIST_ID, name="GeneratedCommandList", validators=[check_list_of_ints_in_range(0, 0xFFFF_FFFF), check_no_duplicates]), ] @@ -532,7 +230,7 @@ class RequiredMandatoryAttribute: if success: for endpoint_id, endpoint in self.endpoints_tlv.items(): for cluster_id, cluster in endpoint.items(): - attribute_list = cluster[ATTRIBUTE_LIST_ID] + attribute_list = cluster[GlobalAttributeIds.ATTRIBUTE_LIST_ID] for attribute_id in attribute_list: location = AttributePathLocation(endpoint_id, cluster_id, attribute_id) has_attribute = attribute_id in cluster @@ -579,7 +277,7 @@ class RequiredMandatoryAttribute: mei_range_min = 0x0001_0000 for endpoint_id, endpoint in self.endpoints_tlv.items(): for cluster_id, cluster in endpoint.items(): - globals = [a for a in cluster[ATTRIBUTE_LIST_ID] if a >= global_range_min and a < mei_range_min] + globals = [a for a in cluster[GlobalAttributeIds.ATTRIBUTE_LIST_ID] if a >= global_range_min and a < mei_range_min] unexpected_globals = sorted(list(set(globals) - set(allowed_globals))) for unexpected in unexpected_globals: location = AttributePathLocation(endpoint_id=endpoint_id, cluster_id=cluster_id, attribute_id=unexpected) @@ -593,7 +291,8 @@ class RequiredMandatoryAttribute: if cluster_id not in chip.clusters.ClusterObjects.ALL_ATTRIBUTES: # Skip clusters that are not part of the standard generated corpus (e.g. MS clusters) continue - standard_attributes = [a for a in cluster[ATTRIBUTE_LIST_ID] if a <= attribute_standard_range_max] + standard_attributes = [a for a in cluster[GlobalAttributeIds.ATTRIBUTE_LIST_ID] + if a <= attribute_standard_range_max] allowed_standard_attributes = chip.clusters.ClusterObjects.ALL_ATTRIBUTES[cluster_id] unexpected_standard_attributes = sorted(list(set(standard_attributes) - set(allowed_standard_attributes))) for unexpected in unexpected_standard_attributes: @@ -606,7 +305,7 @@ class RequiredMandatoryAttribute: # This is de-facto already covered in the check above, assuming the spec hasn't defined any values in this range, but we should make sure for endpoint_id, endpoint in self.endpoints_tlv.items(): for cluster_id, cluster in endpoint.items(): - bad_range_values = [a for a in cluster[ATTRIBUTE_LIST_ID] if a > + bad_range_values = [a for a in cluster[GlobalAttributeIds.ATTRIBUTE_LIST_ID] if a > attribute_standard_range_max and a < global_range_min] for bad in bad_range_values: location = AttributePathLocation(endpoint_id=endpoint_id, cluster_id=cluster_id, attribute_id=bad) @@ -620,8 +319,10 @@ class RequiredMandatoryAttribute: for cluster_id, cluster in endpoint.items(): if cluster_id not in chip.clusters.ClusterObjects.ALL_CLUSTERS: continue - standard_accepted_commands = [a for a in cluster[ACCEPTED_COMMAND_LIST_ID] if a <= command_standard_range_max] - standard_generated_commands = [a for a in cluster[GENERATED_COMMAND_LIST_ID] if a <= command_standard_range_max] + standard_accepted_commands = [ + a for a in cluster[GlobalAttributeIds.ACCEPTED_COMMAND_LIST_ID] if a <= command_standard_range_max] + standard_generated_commands = [ + a for a in cluster[GlobalAttributeIds.GENERATED_COMMAND_LIST_ID] if a <= command_standard_range_max] if cluster_id in chip.clusters.ClusterObjects.ALL_ACCEPTED_COMMANDS: allowed_accepted_commands = [a for a in chip.clusters.ClusterObjects.ALL_ACCEPTED_COMMANDS[cluster_id]] else: @@ -658,8 +359,9 @@ class RequiredMandatoryAttribute: bad_prefix_min = 0xFFF1_0000 for endpoint_id, endpoint in self.endpoints_tlv.items(): for cluster_id, cluster in endpoint.items(): - attr_prefixes = [a & 0xFFFF_0000 for a in cluster[ATTRIBUTE_LIST_ID]] - cmd_values = cluster[ACCEPTED_COMMAND_LIST_ID] + cluster[GENERATED_COMMAND_LIST_ID] + attr_prefixes = [a & 0xFFFF_0000 for a in cluster[GlobalAttributeIds.ATTRIBUTE_LIST_ID]] + cmd_values = cluster[GlobalAttributeIds.ACCEPTED_COMMAND_LIST_ID] + \ + cluster[GlobalAttributeIds.GENERATED_COMMAND_LIST_ID] cmd_prefixes = [a & 0xFFFF_0000 for a in cmd_values] bad_attrs = [a for a in attr_prefixes if a >= bad_prefix_min] bad_cmds = [a for a in cmd_prefixes if a >= bad_prefix_min] @@ -679,7 +381,7 @@ class RequiredMandatoryAttribute: suffix_mask = 0x0000_FFFF for endpoint_id, endpoint in self.endpoints_tlv.items(): for cluster_id, cluster in endpoint.items(): - manufacturer_range_values = [a for a in cluster[ATTRIBUTE_LIST_ID] if a > mei_range_min] + manufacturer_range_values = [a for a in cluster[GlobalAttributeIds.ATTRIBUTE_LIST_ID] if a > mei_range_min] for manufacturer_value in manufacturer_range_values: suffix = manufacturer_value & suffix_mask location = AttributePathLocation(endpoint_id=endpoint_id, cluster_id=cluster_id, @@ -697,8 +399,10 @@ class RequiredMandatoryAttribute: for endpoint_id, endpoint in self.endpoints_tlv.items(): for cluster_id, cluster in endpoint.items(): - accepted_manufacturer_range_values = [a for a in cluster[ACCEPTED_COMMAND_LIST_ID] if a > mei_range_min] - generated_manufacturer_range_values = [a for a in cluster[GENERATED_COMMAND_LIST_ID] if a > mei_range_min] + accepted_manufacturer_range_values = [ + a for a in cluster[GlobalAttributeIds.ACCEPTED_COMMAND_LIST_ID] if a > mei_range_min] + generated_manufacturer_range_values = [ + a for a in cluster[GlobalAttributeIds.GENERATED_COMMAND_LIST_ID] if a > mei_range_min] all_command_manufacturer_range_values = accepted_manufacturer_range_values + generated_manufacturer_range_values for manufacturer_value in all_command_manufacturer_range_values: suffix = manufacturer_value & suffix_mask @@ -743,7 +447,7 @@ class RequiredMandatoryAttribute: for cluster_id, cluster in endpoint.items(): if cluster_id not in chip.clusters.ClusterObjects.ALL_CLUSTERS: continue - feature_map = cluster[FEATURE_MAP_ID] + feature_map = cluster[GlobalAttributeIds.FEATURE_MAP_ID] feature_mask = 0 try: feature_map_enum = chip.clusters.ClusterObjects.ALL_CLUSTERS[cluster_id].Bitmaps.Feature @@ -1045,14 +749,16 @@ def conformance_str(conformance: Callable, feature_map: uint, feature_dict: dict problem='Standard cluster found on device, but is not present in spec data') continue - feature_map = cluster[FEATURE_MAP_ID] - attribute_list = cluster[ATTRIBUTE_LIST_ID] - all_command_list = cluster[ACCEPTED_COMMAND_LIST_ID] + cluster[GENERATED_COMMAND_LIST_ID] + feature_map = cluster[GlobalAttributeIds.FEATURE_MAP_ID] + attribute_list = cluster[GlobalAttributeIds.ATTRIBUTE_LIST_ID] + all_command_list = cluster[GlobalAttributeIds.ACCEPTED_COMMAND_LIST_ID] + \ + cluster[GlobalAttributeIds.GENERATED_COMMAND_LIST_ID] # Feature conformance checking feature_masks = [1 << i for i in range(32) if feature_map & (1 << i)] for f in feature_masks: - location = AttributePathLocation(endpoint_id=endpoint_id, cluster_id=cluster_id, attribute_id=FEATURE_MAP_ID) + location = AttributePathLocation(endpoint_id=endpoint_id, cluster_id=cluster_id, + attribute_id=GlobalAttributeIds.FEATURE_MAP_ID) if f not in clusters[cluster_id].features.keys(): self.record_error(self.get_test_name(), location=location, problem=f'Unknown feature with mask 0x{f:02x}') success = False @@ -1102,7 +808,7 @@ def conformance_str(conformance: Callable, feature_map: uint, feature_dict: dict def check_spec_conformance_for_commands(command_type: CommandType) -> bool: success = True - global_attribute_id = ACCEPTED_COMMAND_LIST_ID if command_type == CommandType.ACCEPTED else GENERATED_COMMAND_LIST_ID + global_attribute_id = GlobalAttributeIds.ACCEPTED_COMMAND_LIST_ID if command_type == CommandType.ACCEPTED else GlobalAttributeIds.GENERATED_COMMAND_LIST_ID xml_commands_dict = clusters[cluster_id].accepted_commands if command_type == CommandType.ACCEPTED else clusters[cluster_id].generated_commands command_list = cluster[global_attribute_id] for command_id in command_list: diff --git a/src/python_testing/TestMatterTestingSupport.py b/src/python_testing/TestMatterTestingSupport.py index f69182ad6b5350..e713f4f0e71a61 100644 --- a/src/python_testing/TestMatterTestingSupport.py +++ b/src/python_testing/TestMatterTestingSupport.py @@ -25,9 +25,9 @@ from matter_testing_support import (MatterBaseTest, async_test_body, compare_time, default_matter_test_main, get_wait_seconds_from_set_time, parse_pics, type_matches, utc_time_in_matter_epoch) from mobly import asserts, signals -from TC_DeviceBasicComposition import (TagProblem, create_device_type_list_for_root, create_device_type_lists, - find_tag_list_problems, find_tree_roots, get_all_children, get_direct_children_of_root, - parts_list_cycles, separate_endpoint_types) +from taglist_and_topology_test_support import (TagProblem, create_device_type_list_for_root, create_device_type_lists, + find_tag_list_problems, find_tree_roots, get_all_children, + get_direct_children_of_root, parts_list_cycles, separate_endpoint_types) def get_raw_type_list(): diff --git a/src/python_testing/basic_composition_support.py b/src/python_testing/basic_composition_support.py new file mode 100644 index 00000000000000..523c41c223875d --- /dev/null +++ b/src/python_testing/basic_composition_support.py @@ -0,0 +1,163 @@ +# +# Copyright (c) 2023 Project CHIP Authors +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import base64 +import copy +import json +import logging +import pathlib +import sys +from pprint import pprint +from typing import Any, Optional + +import chip.clusters.ClusterObjects +import chip.tlv +from chip.clusters.Attribute import ValueDecodeFailure +from mobly import asserts + + +def MatterTlvToJson(tlv_data: dict[int, Any]) -> dict[str, Any]: + """Given TLV data for a specific cluster instance, convert to the Matter JSON format.""" + + matter_json_dict = {} + + key_type_mappings = { + chip.tlv.uint: "UINT", + int: "INT", + bool: "BOOL", + list: "ARRAY", + dict: "STRUCT", + chip.tlv.float32: "FLOAT", + float: "DOUBLE", + bytes: "BYTES", + str: "STRING", + ValueDecodeFailure: "ERROR", + type(None): "NULL", + } + + def ConvertValue(value) -> Any: + if isinstance(value, ValueDecodeFailure): + raise ValueError(f"Bad Value: {str(value)}") + + if isinstance(value, bytes): + return base64.b64encode(value).decode("UTF-8") + elif isinstance(value, list): + value = [ConvertValue(item) for item in value] + elif isinstance(value, dict): + value = MatterTlvToJson(value) + + return value + + for key in tlv_data: + value_type = type(tlv_data[key]) + value = copy.deepcopy(tlv_data[key]) + + element_type: str = key_type_mappings[value_type] + sub_element_type = "" + + try: + new_value = ConvertValue(value) + except ValueError as e: + new_value = str(e) + + if element_type: + if element_type == "ARRAY": + if len(new_value): + sub_element_type = key_type_mappings[type(tlv_data[key][0])] + else: + sub_element_type = "?" + + new_key = "" + if element_type: + if sub_element_type: + new_key = f"{str(key)}:{element_type}-{sub_element_type}" + else: + new_key = f"{str(key)}:{element_type}" + else: + new_key = str(key) + + matter_json_dict[new_key] = new_value + + return matter_json_dict + + +class BasicCompositionTests: + async def setup_class_helper(self): + dev_ctrl = self.default_controller + self.problems = [] + + do_test_over_pase = self.user_params.get("use_pase_only", True) + dump_device_composition_path: Optional[str] = self.user_params.get("dump_device_composition_path", None) + + if do_test_over_pase: + info = self.get_setup_payload_info() + + commissionable_nodes = dev_ctrl.DiscoverCommissionableNodes( + info.filter_type, info.filter_value, stopOnFirst=True, timeoutSecond=15) + logging.info(f"Commissionable nodes: {commissionable_nodes}") + # TODO: Support BLE + if commissionable_nodes is not None and len(commissionable_nodes) > 0: + commissionable_node = commissionable_nodes[0] + instance_name = f"{commissionable_node.instanceName}._matterc._udp.local" + vid = f"{commissionable_node.vendorId}" + pid = f"{commissionable_node.productId}" + address = f"{commissionable_node.addresses[0]}" + logging.info(f"Found instance {instance_name}, VID={vid}, PID={pid}, Address={address}") + + node_id = 1 + dev_ctrl.EstablishPASESessionIP(address, info.passcode, node_id) + else: + asserts.fail("Failed to find the DUT according to command line arguments.") + else: + # Using the already commissioned node + node_id = self.dut_node_id + + wildcard_read = (await dev_ctrl.Read(node_id, [()])) + endpoints_tlv = wildcard_read.tlvAttributes + + node_dump_dict = {endpoint_id: MatterTlvToJson(endpoints_tlv[endpoint_id]) for endpoint_id in endpoints_tlv} + logging.debug(f"Raw TLV contents of Node: {json.dumps(node_dump_dict, indent=2)}") + + if dump_device_composition_path is not None: + with open(pathlib.Path(dump_device_composition_path).with_suffix(".json"), "wt+") as outfile: + json.dump(node_dump_dict, outfile, indent=2) + with open(pathlib.Path(dump_device_composition_path).with_suffix(".txt"), "wt+") as outfile: + pprint(wildcard_read.attributes, outfile, indent=1, width=200, compact=True) + + logging.info("###########################################################") + logging.info("Start of actual tests") + logging.info("###########################################################") + + # ======= State kept for use by all tests ======= + + # All endpoints in "full object" indexing format + self.endpoints = wildcard_read.attributes + + # All endpoints in raw TLV format + self.endpoints_tlv = wildcard_read.tlvAttributes + + def get_test_name(self) -> str: + """Return the function name of the caller. Used to create logging entries.""" + return sys._getframe().f_back.f_code.co_name + + def fail_current_test(self, msg: Optional[str] = None): + if not msg: + # Without a message, just log the last problem seen + asserts.fail(msg=self.problems[-1].problem) + else: + asserts.fail(msg) diff --git a/src/python_testing/global_attribute_ids.py b/src/python_testing/global_attribute_ids.py new file mode 100644 index 00000000000000..63d37952a00e2c --- /dev/null +++ b/src/python_testing/global_attribute_ids.py @@ -0,0 +1,28 @@ +# +# Copyright (c) 2023 Project CHIP Authors +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This file should be removed once we have a good way to get this from the codegen or XML + +from enum import IntEnum + + +class GlobalAttributeIds(IntEnum): + ATTRIBUTE_LIST_ID = 0xFFFB + ACCEPTED_COMMAND_LIST_ID = 0xFFF9 + GENERATED_COMMAND_LIST_ID = 0xFFF8 + FEATURE_MAP_ID = 0xFFFC + CLUSTER_REVISION_ID = 0xFFFD diff --git a/src/python_testing/taglist_and_topology_test_support.py b/src/python_testing/taglist_and_topology_test_support.py new file mode 100644 index 00000000000000..af3bb05bea21b6 --- /dev/null +++ b/src/python_testing/taglist_and_topology_test_support.py @@ -0,0 +1,191 @@ +# +# Copyright (c) 2023 Project CHIP Authors +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any + +import chip.clusters as Clusters + + +@dataclass +class TagProblem: + root: int + missing_attribute: bool + missing_feature: bool + duplicates: set[int] + same_tag: set[int] = field(default_factory=set) + + +def separate_endpoint_types(endpoint_dict: dict[int, Any]) -> tuple[list[int], list[int]]: + """Returns a tuple containing the list of flat endpoints and a list of tree endpoints""" + flat = [] + tree = [] + for endpoint_id, endpoint in endpoint_dict.items(): + if endpoint_id == 0: + continue + aggregator_id = 0x000e + content_app_id = 0x0024 + device_types = [d.deviceType for d in endpoint[Clusters.Descriptor][Clusters.Descriptor.Attributes.DeviceTypeList]] + if aggregator_id in device_types: + flat.append(endpoint_id) + else: + if content_app_id in device_types: + continue + tree.append(endpoint_id) + return (flat, tree) + + +def get_all_children(endpoint_id, endpoint_dict: dict[int, Any]) -> set[int]: + """Returns all the children (include subchildren) of the given endpoint + This assumes we've already checked that there are no cycles, so we can do the dumb things and just trace the tree + """ + children = set() + + def add_children(endpoint_id, children): + immediate_children = endpoint_dict[endpoint_id][Clusters.Descriptor][Clusters.Descriptor.Attributes.PartsList] + if not immediate_children: + return + children.update(set(immediate_children)) + for child in immediate_children: + add_children(child, children) + + add_children(endpoint_id, children) + return children + + +def find_tree_roots(tree_endpoints: list[int], endpoint_dict: dict[int, Any]) -> set[int]: + """Returns a set of all the endpoints in tree_endpoints that are roots for a tree (not include singletons)""" + tree_roots = set() + + def find_tree_root(current_id): + for endpoint_id, endpoint in endpoint_dict.items(): + if endpoint_id not in tree_endpoints: + continue + if current_id in endpoint[Clusters.Descriptor][Clusters.Descriptor.Attributes.PartsList]: + # this is not the root, move up + return find_tree_root(endpoint_id) + return current_id + + for endpoint_id in tree_endpoints: + root = find_tree_root(endpoint_id) + if root != endpoint_id: + tree_roots.add(root) + return tree_roots + + +def parts_list_cycles(tree_endpoints: list[int], endpoint_dict: dict[int, Any]) -> list[int]: + """Returns a list of all the endpoints in the tree_endpoints list that contain cycles""" + def parts_list_cycle_detect(visited: set, current_id: int) -> bool: + if current_id in visited: + return True + visited.add(current_id) + for child in endpoint_dict[current_id][Clusters.Descriptor][Clusters.Descriptor.Attributes.PartsList]: + child_has_cycles = parts_list_cycle_detect(visited, child) + if child_has_cycles: + return True + return False + + cycles = [] + # This is quick enough that we can do all the endpoints without searching for the roots + for endpoint_id in tree_endpoints: + visited = set() + if parts_list_cycle_detect(visited, endpoint_id): + cycles.append(endpoint_id) + return cycles + + +def create_device_type_lists(roots: list[int], endpoint_dict: dict[int, Any]) -> dict[int, dict[int, set[int]]]: + """Returns a list of endpoints per device type for each root in the list""" + device_types = {} + for root in roots: + tree_device_types = defaultdict(set) + eps = get_all_children(root, endpoint_dict) + eps.add(root) + for ep in eps: + for d in endpoint_dict[ep][Clusters.Descriptor][Clusters.Descriptor.Attributes.DeviceTypeList]: + tree_device_types[d.deviceType].add(ep) + device_types[root] = tree_device_types + + return device_types + + +def get_direct_children_of_root(endpoint_dict: dict[int, Any]) -> set[int]: + root_children = set(endpoint_dict[0][Clusters.Descriptor][Clusters.Descriptor.Attributes.PartsList]) + direct_children = root_children + for ep in root_children: + ep_children = set(endpoint_dict[ep][Clusters.Descriptor][Clusters.Descriptor.Attributes.PartsList]) + direct_children = direct_children - ep_children + return direct_children + + +def create_device_type_list_for_root(direct_children, endpoint_dict: dict[int, Any]) -> dict[int, set[int]]: + device_types = defaultdict(set) + for ep in direct_children: + for d in endpoint_dict[ep][Clusters.Descriptor][Clusters.Descriptor.Attributes.DeviceTypeList]: + device_types[d.deviceType].add(ep) + return device_types + + +def cmp_tag_list(a: Clusters.Descriptor.Structs.SemanticTagStruct, b: Clusters.Descriptor.Structs.SemanticTagStruct): + if a.mfgCode != b.mfgCode: + return -1 if a.mfgCode < b.mfgCode else 1 + if a.namespaceID != b.namespaceID: + return -1 if a.namespaceID < b.namespaceID else 1 + if a.tag != b.tag: + return -1 if a.tag < b.tag else 1 + if a.label != b.label: + return -1 if a.label < b.label else 1 + return 0 + + +def find_tag_list_problems(roots: list[int], device_types: dict[int, dict[int, set[int]]], endpoint_dict: dict[int, Any]) -> dict[int, TagProblem]: + """Checks for non-spec compliant tag lists""" + tag_problems = {} + for root in roots: + for _, endpoints in device_types[root].items(): + if len(endpoints) < 2: + continue + for endpoint in endpoints: + missing_feature = not bool(endpoint_dict[endpoint][Clusters.Descriptor] + [Clusters.Descriptor.Attributes.FeatureMap] & Clusters.Descriptor.Bitmaps.Feature.kTagList) + if Clusters.Descriptor.Attributes.TagList not in endpoint_dict[endpoint][Clusters.Descriptor] or endpoint_dict[endpoint][Clusters.Descriptor][Clusters.Descriptor.Attributes.TagList] == []: + tag_problems[endpoint] = TagProblem(root=root, missing_attribute=True, + missing_feature=missing_feature, duplicates=endpoints) + continue + # Check that this tag isn't the same as the other tags in the endpoint list + duplicate_tags = set() + for other in endpoints: + if other == endpoint: + continue + # The OTHER endpoint is missing a tag list attribute - ignore this here, we'll catch that when we assess this endpoint as the primary + if Clusters.Descriptor.Attributes.TagList not in endpoint_dict[other][Clusters.Descriptor]: + continue + + if sorted(endpoint_dict[endpoint][Clusters.Descriptor][Clusters.Descriptor.Attributes.TagList], key=functools.cmp_to_key(cmp_tag_list)) == sorted(endpoint_dict[other][Clusters.Descriptor][Clusters.Descriptor.Attributes.TagList], key=functools.cmp_to_key(cmp_tag_list)): + duplicate_tags.add(other) + if len(duplicate_tags) != 0: + duplicate_tags.add(endpoint) + tag_problems[endpoint] = TagProblem(root=root, missing_attribute=False, missing_feature=missing_feature, + duplicates=endpoints, same_tag=duplicate_tags) + continue + if missing_feature: + tag_problems[endpoint] = TagProblem(root=root, missing_attribute=False, + missing_feature=missing_feature, duplicates=endpoints) + + return tag_problems