Skip to content

Commit

Permalink
Refactor through new function DocTestController.source_baseline
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthias Koeppe committed Dec 20, 2023
1 parent 48cde67 commit d52858c
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 32 deletions.
28 changes: 28 additions & 0 deletions src/sage/doctest/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,34 @@ def sort_key(source):
return -self.stats.get(basename, default).get('walltime', 0), basename
self.sources = sorted(self.sources, key=sort_key)

def source_baseline(self, source):
r"""
Return the ``baseline_stats`` value of ``source``.
INPUT:
- ``source`` -- a :class:`DocTestSource` instance
OUTPUT:
A dictionary.
EXAMPLES::
sage: from sage.doctest.control import DocTestDefaults, DocTestController
sage: from sage.env import SAGE_SRC
sage: import os
sage: filename = os.path.join(SAGE_SRC,'sage','doctest','util.py')
sage: DD = DocTestDefaults()
sage: DC = DocTestController(DD, [filename])
sage: DC.source_baseline(DC.sources[0])
{}
"""
if self.baseline_stats:
basename = source.basename
return self.baseline_stats.get(basename, {})
return {}

def run_doctests(self):
"""
Actually runs the doctests.
Expand Down
35 changes: 13 additions & 22 deletions src/sage/doctest/forker.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,7 @@ def __init__(self, *args, **kwds):
- ``optionflags`` -- Controls the comparison with the expected
output. See :mod:`testmod` for more information.
- ``baseline`` -- ``None`` or a dictionary, the ``baseline_stats``
value
- ``baseline`` -- dictionary, the ``baseline_stats`` value
EXAMPLES::
Expand All @@ -538,9 +537,7 @@ def __init__(self, *args, **kwds):
O = kwds.pop('outtmpfile', None)
self.msgfile = kwds.pop('msgfile', None)
self.options = kwds.pop('sage_options')
self.baseline = kwds.pop('baseline', None)
if self.baseline is None:
self.baseline = {}
self.baseline = kwds.pop('baseline', {})
doctest.DocTestRunner.__init__(self, *args, **kwds)
self._fakeout = SageSpoofInOut(O)
if self.msgfile is None:
Expand Down Expand Up @@ -1727,20 +1724,16 @@ def serial_dispatch(self):
"""
for source in self.controller.sources:
heading = self.controller.reporter.report_head(source)
basename = source.basename
if self.controller.baseline_stats:
the_baseline_stats = self.controller.baseline_stats.get(basename, {})
else:
the_baseline_stats = {}
if the_baseline_stats.get('failed', False):
baseline = self.controller.source_baseline(source)
if baseline.get('failed', False):
heading += " # [failed in baseline]"
if not self.controller.options.only_errors:
self.controller.log(heading)

with tempfile.TemporaryFile() as outtmpfile:
result = DocTestTask(source)(self.controller.options,
outtmpfile, self.controller.logger,
baseline=the_baseline_stats)
baseline=baseline)
outtmpfile.seek(0)
output = bytes_to_str(outtmpfile.read())

Expand Down Expand Up @@ -1998,16 +1991,12 @@ def sel_exit():
# Start a new worker.
import copy
worker_options = copy.copy(opt)
basename = source.basename
if self.controller.baseline_stats:
the_baseline_stats = self.controller.baseline_stats.get(basename, {})
else:
the_baseline_stats = {}
baseline = self.controller.source_baseline(source)
if target_endtime is not None:
worker_options.target_walltime = (target_endtime - now) / (max(1, pending_tests / opt.nthreads))
w = DocTestWorker(source, options=worker_options, baseline=the_baseline_stats, funclist=[sel_exit])
w = DocTestWorker(source, options=worker_options, funclist=[sel_exit], baseline=baseline)
heading = self.controller.reporter.report_head(w.source)
if the_baseline_stats.get('failed', False):
if baseline.get('failed', False):
heading += " # [failed in baseline]"
if not self.controller.options.only_errors:
w.messages = heading + "\n"
Expand Down Expand Up @@ -2148,6 +2137,8 @@ class should be accessed by the child process.
- ``funclist`` -- a list of callables to be called at the start of
the child process.
- ``baseline`` -- dictionary, the ``baseline_stats`` value
EXAMPLES::
sage: from sage.doctest.forker import DocTestWorker, DocTestTask
Expand Down Expand Up @@ -2263,7 +2254,8 @@ def run(self):
os.close(self.rmessages)
msgpipe = os.fdopen(self.wmessages, "w")
try:
task(self.options, self.outtmpfile, msgpipe, self.result_queue, baseline=self.baseline)
task(self.options, self.outtmpfile, msgpipe, self.result_queue,
baseline=self.baseline)
finally:
msgpipe.close()
self.outtmpfile.close()
Expand Down Expand Up @@ -2547,8 +2539,7 @@ def __call__(self, options, outtmpfile=None, msgfile=None, result_queue=None, *,
- ``result_queue`` -- an instance of :class:`multiprocessing.Queue`
to store the doctest result. For testing, this can also be None.
- ``baseline`` -- ``None`` or a dictionary, the ``baseline_stats``
value.
- ``baseline`` -- a dictionary, the ``baseline_stats`` value.
OUTPUT:
Expand Down
17 changes: 7 additions & 10 deletions src/sage/doctest/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,7 @@ def report(self, source, timeout, return_code, results, output, pid=None):
postscript = self.postscript
stats = self.stats
basename = source.basename
if self.controller.baseline_stats:
the_baseline_stats = self.controller.baseline_stats.get(basename, {})
else:
the_baseline_stats = {}
baseline = self.controller.source_baseline(source)
cmd = self.report_head(source)
try:
ntests, result_dict = results
Expand All @@ -423,14 +420,14 @@ def report(self, source, timeout, return_code, results, output, pid=None):
fail_msg += " (and interrupt failed)"
else:
fail_msg += " (with %s after interrupt)" % signal_name(sig)
if the_baseline_stats.get('failed', False):
if baseline.get('failed', False):
fail_msg += " [failed in baseline]"
log(" %s\n%s\nTests run before %s timed out:" % (fail_msg, "*"*70, process_name))
log(output)
log("*"*70)
postscript['lines'].append(cmd + " # %s" % fail_msg)
stats[basename] = {"failed": True, "walltime": 1e6, "ntests": ntests}
if not the_baseline_stats.get('failed', False):
if not baseline.get('failed', False):
self.error_status |= 4
elif return_code:
if return_code > 0:
Expand All @@ -439,14 +436,14 @@ def report(self, source, timeout, return_code, results, output, pid=None):
fail_msg = "Killed due to %s" % signal_name(-return_code)
if ntests > 0:
fail_msg += " after testing finished"
if the_baseline_stats.get('failed', False):
if baseline.get('failed', False):
fail_msg += " [failed in baseline]"
log(" %s\n%s\nTests run before %s failed:" % (fail_msg,"*"*70, process_name))
log(output)
log("*"*70)
postscript['lines'].append(cmd + " # %s" % fail_msg)
stats[basename] = {"failed": True, "walltime": 1e6, "ntests": ntests}
if not the_baseline_stats.get('failed', False):
if not baseline.get('failed', False):
self.error_status |= (8 if return_code > 0 else 16)
else:
if hasattr(result_dict, 'walltime') and hasattr(result_dict.walltime, '__len__') and len(result_dict.walltime) > 0:
Expand Down Expand Up @@ -509,10 +506,10 @@ def report(self, source, timeout, return_code, results, output, pid=None):
f = result_dict.failures
if f:
fail_msg = "%s failed" % (count_noun(f, "doctest"))
if the_baseline_stats.get('failed', False):
if baseline.get('failed', False):
fail_msg += " [failed in baseline]"
postscript['lines'].append(cmd + " # %s" % fail_msg)
if not the_baseline_stats.get('failed', False):
if not baseline.get('failed', False):
self.error_status |= 1
if f or result_dict.err == 'tab':
stats[basename] = {"failed": True, "walltime": wall, "ntests": ntests}
Expand Down

0 comments on commit d52858c

Please sign in to comment.