Skip to content

Commit

Permalink
fixes #405
Browse files Browse the repository at this point in the history
  • Loading branch information
jph00 committed Apr 11, 2022
1 parent 0dac97a commit 85c5a7f
Show file tree
Hide file tree
Showing 4 changed files with 336 additions and 63 deletions.
2 changes: 2 additions & 0 deletions fastcore/_nbdev.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@
"gather_attrs": "05_transform.ipynb",
"gather_attr_names": "05_transform.ipynb",
"Pipeline": "05_transform.ipynb",
"docstring": "06_docments.ipynb",
"parse_docstring": "06_docments.ipynb",
"empty": "06_docments.ipynb",
"docments": "06_docments.ipynb",
"test_sig": "07_meta.ipynb",
Expand Down
42 changes: 37 additions & 5 deletions fastcore/docments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,36 @@
from __future__ import annotations


__all__ = ['empty', 'docments']
__all__ = ['docstring', 'parse_docstring', 'empty', 'docments']

# Cell
#nbdev_comment from __future__ import annotations

import re
from tokenize import tokenize,COMMENT
from ast import parse,FunctionDef
from io import BytesIO
from textwrap import dedent
from types import SimpleNamespace
from inspect import getsource,isfunction,isclass,signature,Parameter
from .basics import *
from .utils import *

import re
from fastcore import docscrape
from inspect import isclass

# Cell
def docstring(sym):
"Get docstring for `sym` for functions ad classes"
if isinstance(sym, str): return sym
res = getattr(sym, "__doc__", None)
if not res and isclass(sym): res = nested_attr(sym, "__init__.__doc__")
return res or ""

# Cell
def parse_docstring(sym):
"Parse a numpy-style docstring in `sym`"
docs = docstring(sym)
return AttrDict(**docscrape.NumpyDocString(docstring(sym)))

# Cell
def _parses(s):
Expand All @@ -36,7 +53,7 @@ def _clean_comment(s):
def _param_locs(s, returns=True):
"`dict` of parameter line numbers to names"
body = _parses(s).body
if len(body)!=1or not isinstance(body[0], FunctionDef): return None
if len(body)!=1 or not isinstance(body[0], FunctionDef): return None
defn = body[0]
res = {arg.lineno:arg.arg for arg in defn.args.args}
if returns and defn.returns: res[defn.returns.lineno] = 'return'
Expand All @@ -59,21 +76,36 @@ def _get_full(anno, name, default, docs):
if anno==empty and default!=empty: anno = type(default)
return AttrDict(docment=docs.get(name), anno=anno, default=default)

# Cell
def _merge_doc(dm, npdoc):
if not npdoc: return dm
if not dm.anno or dm.anno==empty: dm.anno = npdoc.type
if not dm.docment: dm.docment = '\n'.join(npdoc.desc)
return dm

def _merge_docs(dms, npdocs):
npparams = npdocs['Parameters']
params = {nm:_merge_doc(dm,npparams.get(nm,None)) for nm,dm in dms.items()}
if 'return' in dms: params['return'] = _merge_doc(dms['return'], npdocs['Returns'])
return params

# Cell
def docments(s, full=False, returns=True, eval_str=False):
"`dict` of parameter names to 'docment-style' comments in function or string `s`"
nps = parse_docstring(s)
if isclass(s): s = s.__init__ # Constructor for a class
comments = {o.start[0]:_clean_comment(o.string) for o in _tokens(s) if o.type==COMMENT}
parms = _param_locs(s, returns=returns)
docs = {arg:_get_comment(line, arg, comments, parms) for line,arg in parms.items()}
if not full: return AttrDict(docs)

if isinstance(s,str): s = eval(s)
sig = signature(s)
res = {arg:_get_full(p.annotation, p.name, p.default, docs) for arg,p in sig.parameters.items()}
if returns: res['return'] = _get_full(sig.return_annotation, 'return', empty, docs)
res = _merge_docs(res, nps)
if eval_str:
hints = type_hints(s)
for k,v in res.items():
if k in hints: v['anno'] = hints.get(k)
if not full: res = {k:v['docment'] for k,v in res.items()}
return AttrDict(res)
37 changes: 14 additions & 23 deletions fastcore/docscrape.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
contributors may be used to endorse or promote products derived
from this software without specific prior written permission. """

import textwrap, re, copy
from warnings import warn
from collections import namedtuple
from collections.abc import Mapping

__all__ = ['Parameter', 'NumpyDocString', 'dedent_lines']

Parameter = namedtuple('Parameter', ['name', 'type', 'desc'])
Expand Down Expand Up @@ -89,14 +94,17 @@ def __str__(self):

class NumpyDocString(Mapping):
"""Parses a numpydoc string to an abstract representation """
sections = { 'Signature': '', 'Summary': [''], 'Extended': [], 'Parameters': [], 'Returns': [], 'Yields': [], 'Raises': [] }
sections = { 'Summary': [''], 'Extended': [], 'Parameters': [], 'Returns': [] }

def __init__(self, docstring, config=None):
docstring = textwrap.dedent(docstring).split('\n')
self._doc = Reader(docstring)
self._parsed_data = copy.deepcopy(self.sections)
self._parse()
if 'Parameters' in self: self['Parameters'] = {o.name:o for o in self['Parameters']}
self['Parameters'] = {o.name:o for o in self['Parameters']}
if self['Returns']: self['Returns'] = self['Returns'][0]
self['Summary'] = dedent_lines(self['Summary'], split=False)
self['Extended'] = dedent_lines(self['Extended'], split=False)

def __iter__(self): return iter(self._parsed_data)
def __len__(self): return len(self._parsed_data)
Expand Down Expand Up @@ -171,7 +179,6 @@ def _parse_summary(self):
summary_str = " ".join([s.strip() for s in summary]).strip()
compiled = re.compile(r'^([\w., ]+=)?\s*[\w\.]+\(.*\)$')
if compiled.match(summary_str):
self['Signature'] = summary_str
if not self._is_at_section(): continue
break

Expand Down Expand Up @@ -216,16 +223,12 @@ def _obj(self):

def _error_location(self, msg, error=True):
if self._obj is not None:
# we know where the docs came from:
try: filename = inspect.getsourcefile(self._obj)
except TypeError: filename = None
# Make UserWarning more descriptive via object introspection.
# Skip if introspection fails
name = getattr(self._obj, '__name__', None)
if name is None:
name = getattr(getattr(self._obj, '__class__', None), '__name__', None)
if name is not None: msg += f" in the docstring of {name}"
msg += f" in {filename}." if filename else ""
if error: raise ValueError(msg)
else: warn(msg)

Expand All @@ -234,10 +237,6 @@ def _error_location(self, msg, error=True):
def _str_header(self, name, symbol='-'): return [name, len(name)*symbol]
def _str_indent(self, doc, indent=4): return [' '*indent + line for line in doc]

def _str_signature(self):
if self['Signature']: return [self['Signature'].replace('*', r'\*')] + ['']
return ['']

def _str_summary(self):
if self['Summary']: return self['Summary'] + ['']
return []
Expand All @@ -259,18 +258,10 @@ def _str_param_list(self, name):
out += ['']
return out

def __str__(self, func_role=''):
out = []
out += self._str_signature()
out += self._str_summary()
out += self._str_extended_summary()
for param_list in ('Parameters', 'Returns', 'Yields', 'Receives', 'Other Parameters', 'Raises', 'Warns'):
out += self._str_param_list(param_list)
for param_list in ('Attributes', 'Methods'): out += self._str_param_list(param_list)
return '\n'.join(out)


def dedent_lines(lines):
def dedent_lines(lines, split=True):
"""Deindent a list of lines maximally"""
return textwrap.dedent("\n".join(lines)).split("\n")
res = textwrap.dedent("\n".join(lines))
if split: res = res.split("\n")
return res

Loading

0 comments on commit 85c5a7f

Please sign in to comment.