Skip to content

Commit

Permalink
api: reorganize policy generation API (#196)
Browse files Browse the repository at this point in the history
Extract policy generation API from api's __init__.py and move into
verb as nothing else uses it.

Signed-off-by: Kyle Fazzari <[email protected]>
  • Loading branch information
kyrofa authored Apr 9, 2020
1 parent 84dc386 commit 4a212b3
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 60 deletions.
47 changes: 0 additions & 47 deletions sros2/sros2/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,53 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import namedtuple

HIDDEN_NODE_PREFIX = '_'

NodeName = namedtuple('NodeName', ('node', 'ns', 'fqn'))
TopicInfo = namedtuple('Topic', ('fqn', 'type'))


def get_node_names(*, node, include_hidden_nodes=False):
node_names_and_namespaces = node.get_node_names_and_namespaces()
return [
NodeName(
node=t[0],
ns=t[1],
fqn=t[1] + ('' if t[1].endswith('/') else '/') + t[0])
for t in node_names_and_namespaces
if (
include_hidden_nodes or
(t[0] and not t[0].startswith(HIDDEN_NODE_PREFIX))
)
]


def get_topics(node_name, func):
names_and_types = func(node_name.node, node_name.ns)
return [
TopicInfo(
fqn=t[0],
type=t[1])
for t in names_and_types]


def get_subscriber_info(node, node_name):
return get_topics(node_name, node.get_subscriber_names_and_types_by_node)


def get_publisher_info(node, node_name):
return get_topics(node_name, node.get_publisher_names_and_types_by_node)


def get_service_info(node, node_name):
return get_topics(node_name, node.get_service_names_and_types_by_node)


def get_client_info(node, node_name):
return get_topics(node_name, node.get_client_names_and_types_by_node)


def distribute_key(source_keystore_path, taget_keystore_path):
raise NotImplementedError()
65 changes: 52 additions & 13 deletions sros2/sros2/verb/generate_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import namedtuple
import os
import sys

Expand All @@ -31,14 +32,6 @@ def FilesCompleter(*, allowednames, directories):
from ros2cli.node.strategy import add_arguments as add_strategy_node_arguments
from ros2cli.node.strategy import NodeStrategy

from sros2.api import (
get_client_info,
get_node_names,
get_publisher_info,
get_service_info,
get_subscriber_info
)

from sros2.policy import (
dump_policy,
load_policy,
Expand All @@ -48,6 +41,12 @@ def FilesCompleter(*, allowednames, directories):
from sros2.verb import VerbExtension


_HIDDEN_NODE_PREFIX = '_'

_NodeName = namedtuple('NodeName', ('node', 'ns', 'fqn'))
_TopicInfo = namedtuple('Topic', ('fqn', 'type'))


def formatTopics(topic_list, permission, topic_map):
for topic in topic_list:
topic_map[topic.name].append(permission)
Expand Down Expand Up @@ -116,7 +115,7 @@ def add_permission(
def main(self, *, args):
policy = self.get_policy(args.POLICY_FILE_PATH)
with NodeStrategy(args) as node:
node_names = get_node_names(node=node, include_hidden_nodes=False)
node_names = _get_node_names(node=node, include_hidden_nodes=False)

if not len(node_names):
print('No nodes detected in the ROS graph. No policy file was generated.',
Expand All @@ -125,23 +124,63 @@ def main(self, *, args):

for node_name in node_names:
profile = self.get_profile(policy, node_name)
subscribe_topics = get_subscriber_info(node=node, node_name=node_name)
subscribe_topics = _get_subscriber_info(node=node, node_name=node_name)
if subscribe_topics:
self.add_permission(
profile, 'topic', 'subscribe', 'ALLOW', subscribe_topics, node_name)
publish_topics = get_publisher_info(node=node, node_name=node_name)
publish_topics = _get_publisher_info(node=node, node_name=node_name)
if publish_topics:
self.add_permission(
profile, 'topic', 'publish', 'ALLOW', publish_topics, node_name)
reply_services = get_service_info(node=node, node_name=node_name)
reply_services = _get_service_info(node=node, node_name=node_name)
if reply_services:
self.add_permission(
profile, 'service', 'reply', 'ALLOW', reply_services, node_name)
request_services = get_client_info(node=node, node_name=node_name)
request_services = _get_client_info(node=node, node_name=node_name)
if request_services:
self.add_permission(
profile, 'service', 'request', 'ALLOW', request_services, node_name)

with open(args.POLICY_FILE_PATH, 'w') as stream:
dump_policy(policy, stream)
return 0


def _get_node_names(*, node, include_hidden_nodes=False):
node_names_and_namespaces = node.get_node_names_and_namespaces()
return [
_NodeName(
node=t[0],
ns=t[1],
fqn=t[1] + ('' if t[1].endswith('/') else '/') + t[0])
for t in node_names_and_namespaces
if (
include_hidden_nodes or
(t[0] and not t[0].startswith(_HIDDEN_NODE_PREFIX))
)
]


def _get_topics(node_name, func):
names_and_types = func(node_name.node, node_name.ns)
return [
_TopicInfo(
fqn=t[0],
type=t[1])
for t in names_and_types]


def _get_subscriber_info(node, node_name):
return _get_topics(node_name, node.get_subscriber_names_and_types_by_node)


def _get_publisher_info(node, node_name):
return _get_topics(node_name, node.get_publisher_names_and_types_by_node)


def _get_service_info(node, node_name):
return _get_topics(node_name, node.get_service_names_and_types_by_node)


def _get_client_info(node, node_name):
return _get_topics(node_name, node.get_client_names_and_types_by_node)

0 comments on commit 4a212b3

Please sign in to comment.