Skip to content

Commit

Permalink
feat(ESSNTL-5389): Include platform_metadata for MQ messages from API…
Browse files Browse the repository at this point in the history
… requests (#1533)

* Add b64_identity to platform_metadata on PATCH requests; cleanup

* Reduce redundancy

* platform_metadata on delete events

* Reaper fix

* Add identity to groups API MQ, make identity.user optional

* Remove account number requirement

* Revert the _decode_id changes in queue.py

* Undo a couple refactors; add test
  • Loading branch information
kruai authored Dec 5, 2023
1 parent b6222c2 commit 2081935
Show file tree
Hide file tree
Showing 15 changed files with 139 additions and 59 deletions.
6 changes: 4 additions & 2 deletions api/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from app import RbacPermission
from app import RbacResourceType
from app.auth import get_current_identity
from app.auth.identity import to_auth_header
from app.instrumentation import get_control_rule
from app.instrumentation import log_get_host_list_failed
from app.instrumentation import log_get_host_list_succeeded
Expand Down Expand Up @@ -207,7 +208,7 @@ def _delete_host_list(host_id_list, rbac_filter):
deletion_count = 0

for host_id, deleted in delete_hosts(
query, current_app.event_producer, inventory_config().host_delete_chunk_size
query, current_app.event_producer, inventory_config().host_delete_chunk_size, identity=current_identity
):
if deleted:
log_host_delete_succeeded(logger, host_id, get_control_rule())
Expand Down Expand Up @@ -312,7 +313,8 @@ def _emit_patch_event(serialized_host, host):
host.system_profile_facts.get("host_type"),
host.system_profile_facts.get("operating_system", {}).get("name"),
)
event = build_event(EventType.updated, serialized_host)
metadata = {"b64_identity": to_auth_header(get_current_identity())}
event = build_event(EventType.updated, serialized_host, platform_metadata=metadata)
current_app.event_producer.write_event(event, str(host.id), headers, wait=True)


Expand Down
35 changes: 23 additions & 12 deletions app/auth/identity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from base64 import b64decode
from base64 import b64encode
from enum import Enum
from json import dumps
from json import loads

import marshmallow as m
Expand All @@ -16,16 +18,6 @@
SHARED_SECRET_ENV_VAR = "INVENTORY_SHARED_SECRET"


def from_auth_header(base64):
json = b64decode(base64)
identity_dict = loads(json)
return Identity(identity_dict["identity"])


def from_bearer_token(token):
return Identity(token=token)


class AuthType(str, Enum):
BASIC = "basic-auth"
CERT = "cert-auth"
Expand Down Expand Up @@ -97,8 +89,11 @@ def _asdict(self):
ident["account_number"] = self.account_number

if self.identity_type == IdentityType.USER:
ident["user"] = self.user.copy()
if hasattr(self, "user"):
ident["user"] = self.user.copy()

return ident

if self.identity_type == IdentityType.SYSTEM:
ident["system"] = self.system.copy()
return ident
Expand All @@ -124,7 +119,7 @@ class IdentitySchema(IdentityBaseSchema):
org_id = m.fields.Str(required=True, validate=m.validate.Length(min=1, max=36))
type = m.fields.String(required=True, validate=m.validate.OneOf(IdentityType.__members__.values()))
auth_type = IdentityLowerString(required=True, validate=m.validate.OneOf(AuthType.__members__.values()))
account_number = m.fields.Str(validate=m.validate.Length(min=0, max=36))
account_number = m.fields.Str(allow_none=True, validate=m.validate.Length(min=0, max=36))

@m.post_load
def user_system_check(self, in_data, **kwargs):
Expand Down Expand Up @@ -153,3 +148,19 @@ class SystemIdentitySchema(IdentityBaseSchema):
# So this helper function creates a basic User-type identity from the host data.
def create_mock_identity_with_org_id(org_id):
return Identity({"org_id": org_id, "type": IdentityType.USER.value, "auth_type": AuthType.BASIC})


def to_auth_header(identity: Identity):
id = {"identity": identity._asdict()}
b64_id = b64encode(dumps(id).encode())
return b64_id.decode("ascii")


def from_auth_header(base64):
json = b64decode(base64)
identity_dict = loads(json)
return Identity(identity_dict["identity"])


def from_bearer_token(token):
return Identity(token=token)
4 changes: 3 additions & 1 deletion app/queue/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class HostDeleteEvent(Schema):
org_id = fields.Str()
insights_id = fields.Str()
request_id = fields.Str()
platform_metadata = fields.Dict()
metadata = fields.Nested(HostEventMetadataSchema())


Expand Down Expand Up @@ -102,7 +103,7 @@ def host_create_update_event(event_type, host, platform_metadata=None):
)


def host_delete_event(event_type, host):
def host_delete_event(event_type, host, platform_metadata=None):
delete_event = {
"timestamp": datetime.now(timezone.utc),
"type": event_type.name,
Expand All @@ -111,6 +112,7 @@ def host_delete_event(event_type, host):
"org_id": host.org_id,
"account": host.account,
"request_id": threadctx.request_id,
"platform_metadata": platform_metadata,
"metadata": {"request_id": threadctx.request_id},
}

Expand Down
4 changes: 3 additions & 1 deletion lib/group_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from api.host_query import staleness_timestamps
from api.staleness_query import get_staleness_obj
from app.auth import get_current_identity
from app.auth.identity import to_auth_header
from app.exceptions import InventoryException
from app.instrumentation import get_control_rule
from app.instrumentation import log_get_group_list_failed
Expand Down Expand Up @@ -41,6 +42,7 @@ def _produce_host_update_events(event_producer, host_id_list, group_id_list=[],
Host.query.filter(Host.id.in_(host_id_list)).update({"groups": serialized_groups}, synchronize_session="fetch")
db.session.commit()
host_list = get_host_list_by_id_list_from_db(host_id_list)
metadata = {"b64_identity": to_auth_header(get_current_identity())}

# Send messages
for host in host_list:
Expand All @@ -53,7 +55,7 @@ def _produce_host_update_events(event_producer, host_id_list, group_id_list=[],
host.system_profile_facts.get("host_type"),
host.system_profile_facts.get("operating_system", {}).get("name"),
)
event = build_event(EventType.updated, serialized_host)
event = build_event(EventType.updated, serialized_host, platform_metadata=metadata)
event_producer.write_event(event, serialized_host["id"], headers, wait=True)


Expand Down
10 changes: 6 additions & 4 deletions lib/host_delete.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from confluent_kafka import KafkaException
from sqlalchemy.orm.base import instance_state

from app.auth.identity import to_auth_header
from app.logging import get_logger
from app.models import Host
from app.models import HostGroupAssoc
Expand All @@ -16,21 +17,21 @@
logger = get_logger(__name__)


def delete_hosts(select_query, event_producer, chunk_size, interrupt=lambda: False):
def delete_hosts(select_query, event_producer, chunk_size, interrupt=lambda: False, identity=None):
with session_guard(select_query.session):
while select_query.count():
for host in select_query.limit(chunk_size):
host_id = host.id
with delete_host_processing_time.time():
host_deleted = _delete_host(select_query.session, event_producer, host)
host_deleted = _delete_host(select_query.session, event_producer, host, identity)

yield host_id, host_deleted

if interrupt():
return


def _delete_host(session, event_producer, host):
def _delete_host(session, event_producer, host, identity=None):
assoc_delete_query = session.query(HostGroupAssoc).filter(HostGroupAssoc.host_id == host.id)
host_delete_query = session.query(Host).filter(Host.id == host.id)
if kafka_available():
Expand All @@ -39,7 +40,8 @@ def _delete_host(session, event_producer, host):
host_deleted = _deleted_by_this_query(host)
if host_deleted:
delete_host_count.inc()
event = build_event(EventType.delete, host)
metadata = {"b64_identity": to_auth_header(identity)} if identity else None
event = build_event(EventType.delete, host, platform_metadata=metadata)
headers = message_headers(
EventType.delete,
host.canonical_facts.get("insights_id"),
Expand Down
25 changes: 22 additions & 3 deletions tests/helpers/mq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@

from confluent_kafka import TopicPartition

from app.auth.identity import Identity
from app.auth.identity import to_auth_header
from app.serialization import serialize_facts
from app.utils import Tag
from tests.helpers.test_utils import minimal_host
from tests.helpers.test_utils import USER_IDENTITY


MockFutureCallback = namedtuple("MockFutureCallback", ("method", "args", "kwargs", "extra_arg"))
Expand Down Expand Up @@ -93,12 +96,24 @@ def assert_mq_host_data(actual_id, actual_event, expected_results, host_keys_to_
assert actual_event["host"][key] == expected_results["host"][key]


def assert_delete_event_is_valid(event_producer, host, timestamp, expected_request_id=None, expected_metadata=None):
def assert_delete_event_is_valid(
event_producer, host, timestamp, expected_request_id=None, expected_metadata=None, identity=USER_IDENTITY
):
event = json.loads(event_producer.event)

assert isinstance(event, dict)

expected_keys = {"timestamp", "type", "id", "account", "org_id", "insights_id", "request_id", "metadata"}
expected_keys = {
"timestamp",
"type",
"id",
"account",
"org_id",
"insights_id",
"request_id",
"platform_metadata",
"metadata",
}
assert set(event.keys()) == expected_keys

assert timestamp.replace(tzinfo=timezone.utc).isoformat() == event["timestamp"]
Expand All @@ -117,6 +132,9 @@ def assert_delete_event_is_valid(event_producer, host, timestamp, expected_reque
host.system_profile_facts.get("operating_system", {}).get("name"),
)

if identity:
assert event["platform_metadata"] == {"b64_identity": to_auth_header(Identity(obj=identity))}

if expected_request_id:
assert event["request_id"] == expected_request_id

Expand All @@ -132,6 +150,7 @@ def assert_patch_event_is_valid(
display_name="patch_event_test",
stale_timestamp=None,
reporter=None,
identity=USER_IDENTITY,
):
stale_timestamp = (host.modified_on.astimezone(timezone.utc) + timedelta(seconds=104400)).isoformat()
stale_warning_timestamp = (host.modified_on.astimezone(timezone.utc) + timedelta(seconds=604800)).isoformat()
Expand Down Expand Up @@ -170,7 +189,7 @@ def assert_patch_event_is_valid(
"provider_id": host.canonical_facts.get("provider_id"),
"provider_type": host.canonical_facts.get("provider_type"),
},
"platform_metadata": None,
"platform_metadata": {"b64_identity": to_auth_header(Identity(obj=identity))},
"metadata": {"request_id": expected_request_id},
"timestamp": expected_timestamp.isoformat(),
}
Expand Down
44 changes: 19 additions & 25 deletions tests/test_api_auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from base64 import b64encode
from copy import deepcopy
from json import dumps

import pytest

from app.auth.identity import Identity
from app.auth.identity import IdentityType
from tests.helpers.api_utils import build_token_auth_header
Expand Down Expand Up @@ -34,30 +37,13 @@ def invalid_payloads(identity_type):
return payloads


def valid_identity(identity_type):
"""
Provides a valid Identity object.
"""
if identity_type == IdentityType.USER:
return Identity(USER_IDENTITY)
elif identity_type == IdentityType.SYSTEM:
return Identity(SYSTEM_IDENTITY)


def create_identity_payload(identity):
dict_ = {"identity": identity._asdict()}
# Load into Identity object for validation, then return to dict
dict_ = {"identity": Identity(identity)._asdict()}
json = dumps(dict_)
return b64encode(json.encode())


def valid_payload(identity_type):
"""
Builds a valid HTTP header payload – Base64 encoded JSON string with valid data.
"""
identity = valid_identity(identity_type)
return create_identity_payload(identity)


def test_validate_missing_identity(flask_client):
"""
Identity header is not present.
Expand All @@ -74,11 +60,19 @@ def test_validate_invalid_identity(flask_client):
assert 401 == response.status_code


def test_validate_valid_user_identity(flask_client):
@pytest.mark.parametrize(
"remove_account_number",
[True, False],
)
def test_validate_valid_user_identity(flask_client, remove_account_number):
"""
Identity header is valid – non-empty in this case
"""
payload = valid_payload(IdentityType.USER)
identity = deepcopy(USER_IDENTITY)
if remove_account_number:
del identity["account_number"]

payload = create_identity_payload(identity)
response = flask_client.get(HOST_URL, headers={"x-rh-identity": payload})
assert 200 == response.status_code # OK

Expand All @@ -87,8 +81,8 @@ def test_validate_non_admin_user_identity(flask_client):
"""
Identity header is valid and user is provided, but is not an Admin
"""
identity = valid_identity(IdentityType.USER)
identity.user["username"] = "[email protected]"
identity = deepcopy(USER_IDENTITY)
identity["user"]["username"] = "[email protected]"
payload = create_identity_payload(identity)
response = flask_client.post(
f"{SYSTEM_PROFILE_URL}/validate_schema?repo_branch=master&days=1", headers={"x-rh-identity": payload}
Expand All @@ -100,7 +94,7 @@ def test_validate_non_user_admin_endpoint(flask_client):
"""
Identity header is valid and user is provided, but is not an Admin
"""
payload = valid_payload(IdentityType.SYSTEM)
payload = create_identity_payload(SYSTEM_IDENTITY)
response = flask_client.post(
f"{SYSTEM_PROFILE_URL}/validate_schema?repo_branch=master&days=1", headers={"x-rh-identity": payload}
)
Expand All @@ -111,7 +105,7 @@ def test_validate_valid_system_identity(flask_client):
"""
Identity header is valid – non-empty in this case
"""
payload = valid_payload(IdentityType.SYSTEM)
payload = create_identity_payload(SYSTEM_IDENTITY)
response = flask_client.get(HOST_URL, headers={"x-rh-identity": payload})
assert 200 == response.status_code # OK

Expand Down
7 changes: 6 additions & 1 deletion tests/test_api_groups_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
import pytest
from dateutil import parser

from app.auth.identity import Identity
from app.auth.identity import to_auth_header
from tests.helpers.api_utils import assert_group_response
from tests.helpers.api_utils import assert_response_status
from tests.helpers.api_utils import create_mock_rbac_response
from tests.helpers.api_utils import GROUP_WRITE_PROHIBITED_RBAC_RESPONSE_FILES
from tests.helpers.test_utils import generate_uuid
from tests.helpers.test_utils import SYSTEM_IDENTITY
from tests.helpers.test_utils import USER_IDENTITY


def test_create_group_with_empty_host_list(api_create_group, db_get_group_by_name, event_producer, mocker):
Expand Down Expand Up @@ -60,11 +63,13 @@ def test_create_group_with_hosts(
assert_group_response(response_data, retrieved_group)
assert event_producer.write_event.call_count == 2
for call_arg in event_producer.write_event.call_args_list:
host = json.loads(call_arg[0][0])["host"]
event = json.loads(call_arg[0][0])
host = event["host"]
assert host["id"] in host_id_list
assert host["groups"][0]["name"] == group_data["name"]
assert host["groups"][0]["id"] == str(retrieved_group.id)
assert parser.isoparse(host["updated"]) == db_get_host(host["id"]).modified_on
assert event["platform_metadata"] == {"b64_identity": to_auth_header(Identity(obj=USER_IDENTITY))}


@pytest.mark.parametrize(
Expand Down
Loading

0 comments on commit 2081935

Please sign in to comment.