Skip to content

Commit

Permalink
Use flake8-bugbear; use async def coroutines
Browse files Browse the repository at this point in the history
  • Loading branch information
sloria committed Jul 9, 2018
1 parent cd55fe7 commit e11b7c0
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 85 deletions.
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pre-commit==1.10.3

# Syntax checking
flake8==3.5.0
flake8-bugbear=18.2.0

# Install this package in development mode
-e '.'
9 changes: 3 additions & 6 deletions examples/aiohttp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@
hello_args = {"name": fields.Str(missing="Friend")}


@asyncio.coroutine
@use_args(hello_args)
def index(request, args):
async def index(request, args):
"""A welcome page.
"""
return json_response({"message": "Welcome, {}!".format(args["name"])})
Expand All @@ -34,9 +33,8 @@ def index(request, args):
add_args = {"x": fields.Float(required=True), "y": fields.Float(required=True)}


@asyncio.coroutine
@use_kwargs(add_args)
def add(request, x, y):
async def add(request, x, y):
"""An addition endpoint."""
return json_response({"result": x + y})

Expand All @@ -48,9 +46,8 @@ def add(request, x, y):
}


@asyncio.coroutine
@use_kwargs(dateadd_args)
def dateadd(request, value, addend, unit):
async def dateadd(request, value, addend, unit):
"""A datetime adder endpoint."""
value = value or dt.datetime.utcnow()
if unit == "minutes":
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ universal = 1

[flake8]
ignore = E203, E266, E501, W503, E302
max-line-length = 80
max-line-length = 100
max-complexity = 18
select = B,C,E,F,W,T4,B9
exclude = .git,.ropeproject,.tox,build,env,venv,__pycache__
87 changes: 33 additions & 54 deletions tests/apps/aiohttp_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,140 +23,119 @@ class HelloSchema(ma.Schema):
##### Handlers #####


@asyncio.coroutine
def echo(request):
parsed = yield from parser.parse(hello_args, request)
async def echo(request):
parsed = await parser.parse(hello_args, request)
return json_response(parsed)


@asyncio.coroutine
def echo_query(request):
parsed = yield from parser.parse(hello_args, request, locations=("query",))
async def echo_query(request):
parsed = await parser.parse(hello_args, request, locations=("query",))
return json_response(parsed)


@asyncio.coroutine
@use_args(hello_args)
def echo_use_args(request, args):
async def echo_use_args(request, args):
return json_response(args)


@asyncio.coroutine
@use_kwargs(hello_args)
def echo_use_kwargs(request, name):
async def echo_use_kwargs(request, name):
return json_response({"name": name})


@asyncio.coroutine
@use_args({"value": fields.Int()}, validate=lambda args: args["value"] > 42)
def echo_use_args_validated(request, args):
async def echo_use_args_validated(request, args):
return json_response(args)


@asyncio.coroutine
def echo_multi(request):
parsed = yield from parser.parse(hello_multiple, request)
async def echo_multi(request):
parsed = await parser.parse(hello_multiple, request)
return json_response(parsed)


@asyncio.coroutine
def echo_many_schema(request):
parsed = yield from parser.parse(hello_many_schema, request, locations=("json",))
async def echo_many_schema(request):
parsed = await parser.parse(hello_many_schema, request, locations=("json",))
return json_response(parsed)


@asyncio.coroutine
@use_args({"value": fields.Int()})
def echo_use_args_with_path_param(request, args):
async def echo_use_args_with_path_param(request, args):
return json_response(args)


@asyncio.coroutine
@use_kwargs({"value": fields.Int()})
def echo_use_kwargs_with_path_param(request, value):
async def echo_use_kwargs_with_path_param(request, value):
return json_response({"value": value})


@asyncio.coroutine
def always_error(request):
async def always_error(request):
def always_fail(value):
raise ValidationError("something went wrong")

args = {"text": fields.Str(validate=always_fail)}
parsed = yield from parser.parse(args, request)
parsed = await parser.parse(args, request)
return json_response(parsed)


@asyncio.coroutine
def error400(request):
async def error400(request):
def always_fail(value):
raise ValidationError("something went wrong", status_code=400)

args = {"text": fields.Str(validate=always_fail)}
parsed = yield from parser.parse(args, request)
parsed = await parser.parse(args, request)
return json_response(parsed)


@asyncio.coroutine
def error_invalid(request):
async def error_invalid(request):
def always_fail(value):
raise ValidationError("something went wrong", status_code=12345)

args = {"text": fields.Str(validate=always_fail)}
parsed = yield from parser.parse(args, request)
parsed = await parser.parse(args, request)
return json_response(parsed)


@asyncio.coroutine
def echo_headers(request):
parsed = yield from parser.parse(hello_args, request, locations=("headers",))
async def echo_headers(request):
parsed = await parser.parse(hello_args, request, locations=("headers",))
return json_response(parsed)


@asyncio.coroutine
def echo_cookie(request):
parsed = yield from parser.parse(hello_args, request, locations=("cookies",))
async def echo_cookie(request):
parsed = await parser.parse(hello_args, request, locations=("cookies",))
return json_response(parsed)


@asyncio.coroutine
def echo_nested(request):
async def echo_nested(request):
args = {"name": fields.Nested({"first": fields.Str(), "last": fields.Str()})}
parsed = yield from parser.parse(args, request)
parsed = await parser.parse(args, request)
return json_response(parsed)


@asyncio.coroutine
def echo_multiple_args(request):
async def echo_multiple_args(request):
args = {"first": fields.Str(), "last": fields.Str()}
parsed = yield from parser.parse(args, request)
parsed = await parser.parse(args, request)
return json_response(parsed)


@asyncio.coroutine
def echo_nested_many(request):
async def echo_nested_many(request):
args = {
"users": fields.Nested({"id": fields.Int(), "name": fields.Str()}, many=True)
}
parsed = yield from parser.parse(args, request)
parsed = await parser.parse(args, request)
return json_response(parsed)


@asyncio.coroutine
def echo_nested_many_data_key(request):
async def echo_nested_many_data_key(request):
data_key_kwarg = {
"load_from" if (MARSHMALLOW_VERSION_INFO[0] < 3) else "data_key": "X-Field"
}
args = {"x_field": fields.Nested({"id": fields.Int()}, many=True, **data_key_kwarg)}
parsed = yield from parser.parse(args, request)
parsed = await parser.parse(args, request)
return json_response(parsed)


@asyncio.coroutine
def echo_match_info(request):
parsed = yield from parser.parse(
{"mymatch": fields.Int(location="match_info")}, request
)
async def echo_match_info(request):
parsed = await parser.parse({"mymatch": fields.Int(location="match_info")}, request)
return json_response(parsed)


Expand Down
11 changes: 4 additions & 7 deletions webargs/aiohttpparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def index(request, args):
app = web.Application()
app.router.add_route('GET', '/', index)
"""
import asyncio
import json
import warnings

Expand Down Expand Up @@ -85,22 +84,20 @@ def parse_querystring(self, req, name, field):
"""Pull a querystring value from the request."""
return core.get_value(req.query, name, field)

@asyncio.coroutine
def parse_form(self, req, name, field):
async def parse_form(self, req, name, field):
"""Pull a form value from the request."""
post_data = self._cache.get("post")
if post_data is None:
self._cache["post"] = yield from req.post()
self._cache["post"] = await req.post()
return core.get_value(self._cache["post"], name, field)

@asyncio.coroutine
def parse_json(self, req, name, field):
async def parse_json(self, req, name, field):
"""Pull a json value from the request."""
json_data = self._cache.get("json")
if json_data is None:
if not (req.body_exists and is_json_request(req)):
return core.missing
self._cache["json"] = json_data = yield from req.json()
self._cache["json"] = json_data = await req.json()
return core.get_value(json_data, name, field, allow_many_nested=True)

def parse_headers(self, req, name, field):
Expand Down
32 changes: 15 additions & 17 deletions webargs/asyncparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@ class AsyncParser(core.Parser):
either coroutines or regular methods.
"""

@asyncio.coroutine
def _parse_request(self, schema, req, locations):
async def _parse_request(self, schema, req, locations):
if schema.many:
assert (
"json" in locations
), "schema.many=True is only supported for JSON location"
# The ad hoc Nested field is more like a workaround or a helper, and it servers its
# purpose fine. However, if somebody has a desire to re-design the support of
# bulk-type arguments, go ahead.
parsed = yield from self.parse_arg(
parsed = await self.parse_arg(
name="json",
field=ma.fields.Nested(schema, many=True),
req=req,
Expand All @@ -38,27 +37,28 @@ def _parse_request(self, schema, req, locations):
parsed = {}
for argname, field_obj in argdict.items():
if core.MARSHMALLOW_VERSION_INFO[0] < 3:
parsed_value = yield from self.parse_arg(
parsed_value = await self.parse_arg(
argname, field_obj, req, locations
)
# If load_from is specified on the field, try to parse from that key
if parsed_value is missing and field_obj.load_from:
parsed_value = yield from self.parse_arg(
parsed_value = await self.parse_arg(
field_obj.load_from, field_obj, req, locations
)
argname = field_obj.load_from
else:
argname = field_obj.data_key or argname
parsed_value = yield from self.parse_arg(
parsed_value = await self.parse_arg(
argname, field_obj, req, locations
)
if parsed_value is not missing:
parsed[argname] = parsed_value
return parsed

# TODO: Lots of duplication from core.Parser here. Rethink.
@asyncio.coroutine
def parse(self, argmap, req=None, locations=None, validate=None, force_all=False):
async def parse(
self, argmap, req=None, locations=None, validate=None, force_all=False
):
"""Coroutine variant of `webargs.core.Parser`.
Receives the same arguments as `webargs.core.Parser.parse`.
Expand All @@ -69,7 +69,7 @@ def parse(self, argmap, req=None, locations=None, validate=None, force_all=False
validators = core._ensure_list_of_callables(validate)
schema = self._get_schema(argmap, req)
try:
parsed = yield from self._parse_request(
parsed = await self._parse_request(
schema=schema, req=req, locations=locations
)
result = schema.load(parsed)
Expand Down Expand Up @@ -139,7 +139,7 @@ def wrapper(*args, **kwargs):
if not req_obj:
req_obj = self.get_request_from_view_args(func, args, kwargs)
# NOTE: At this point, argmap may be a Schema, callable, or dict
parsed_args = yield from self.parse(
parsed_args = yield from self.parse( # noqa: B901
argmap,
req=req_obj,
locations=locations,
Expand All @@ -148,7 +148,7 @@ def wrapper(*args, **kwargs):
)
if as_kwargs:
kwargs.update(parsed_args)
return func(*args, **kwargs)
return func(*args, **kwargs) # noqa: B901
else:
# Add parsed_args after other positional arguments
new_args = args + (parsed_args,)
Expand All @@ -167,23 +167,21 @@ def use_kwargs(self, *args, **kwargs):
"""
return super().use_kwargs(*args, **kwargs)

@asyncio.coroutine
def parse_arg(self, name, field, req, locations=None):
async def parse_arg(self, name, field, req, locations=None):
location = field.metadata.get("location")
if location:
locations_to_check = self._validated_locations([location])
else:
locations_to_check = self._validated_locations(locations or self.locations)

for location in locations_to_check:
value = yield from self._get_value(name, field, req=req, location=location)
value = await self._get_value(name, field, req=req, location=location)
# Found the value; validate and return it
if value is not core.missing:
return value
return core.missing

@asyncio.coroutine
def _get_value(self, name, argobj, req, location):
async def _get_value(self, name, argobj, req, location):
# Parsing function to call
# May be a method name (str) or a function
func = self.__location_map__.get(location)
Expand All @@ -193,7 +191,7 @@ def _get_value(self, name, argobj, req, location):
else:
function = getattr(self, func)
if asyncio.iscoroutinefunction(function):
value = yield from function(req, name, argobj)
value = await function(req, name, argobj)
else:
value = function(req, name, argobj)
else:
Expand Down

0 comments on commit e11b7c0

Please sign in to comment.