Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api: reorganize policy generation API #196

Merged
merged 1 commit into from
Apr 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 0 additions & 47 deletions sros2/sros2/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,53 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import namedtuple

HIDDEN_NODE_PREFIX = '_'

NodeName = namedtuple('NodeName', ('node', 'ns', 'fqn'))
TopicInfo = namedtuple('Topic', ('fqn', 'type'))


def get_node_names(*, node, include_hidden_nodes=False):
node_names_and_namespaces = node.get_node_names_and_namespaces()
return [
NodeName(
node=t[0],
ns=t[1],
fqn=t[1] + ('' if t[1].endswith('/') else '/') + t[0])
for t in node_names_and_namespaces
if (
include_hidden_nodes or
(t[0] and not t[0].startswith(HIDDEN_NODE_PREFIX))
)
]


def get_topics(node_name, func):
names_and_types = func(node_name.node, node_name.ns)
return [
TopicInfo(
fqn=t[0],
type=t[1])
for t in names_and_types]


def get_subscriber_info(node, node_name):
return get_topics(node_name, node.get_subscriber_names_and_types_by_node)


def get_publisher_info(node, node_name):
return get_topics(node_name, node.get_publisher_names_and_types_by_node)


def get_service_info(node, node_name):
return get_topics(node_name, node.get_service_names_and_types_by_node)


def get_client_info(node, node_name):
return get_topics(node_name, node.get_client_names_and_types_by_node)


def distribute_key(source_keystore_path, taget_keystore_path):
raise NotImplementedError()
65 changes: 52 additions & 13 deletions sros2/sros2/verb/generate_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import namedtuple
import os
import sys

Expand All @@ -31,14 +32,6 @@ def FilesCompleter(*, allowednames, directories):
from ros2cli.node.strategy import add_arguments as add_strategy_node_arguments
from ros2cli.node.strategy import NodeStrategy

from sros2.api import (
get_client_info,
get_node_names,
get_publisher_info,
get_service_info,
get_subscriber_info
)

from sros2.policy import (
dump_policy,
load_policy,
Expand All @@ -48,6 +41,12 @@ def FilesCompleter(*, allowednames, directories):
from sros2.verb import VerbExtension


_HIDDEN_NODE_PREFIX = '_'

_NodeName = namedtuple('NodeName', ('node', 'ns', 'fqn'))
_TopicInfo = namedtuple('Topic', ('fqn', 'type'))


def formatTopics(topic_list, permission, topic_map):
for topic in topic_list:
topic_map[topic.name].append(permission)
Expand Down Expand Up @@ -116,7 +115,7 @@ def add_permission(
def main(self, *, args):
policy = self.get_policy(args.POLICY_FILE_PATH)
with NodeStrategy(args) as node:
node_names = get_node_names(node=node, include_hidden_nodes=False)
node_names = _get_node_names(node=node, include_hidden_nodes=False)

if not len(node_names):
print('No nodes detected in the ROS graph. No policy file was generated.',
Expand All @@ -125,23 +124,63 @@ def main(self, *, args):

for node_name in node_names:
profile = self.get_profile(policy, node_name)
subscribe_topics = get_subscriber_info(node=node, node_name=node_name)
subscribe_topics = _get_subscriber_info(node=node, node_name=node_name)
if subscribe_topics:
self.add_permission(
profile, 'topic', 'subscribe', 'ALLOW', subscribe_topics, node_name)
publish_topics = get_publisher_info(node=node, node_name=node_name)
publish_topics = _get_publisher_info(node=node, node_name=node_name)
if publish_topics:
self.add_permission(
profile, 'topic', 'publish', 'ALLOW', publish_topics, node_name)
reply_services = get_service_info(node=node, node_name=node_name)
reply_services = _get_service_info(node=node, node_name=node_name)
if reply_services:
self.add_permission(
profile, 'service', 'reply', 'ALLOW', reply_services, node_name)
request_services = get_client_info(node=node, node_name=node_name)
request_services = _get_client_info(node=node, node_name=node_name)
if request_services:
self.add_permission(
profile, 'service', 'request', 'ALLOW', request_services, node_name)

with open(args.POLICY_FILE_PATH, 'w') as stream:
dump_policy(policy, stream)
return 0


def _get_node_names(*, node, include_hidden_nodes=False):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these are more generally useful. They're very specific to this verb. Other packages will consume node names/namespaces on their own.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed

node_names_and_namespaces = node.get_node_names_and_namespaces()
return [
_NodeName(
node=t[0],
ns=t[1],
fqn=t[1] + ('' if t[1].endswith('/') else '/') + t[0])
for t in node_names_and_namespaces
if (
include_hidden_nodes or
(t[0] and not t[0].startswith(_HIDDEN_NODE_PREFIX))
)
]


def _get_topics(node_name, func):
names_and_types = func(node_name.node, node_name.ns)
return [
_TopicInfo(
fqn=t[0],
type=t[1])
for t in names_and_types]


def _get_subscriber_info(node, node_name):
return _get_topics(node_name, node.get_subscriber_names_and_types_by_node)


def _get_publisher_info(node, node_name):
return _get_topics(node_name, node.get_publisher_names_and_types_by_node)


def _get_service_info(node, node_name):
return _get_topics(node_name, node.get_service_names_and_types_by_node)


def _get_client_info(node, node_name):
return _get_topics(node_name, node.get_client_names_and_types_by_node)