diff --git a/sros2/sros2/api/__init__.py b/sros2/sros2/api/__init__.py index 1a0343ed..88c33479 100644 --- a/sros2/sros2/api/__init__.py +++ b/sros2/sros2/api/__init__.py @@ -12,53 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple - -HIDDEN_NODE_PREFIX = '_' - -NodeName = namedtuple('NodeName', ('node', 'ns', 'fqn')) -TopicInfo = namedtuple('Topic', ('fqn', 'type')) - - -def get_node_names(*, node, include_hidden_nodes=False): - node_names_and_namespaces = node.get_node_names_and_namespaces() - return [ - NodeName( - node=t[0], - ns=t[1], - fqn=t[1] + ('' if t[1].endswith('/') else '/') + t[0]) - for t in node_names_and_namespaces - if ( - include_hidden_nodes or - (t[0] and not t[0].startswith(HIDDEN_NODE_PREFIX)) - ) - ] - - -def get_topics(node_name, func): - names_and_types = func(node_name.node, node_name.ns) - return [ - TopicInfo( - fqn=t[0], - type=t[1]) - for t in names_and_types] - - -def get_subscriber_info(node, node_name): - return get_topics(node_name, node.get_subscriber_names_and_types_by_node) - - -def get_publisher_info(node, node_name): - return get_topics(node_name, node.get_publisher_names_and_types_by_node) - - -def get_service_info(node, node_name): - return get_topics(node_name, node.get_service_names_and_types_by_node) - - -def get_client_info(node, node_name): - return get_topics(node_name, node.get_client_names_and_types_by_node) - def distribute_key(source_keystore_path, taget_keystore_path): raise NotImplementedError() diff --git a/sros2/sros2/verb/generate_policy.py b/sros2/sros2/verb/generate_policy.py index 69690abf..80811fde 100644 --- a/sros2/sros2/verb/generate_policy.py +++ b/sros2/sros2/verb/generate_policy.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import namedtuple import os import sys @@ -31,14 +32,6 @@ def FilesCompleter(*, allowednames, directories): from ros2cli.node.strategy import add_arguments as add_strategy_node_arguments from ros2cli.node.strategy import NodeStrategy -from sros2.api import ( - get_client_info, - get_node_names, - get_publisher_info, - get_service_info, - get_subscriber_info -) - from sros2.policy import ( dump_policy, load_policy, @@ -48,6 +41,12 @@ def FilesCompleter(*, allowednames, directories): from sros2.verb import VerbExtension +_HIDDEN_NODE_PREFIX = '_' + +_NodeName = namedtuple('NodeName', ('node', 'ns', 'fqn')) +_TopicInfo = namedtuple('Topic', ('fqn', 'type')) + + def formatTopics(topic_list, permission, topic_map): for topic in topic_list: topic_map[topic.name].append(permission) @@ -116,7 +115,7 @@ def add_permission( def main(self, *, args): policy = self.get_policy(args.POLICY_FILE_PATH) with NodeStrategy(args) as node: - node_names = get_node_names(node=node, include_hidden_nodes=False) + node_names = _get_node_names(node=node, include_hidden_nodes=False) if not len(node_names): print('No nodes detected in the ROS graph. No policy file was generated.', @@ -125,19 +124,19 @@ def main(self, *, args): for node_name in node_names: profile = self.get_profile(policy, node_name) - subscribe_topics = get_subscriber_info(node=node, node_name=node_name) + subscribe_topics = _get_subscriber_info(node=node, node_name=node_name) if subscribe_topics: self.add_permission( profile, 'topic', 'subscribe', 'ALLOW', subscribe_topics, node_name) - publish_topics = get_publisher_info(node=node, node_name=node_name) + publish_topics = _get_publisher_info(node=node, node_name=node_name) if publish_topics: self.add_permission( profile, 'topic', 'publish', 'ALLOW', publish_topics, node_name) - reply_services = get_service_info(node=node, node_name=node_name) + reply_services = _get_service_info(node=node, node_name=node_name) if reply_services: self.add_permission( profile, 'service', 'reply', 'ALLOW', reply_services, node_name) - request_services = get_client_info(node=node, node_name=node_name) + request_services = _get_client_info(node=node, node_name=node_name) if request_services: self.add_permission( profile, 'service', 'request', 'ALLOW', request_services, node_name) @@ -145,3 +144,43 @@ def main(self, *, args): with open(args.POLICY_FILE_PATH, 'w') as stream: dump_policy(policy, stream) return 0 + + +def _get_node_names(*, node, include_hidden_nodes=False): + node_names_and_namespaces = node.get_node_names_and_namespaces() + return [ + _NodeName( + node=t[0], + ns=t[1], + fqn=t[1] + ('' if t[1].endswith('/') else '/') + t[0]) + for t in node_names_and_namespaces + if ( + include_hidden_nodes or + (t[0] and not t[0].startswith(_HIDDEN_NODE_PREFIX)) + ) + ] + + +def _get_topics(node_name, func): + names_and_types = func(node_name.node, node_name.ns) + return [ + _TopicInfo( + fqn=t[0], + type=t[1]) + for t in names_and_types] + + +def _get_subscriber_info(node, node_name): + return _get_topics(node_name, node.get_subscriber_names_and_types_by_node) + + +def _get_publisher_info(node, node_name): + return _get_topics(node_name, node.get_publisher_names_and_types_by_node) + + +def _get_service_info(node, node_name): + return _get_topics(node_name, node.get_service_names_and_types_by_node) + + +def _get_client_info(node, node_name): + return _get_topics(node_name, node.get_client_names_and_types_by_node)