Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not lock around writing to stdout, do not flush #1411

Open
wants to merge 7 commits into
base: devel
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 66 additions & 44 deletions src/ansible_runner/display_callback/callback/awx_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,35 @@

# pylint: disable=W0212

from __future__ import (absolute_import, division, print_function)
from __future__ import (absolute_import, annotations, division, print_function)

# Python
import json
import stat
import multiprocessing
import threading
import base64
import functools
import collections
import contextlib
import datetime
import inspect
import os
import sys
import types
import typing as t
import uuid
from copy import copy

# Ansible
from ansible import __version__ as ansible_version_str
from ansible import constants as C
from ansible.plugins.callback import CallbackBase
from ansible.plugins.loader import callback_loader
from ansible.utils.display import Display
from ansible.utils.multiprocessing import context as multiprocessing_context

if t.TYPE_CHECKING:
P = t.ParamSpec('P')

DOCUMENTATION = '''
callback: awx_display
Expand Down Expand Up @@ -68,6 +74,9 @@

CENSORED = "the output has been hidden due to the fact that 'no_log: true' was specified for this result"

_ANSIBLE_VERSION = tuple(int(p) for p in ansible_version_str.split('.')[:2])
_ANSIBLE_214 = _ANSIBLE_VERSION >= (2, 14)


def current_time():
return datetime.datetime.now(datetime.timezone.utc)
Expand Down Expand Up @@ -127,7 +136,6 @@ class EventContext:
'''

def __init__(self):
self.display_lock = multiprocessing.RLock()
self._global_ctx = {}
self._local = threading.local()
if os.getenv('AWX_ISOLATED_DATA_DIR'):
Expand Down Expand Up @@ -222,31 +230,39 @@ def get_begin_dict(self):
def get_end_dict(self):
return {}

def dump(self, fileobj, data, max_width=78, flush=False):
def dump(self, fileobj, data, max_width=78):
b64data = base64.b64encode(json.dumps(data).encode('utf-8')).decode()
with self.display_lock:
# pattern corresponding to OutputEventFilter expectation
fileobj.write('\x1b[K')
for offset in range(0, len(b64data), max_width):
chunk = b64data[offset:offset + max_width]
escaped_chunk = f'{chunk}\x1b[{len(chunk)}D'
fileobj.write(escaped_chunk)
fileobj.write('\x1b[K')
if flush:
fileobj.flush()
# pattern corresponding to OutputEventFilter expectation
out = '\x1b[K'
for offset in range(0, len(b64data), max_width):
chunk = b64data[offset:offset + max_width]
out += f'{chunk}\x1b[{len(chunk)}D'
out += '\x1b[K'
fileobj.write(out)

def dump_begin(self, fileobj):
begin_dict = self.get_begin_dict()
self.cache.set(f":1:ev-{begin_dict['uuid']}", begin_dict)
self.dump(fileobj, {'uuid': begin_dict['uuid']})

def dump_end(self, fileobj):
self.dump(fileobj, self.get_end_dict(), flush=True)
self.dump(fileobj, self.get_end_dict())


event_context = EventContext()


@functools.cache
def _getsignature(f: t.Callable) -> inspect.Signature:
return inspect.signature(f)


def _getcallargs(sig: inspect.Signature, *args: P.args, **kwargs: P.kwargs) -> types.MappingProxyType:
ba = sig.bind(*args, **kwargs)
ba.apply_defaults()
return types.MappingProxyType(ba.arguments)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This stuff is unnecessary for this PR, and if desired I can remove the last commit. Looking at things like this just drove me a little crazy:

log_only = args[5] if len(args) >= 6 else kwargs.get('log_only', False)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't safely functools.cache the bound args this way- any non-hashable value passed to a display method would cause a runtime error. There isn't really a benefit to caching this one in real-world usage anyway, since every unique combination of args is a new cache entry, so this would basically just be a memory leak bound by the size of the cache. Several of the display methods happily accept unhashable args, and last I checked there are a number of callers that blindly pass arbitrary objects to display methods.

Caching the signature this way should be fine (though you might want to verify that inspect isn't already doing so), since method objects are usually immutable/hashable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signature objects are hashable, and as far as I have been able to confirm they are not cached within inspect.

I was also skeptical of caching the _getcallargs function, purely for the concern of how much it would cache. I'll remove that one.



def with_context(**context):
global event_context # pylint: disable=W0602

Expand Down Expand Up @@ -274,8 +290,10 @@ def with_verbosity(f):

@functools.wraps(f)
def wrapper(*args, **kwargs):
host = args[2] if len(args) >= 3 else kwargs.get('host', None)
caplevel = args[3] if len(args) >= 4 else kwargs.get('caplevel', 2)
sig = _getsignature(f)
callargs = _getcallargs(sig, *args, **kwargs)
host = callargs.get('host')
caplevel = callargs.get('caplevel')
context = {'verbose': True, 'verbosity': (caplevel + 1)}
if host is not None:
context['remote_addr'] = host
Expand All @@ -291,22 +309,27 @@ def display_with_context(f):

@functools.wraps(f)
def wrapper(*args, **kwargs):
log_only = args[5] if len(args) >= 6 else kwargs.get('log_only', False)
stderr = args[3] if len(args) >= 4 else kwargs.get('stderr', False)
if _ANSIBLE_214 and multiprocessing_context.parent_process() is not None:
# core 2.14 and newer proxy display, return if we are in a fork
return f(*args, **kwargs)

sig = _getsignature(f)
callargs = _getcallargs(sig, *args, **kwargs)
log_only = callargs.get('log_only')
stderr = callargs.get('stderr')
event_uuid = event_context.get().get('uuid', None)
with event_context.display_lock:
# If writing only to a log file or there is already an event UUID
# set (from a callback module method), skip dumping the event data.
if log_only or event_uuid:
return f(*args, **kwargs)
try:
fileobj = sys.stderr if stderr else sys.stdout
event_context.add_local(uuid=str(uuid.uuid4()))
event_context.dump_begin(fileobj)
return f(*args, **kwargs)
finally:
event_context.dump_end(fileobj)
event_context.remove_local(uuid=None)
# If writing only to a log file or there is already an event UUID
# set (from a callback module method), skip dumping the event data.
if log_only or event_uuid:
return f(*args, **kwargs)
try:
fileobj = sys.stderr if stderr else sys.stdout
event_context.add_local(uuid=str(uuid.uuid4()))
event_context.dump_begin(fileobj)
return f(*args, **kwargs)
finally:
event_context.dump_end(fileobj)
event_context.remove_local(uuid=None)

return wrapper

Expand Down Expand Up @@ -370,18 +393,17 @@ def capture_event_data(self, event, **event_data):
if isinstance(item, dict) and item.get('_ansible_no_log', False):
event_data['res']['results'][i] = {'censored': CENSORED}

with event_context.display_lock:
try:
event_context.add_local(event=event, **event_data)
if task:
self.set_task(task, local=True)
event_context.dump_begin(sys.stdout)
yield
finally:
event_context.dump_end(sys.stdout)
if task:
self.clear_task(local=True)
event_context.remove_local(event=None, **event_data)
try:
event_context.add_local(event=event, **event_data)
if task:
self.set_task(task, local=True)
event_context.dump_begin(sys.stdout)
yield
finally:
event_context.dump_end(sys.stdout)
if task:
self.clear_task(local=True)
event_context.remove_local(event=None, **event_data)

def set_playbook(self, playbook):
file_name = getattr(playbook, '_file_name', '???')
Expand Down