From 5480ba64bcabbda2ca7ebb721f3e1ba63bdb88d8 Mon Sep 17 00:00:00 2001 From: Sourcery AI <> Date: Mon, 24 Jul 2023 04:45:02 +0000 Subject: [PATCH] 'Refactored by Sourcery' --- scripts/generate_identifier_pattern.py | 5 +-- src/jinja2/async_utils.py | 5 +-- src/jinja2/compiler.py | 36 ++++++----------- src/jinja2/debug.py | 6 +-- src/jinja2/environment.py | 54 +++++++++----------------- src/jinja2/exceptions.py | 7 ++-- src/jinja2/ext.py | 11 +----- src/jinja2/filters.py | 47 ++++++---------------- src/jinja2/idtracking.py | 16 ++------ src/jinja2/lexer.py | 12 ++---- src/jinja2/loaders.py | 13 +++---- src/jinja2/nativetypes.py | 5 +-- src/jinja2/nodes.py | 27 +++++-------- src/jinja2/parser.py | 27 ++++--------- src/jinja2/runtime.py | 46 +++++----------------- src/jinja2/sandbox.py | 22 +++++------ src/jinja2/tests.py | 4 +- src/jinja2/utils.py | 26 +++---------- src/jinja2/visitor.py | 5 +-- tests/test_api.py | 12 +++--- tests/test_async.py | 4 +- tests/test_core_tags.py | 2 +- tests/test_ext.py | 12 ++---- tests/test_filters.py | 2 +- tests/test_regression.py | 8 ++-- 25 files changed, 127 insertions(+), 287 deletions(-) diff --git a/scripts/generate_identifier_pattern.py b/scripts/generate_identifier_pattern.py index 7fc64aed0..4e4b9accc 100755 --- a/scripts/generate_identifier_pattern.py +++ b/scripts/generate_identifier_pattern.py @@ -19,7 +19,7 @@ def get_characters(): for cp in range(sys.maxunicode + 1): s = chr(cp) - if ("a" + s).isidentifier() and not re.match(r"\w", s): + if f"a{s}".isidentifier() and not re.match(r"\w", s): yield s @@ -45,8 +45,7 @@ def build_pattern(ranges): if a == b: # single char out.append(a) elif ord(b) - ord(a) == 1: # two chars, range is redundant - out.append(a) - out.append(b) + out.extend((a, b)) else: out.append(f"{a}-{b}") diff --git a/src/jinja2/async_utils.py b/src/jinja2/async_utils.py index e65219e49..36f6c3a5a 100644 --- a/src/jinja2/async_utils.py +++ b/src/jinja2/async_utils.py @@ -39,10 +39,7 @@ def wrapper(*args, **kwargs): # type: ignore if need_eval_context: args = args[1:] - if b: - return async_func(*args, **kwargs) - - return normal_func(*args, **kwargs) + return async_func(*args, **kwargs) if b else normal_func(*args, **kwargs) if need_eval_context: wrapper = pass_eval_context(wrapper) diff --git a/src/jinja2/compiler.py b/src/jinja2/compiler.py index 7dfac0a71..a3091b1c1 100644 --- a/src/jinja2/compiler.py +++ b/src/jinja2/compiler.py @@ -90,7 +90,7 @@ def visitor(self: "CodeGenerator", node: nodes.UnaryExpr, frame: Frame) -> None: self.write(f"environment.call_unop(context, {op!r}, ") self.visit(node.node, frame) else: - self.write("(" + op) + self.write(f"({op}") self.visit(node.node, frame) self.write(")") @@ -116,10 +116,7 @@ def generate( ) generator.visit(node) - if stream is None: - return generator.stream.getvalue() # type: ignore - - return None + return generator.stream.getvalue() if stream is None else None def has_safe_repr(value: t.Any) -> bool: @@ -596,10 +593,7 @@ def enter_frame(self, frame: Frame) -> None: def leave_frame(self, frame: Frame, with_python_scope: bool = False) -> None: if not with_python_scope: - undefs = [] - for target in frame.symbols.loads: - undefs.append(target) - if undefs: + if undefs := list(frame.symbols.loads): self.writeline(f"{' = '.join(undefs)} = missing") def choose_async(self, async_value: str = "async ", sync_value: str = "") -> str: @@ -761,18 +755,14 @@ def get_context_ref(self) -> str: def get_resolve_func(self) -> str: target = self._context_reference_stack[-1] - if target == "context": - return "resolve" - return f"{target}.resolve" + return "resolve" if target == "context" else f"{target}.resolve" def derive_context(self, frame: Frame) -> str: return f"{self.get_context_ref()}.derived({self.dump_local_context(frame)})" def parameter_is_undeclared(self, target: str) -> bool: """Checks if a given target is an undeclared parameter.""" - if not self._param_def_block: - return False - return target in self._param_def_block[-1] + return target in self._param_def_block[-1] if self._param_def_block else False def push_assign_tracking(self) -> None: """Pushes a new layer for assignment tracking.""" @@ -909,7 +899,7 @@ def visit_Template( # at this point we now have the blocks collected and can visit them too. for name, block in self.blocks.items(): self.writeline( - f"{self.func('block_' + name)}(context, missing=missing{envenv}):", + f"{self.func(f'block_{name}')}(context, missing=missing{envenv}):", block, 1, ) @@ -954,11 +944,7 @@ def visit_Block(self, node: nodes.Block, frame: Frame) -> None: self.indent() level += 1 - if node.scoped: - context = self.derive_context(frame) - else: - context = self.get_context_ref() - + context = self.derive_context(frame) if node.scoped else self.get_context_ref() if node.required: self.writeline(f"if len(context.blocks[{node.name!r}]) <= 1:", node) self.indent() @@ -1530,9 +1516,9 @@ def visit_Output(self, node: nodes.Output, frame: Frame) -> None: val = self._output_const_repr(item) if frame.buffer is None: - self.writeline("yield " + val) + self.writeline(f"yield {val}") else: - self.writeline(val + ",") + self.writeline(f"{val},") else: if frame.buffer is None: self.writeline("yield ", item) @@ -1860,7 +1846,7 @@ def visit_Call( self.write("))") def visit_Keyword(self, node: nodes.Keyword, frame: Frame) -> None: - self.write(node.key + "=") + self.write(f"{node.key}=") self.visit(node.value, frame) # -- Unused nodes for extensions @@ -1880,7 +1866,7 @@ def visit_MarkSafeIfAutoescape( def visit_EnvironmentAttribute( self, node: nodes.EnvironmentAttribute, frame: Frame ) -> None: - self.write("environment." + node.name) + self.write(f"environment.{node.name}") def visit_ExtensionAttribute( self, node: nodes.ExtensionAttribute, frame: Frame diff --git a/src/jinja2/debug.py b/src/jinja2/debug.py index 7ed7e9297..a67ca21c2 100644 --- a/src/jinja2/debug.py +++ b/src/jinja2/debug.py @@ -154,11 +154,7 @@ def get_template_locals(real_locals: t.Mapping[str, t.Any]) -> t.Dict[str, t.Any # Start with the current template context. ctx: "t.Optional[Context]" = real_locals.get("context") - if ctx is not None: - data: t.Dict[str, t.Any] = ctx.get_all().copy() - else: - data = {} - + data = ctx.get_all().copy() if ctx is not None else {} # Might be in a derived context that only sets local variables # rather than pushing a context. Local variables follow the scheme # l_depth_name. Find the highest-depth local that has a value for diff --git a/src/jinja2/environment.py b/src/jinja2/environment.py index c9afae311..78692e46c 100644 --- a/src/jinja2/environment.py +++ b/src/jinja2/environment.py @@ -84,10 +84,7 @@ def create_cache( if size == 0: return None - if size < 0: - return {} - - return LRUCache(size) # type: ignore + return {} if size < 0 else LRUCache(size) def copy_cache( @@ -99,10 +96,7 @@ def copy_cache( if cache is None: return None - if type(cache) is dict: - return {} - - return LRUCache(cache.capacity) # type: ignore + return {} if type(cache) is dict else LRUCache(cache.capacity) def load_extensions( @@ -535,11 +529,7 @@ def _filter_test_common( args.insert(0, context) elif pass_arg is _PassArg.eval_context: if eval_ctx is None: - if context is not None: - eval_ctx = context.eval_ctx - else: - eval_ctx = EvalContext(self) - + eval_ctx = context.eval_ctx if context is not None else EvalContext(self) args.insert(0, eval_ctx) elif pass_arg is _PassArg.environment: args.insert(0, self) @@ -633,7 +623,7 @@ def lex( of the extensions to be applied you have to filter source through the :meth:`preprocess` method. """ - source = str(source) + source = source try: return self.lexer.tokeniter(source, name, filename) except TemplateSyntaxError: @@ -652,7 +642,7 @@ def preprocess( return reduce( lambda s, e: e.preprocess(s, name, filename), self.iter_extensions(), - str(source), + source, ) def _tokenize( @@ -1439,9 +1429,7 @@ def _get_default_module(self, ctx: t.Optional[Context] = None) -> "TemplateModul raise RuntimeError("Module is not available in async mode.") if ctx is not None: - keys = ctx.globals_keys - self.globals.keys() - - if keys: + if keys := ctx.globals_keys - self.globals.keys(): return self.make_module({k: ctx.parent[k] for k in keys}) if self._module is None: @@ -1453,9 +1441,7 @@ async def _get_default_module_async( self, ctx: t.Optional[Context] = None ) -> "TemplateModule": if ctx is not None: - keys = ctx.globals_keys - self.globals.keys() - - if keys: + if keys := ctx.globals_keys - self.globals.keys(): return await self.make_module_async({k: ctx.parent[k] for k in keys}) if self._module is None: @@ -1483,17 +1469,19 @@ def get_corresponding_lineno(self, lineno: int) -> int: """Return the source line number of a line number in the generated bytecode as they are not in sync. """ - for template_line, code_line in reversed(self.debug_info): - if code_line <= lineno: - return template_line - return 1 + return next( + ( + template_line + for template_line, code_line in reversed(self.debug_info) + if code_line <= lineno + ), + 1, + ) @property def is_up_to_date(self) -> bool: """If this variable is `False` there is a newer version available.""" - if self._uptodate is None: - return True - return self._uptodate() + return True if self._uptodate is None else self._uptodate() @property def debug_info(self) -> t.List[t.Tuple[int, int]]: @@ -1507,10 +1495,7 @@ def debug_info(self) -> t.List[t.Tuple[int, int]]: return [] def __repr__(self) -> str: - if self.name is None: - name = f"memory:{id(self):x}" - else: - name = repr(self.name) + name = f"memory:{id(self):x}" if self.name is None else repr(self.name) return f"<{type(self).__name__} {name}>" @@ -1547,10 +1532,7 @@ def __str__(self) -> str: return concat(self._body_stream) def __repr__(self) -> str: - if self.__name__ is None: - name = f"memory:{id(self):x}" - else: - name = repr(self.__name__) + name = f"memory:{id(self):x}" if self.__name__ is None else repr(self.__name__) return f"<{type(self).__name__} {name}>" diff --git a/src/jinja2/exceptions.py b/src/jinja2/exceptions.py index 082ebe8f2..20f978947 100644 --- a/src/jinja2/exceptions.py +++ b/src/jinja2/exceptions.py @@ -112,10 +112,9 @@ def __str__(self) -> str: # otherwise attach some stuff location = f"line {self.lineno}" - name = self.filename or self.name - if name: + if name := self.filename or self.name: location = f'File "{name}", {location}' - lines = [t.cast(str, self.message), " " + location] + lines = [t.cast(str, self.message), f" {location}"] # if the source is set, add the line to the output if self.source is not None: @@ -124,7 +123,7 @@ def __str__(self) -> str: except IndexError: pass else: - lines.append(" " + line.strip()) + lines.append(f" {line.strip()}") return "\n".join(lines) diff --git a/src/jinja2/ext.py b/src/jinja2/ext.py index 354b4063d..b0ef9a05d 100644 --- a/src/jinja2/ext.py +++ b/src/jinja2/ext.py @@ -355,12 +355,9 @@ def parse(self, parser: "Parser") -> t.Union[nodes.Node, t.List[nodes.Node]]: """Parse a translatable tag.""" lineno = next(parser.stream).lineno - context = None context_token = parser.stream.next_if("string") - if context_token is not None: - context = context_token.value - + context = context_token.value if context_token is not None else None # find all the variables referenced. Additionally a variable can be # defined in the body of the trans block too, but this is checked at # a later state. @@ -712,11 +709,7 @@ def extract_from_ast( if not out: continue else: - if len(strings) == 1: - out = strings[0] - else: - out = tuple(strings) - + out = strings[0] if len(strings) == 1 else tuple(strings) yield node.lineno, node.node.name, out diff --git a/src/jinja2/filters.py b/src/jinja2/filters.py index f4479ff3d..bf708b117 100644 --- a/src/jinja2/filters.py +++ b/src/jinja2/filters.py @@ -46,10 +46,7 @@ def __html__(self) -> str: def ignore_case(value: V) -> V: """For use as a postprocessor for :func:`make_attrgetter`. Converts strings to lowercase and returns other types as-is.""" - if isinstance(value, str): - return t.cast(V, value.lower()) - - return value + return t.cast(V, value.lower()) if isinstance(value, str) else value def make_attrgetter( @@ -95,11 +92,7 @@ def make_multi_attrgetter( Examples of attribute: "attr1,attr2", "attr1.inner1.0,attr2.inner2.0", etc. """ - if isinstance(attribute, str): - split: t.Sequence[t.Union[str, int, None]] = attribute.split(",") - else: - split = [attribute] - + split = attribute.split(",") if isinstance(attribute, str) else [attribute] parts = [_prepare_attribute_parts(item) for item in split] def attrgetter(item: t.Any) -> t.List[t.Any]: @@ -162,11 +155,7 @@ def do_urlencode( if isinstance(value, str) or not isinstance(value, abc.Iterable): return url_quote(value) - if isinstance(value, dict): - items: t.Iterable[t.Tuple[str, t.Any]] = value.items() - else: - items = value # type: ignore - + items = value.items() if isinstance(value, dict) else value return "&".join( f"{url_quote(k, for_qs=True)}={url_quote(v, for_qs=True)}" for k, v in items ) @@ -194,7 +183,7 @@ def do_replace( count = -1 if not eval_ctx.autoescape: - return str(s).replace(str(old), str(new), count) + return s.replace(old, new, count) if ( hasattr(old, "__html__") @@ -281,7 +270,7 @@ def do_xmlattr( ) if autospace and rv: - rv = " " + rv + rv = f" {rv}" if eval_ctx.autoescape: rv = Markup(rv) @@ -568,7 +557,7 @@ def sync_do_join( # no automatic escaping? joining is a lot easier then if not eval_ctx.autoescape: - return str(d).join(map(str, value)) + return d.join(map(str, value)) # if the delimiter doesn't have an html representation we check # if any of the items has. If yes we do a coercion to Markup @@ -582,11 +571,7 @@ def sync_do_join( else: value[idx] = str(item) - if do_escape: - d = escape(d) - else: - d = str(d) - + d = escape(d) if do_escape else d return d.join(value) # no html involved, to normal joining @@ -799,11 +784,7 @@ def do_indent( Rename the ``indentfirst`` argument to ``first``. """ - if isinstance(width, str): - indention = width - else: - indention = " " * width - + indention = width if isinstance(width, str) else " " * width newline = "\n" if isinstance(s, Markup): @@ -951,10 +932,7 @@ def do_int(value: t.Any, default: int = 0, base: int = 10) -> int: The base is ignored for decimal numbers and non-string values. """ try: - if isinstance(value, str): - return int(value, base) - - return int(value) + return int(value, base) if isinstance(value, str) else int(value) except (TypeError, ValueError): # this quirk is necessary so that "42.23"|int gives 42. try: @@ -1039,8 +1017,7 @@ def sync_do_slice( """ seq = list(value) length = len(seq) - items_per_slice = length // slices - slices_with_extra = length % slices + items_per_slice, slices_with_extra = divmod(length, slices) offset = 0 for slice_number in range(slices): @@ -1338,7 +1315,7 @@ def do_mark_safe(value: str) -> Markup: def do_mark_unsafe(value: str) -> str: """Mark a value as unsafe. This is the reverse operation for :func:`safe`.""" - return str(value) + return value @typing.overload @@ -1380,7 +1357,7 @@ def do_attr( See :ref:`Notes on subscriptions ` for more details. """ try: - name = str(name) + name = name except UnicodeError: pass else: diff --git a/src/jinja2/idtracking.py b/src/jinja2/idtracking.py index 995ebaa0c..7db5a49b8 100644 --- a/src/jinja2/idtracking.py +++ b/src/jinja2/idtracking.py @@ -32,11 +32,7 @@ def __init__( self, parent: t.Optional["Symbols"] = None, level: t.Optional[int] = None ) -> None: if level is None: - if parent is None: - level = 0 - else: - level = parent.level + 1 - + level = 0 if parent is None else parent.level + 1 self.level: int = level self.parent = parent self.refs: t.Dict[str, str] = {} @@ -60,19 +56,13 @@ def find_load(self, target: str) -> t.Optional[t.Any]: if target in self.loads: return self.loads[target] - if self.parent is not None: - return self.parent.find_load(target) - - return None + return self.parent.find_load(target) if self.parent is not None else None def find_ref(self, name: str) -> t.Optional[str]: if name in self.refs: return self.refs[name] - if self.parent is not None: - return self.parent.find_ref(name) - - return None + return self.parent.find_ref(name) if self.parent is not None else None def ref(self, name: str) -> str: rv = self.find_ref(name) diff --git a/src/jinja2/lexer.py b/src/jinja2/lexer.py index 16ca73e1f..7d1a4913e 100644 --- a/src/jinja2/lexer.py +++ b/src/jinja2/lexer.py @@ -282,10 +282,7 @@ def test(self, expr: str) -> bool: if self.type == expr: return True - if ":" in expr: - return expr.split(":", 1) == [self.type, self.value] - - return False + return expr.split(":", 1) == [self.type, self.value] if ":" in expr else False def test_any(self, *iterable: str) -> bool: """Test against multiple token expressions.""" @@ -366,10 +363,7 @@ def next_if(self, expr: str) -> t.Optional[Token]: """Perform the token test and return the token if it matched. Otherwise the return value is `None`. """ - if self.current.test(expr): - return next(self) - - return None + return next(self) if self.current.test(expr) else None def skip_if(self, expr: str) -> bool: """Like :meth:`next_if` but only returns `True` or `False`.""" @@ -690,7 +684,7 @@ def tokeniter( if state is not None and state != "root": assert state in ("variable", "block"), "invalid state" - stack.append(state + "_begin") + stack.append(f"{state}_begin") statetokens = self.rules[stack[-1]] source_length = len(source) diff --git a/src/jinja2/loaders.py b/src/jinja2/loaders.py index 9b479be4e..765b35e87 100644 --- a/src/jinja2/loaders.py +++ b/src/jinja2/loaders.py @@ -463,10 +463,7 @@ def get_source( if rv is None: raise TemplateNotFound(template) - if isinstance(rv, str): - return rv, None, None - - return rv + return (rv, None, None) if isinstance(rv, str) else rv class PrefixLoader(BaseLoader): @@ -527,8 +524,10 @@ def load( def list_templates(self) -> t.List[str]: result = [] for prefix, loader in self.mapping.items(): - for template in loader.list_templates(): - result.append(prefix + self.delimiter + template) + result.extend( + prefix + self.delimiter + template + for template in loader.list_templates() + ) return result @@ -632,7 +631,7 @@ def get_template_key(name: str) -> str: @staticmethod def get_module_filename(name: str) -> str: - return ModuleLoader.get_template_key(name) + ".py" + return f"{ModuleLoader.get_template_key(name)}.py" @internalcode def load( diff --git a/src/jinja2/nativetypes.py b/src/jinja2/nativetypes.py index 71db8cc31..e19328cfb 100644 --- a/src/jinja2/nativetypes.py +++ b/src/jinja2/nativetypes.py @@ -67,10 +67,7 @@ def _output_child_to_const( if not has_safe_repr(const): raise nodes.Impossible() - if isinstance(node, nodes.TemplateData): - return const - - return finalize.const(const) # type: ignore + return const if isinstance(node, nodes.TemplateData) else finalize.const(const) def _output_child_pre( self, node: nodes.Expr, frame: Frame, finalize: CodeGenerator._FinalizeInfo diff --git a/src/jinja2/nodes.py b/src/jinja2/nodes.py index 00365ed83..dfdafb8c5 100644 --- a/src/jinja2/nodes.py +++ b/src/jinja2/nodes.py @@ -54,7 +54,7 @@ class NodeType(type): inheritance. fields and attributes from the parent class are automatically forwarded to the child.""" - def __new__(mcs, name, bases, d): # type: ignore + def __new__(cls, name, bases, d): # type: ignore for attr in "fields", "attributes": storage: t.List[t.Any] = [] storage.extend(getattr(bases[0] if bases else object, attr, ())) @@ -63,7 +63,7 @@ def __new__(mcs, name, bases, d): # type: ignore assert len(storage) == len(set(storage)), "layout conflict" d[attr] = tuple(storage) d.setdefault("abstract", False) - return type.__new__(mcs, name, bases, d) + return type.__new__(cls, name, bases, d) class EvalContext: @@ -220,8 +220,8 @@ def set_lineno(self, lineno: int, override: bool = False) -> "Node": todo = deque([self]) while todo: node = todo.popleft() - if "lineno" in node.attributes: - if node.lineno is None or override: + if node.lineno is None or override: + if "lineno" in node.attributes: node.lineno = lineno todo.extend(node.iter_child_nodes()) return self @@ -613,9 +613,7 @@ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> str: eval_ctx = get_eval_context(self, eval_ctx) if eval_ctx.volatile: raise Impossible() - if eval_ctx.autoescape: - return Markup(self.data) - return self.data + return Markup(self.data) if eval_ctx.autoescape else self.data class Tuple(Literal): @@ -633,10 +631,7 @@ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> t.Tuple[t.Any, . return tuple(x.as_const(eval_ctx) for x in self.items) def can_assign(self) -> bool: - for item in self.items: - if not item.can_assign(): - return False - return True + return all(item.can_assign() for item in self.items) class List(Literal): @@ -727,7 +722,7 @@ def args_as_const( if node.dyn_kwargs is not None: try: - kwargs.update(node.dyn_kwargs.as_const(eval_ctx)) + kwargs |= node.dyn_kwargs.as_const(eval_ctx) except Exception as e: raise Impossible() from e @@ -886,9 +881,7 @@ def as_const(self, eval_ctx: t.Optional[EvalContext] = None) -> slice: eval_ctx = get_eval_context(self, eval_ctx) def const(obj: t.Optional[Expr]) -> t.Optional[t.Any]: - if obj is None: - return None - return obj.as_const(eval_ctx) + return None if obj is None else obj.as_const(eval_ctx) return slice(const(self.start), const(self.stop), const(self.step)) @@ -1106,9 +1099,7 @@ def as_const( if eval_ctx.volatile: raise Impossible() expr = self.expr.as_const(eval_ctx) - if eval_ctx.autoescape: - return Markup(expr) - return expr + return Markup(expr) if eval_ctx.autoescape else expr class ContextReference(Expr): diff --git a/src/jinja2/parser.py b/src/jinja2/parser.py index 206e49523..286620495 100644 --- a/src/jinja2/parser.py +++ b/src/jinja2/parser.py @@ -239,9 +239,7 @@ def parse_for(self) -> nodes.For: iter = self.parse_tuple( with_condexpr=False, extra_end_rules=("name:recursive",) ) - test = None - if self.stream.skip_if("name:if"): - test = self.parse_expression() + test = self.parse_expression() if self.stream.skip_if("name:if") else None recursive = self.stream.skip_if("name:recursive") body = self.parse_statements(("name:endfor", "name:else")) if next(self.stream).value == "endfor": @@ -320,7 +318,7 @@ def parse_block(self) -> nodes.Block: ): self.fail("Required blocks can only contain comments or whitespace") - self.stream.skip_if("name:" + node.name) + self.stream.skip_if(f"name:{node.name}") return node def parse_extends(self) -> nodes.Extends: @@ -517,9 +515,7 @@ def parse_expression(self, with_condexpr: bool = True) -> nodes.Expr: the optional `with_condexpr` parameter is set to `False` conditional expressions are not parsed. """ - if with_condexpr: - return self.parse_condexpr() - return self.parse_or() + return self.parse_condexpr() if with_condexpr else self.parse_or() def parse_condexpr(self) -> nodes.Expr: lineno = self.stream.current.lineno @@ -528,10 +524,7 @@ def parse_condexpr(self) -> nodes.Expr: while self.stream.skip_if("name:if"): expr2 = self.parse_or() - if self.stream.skip_if("name:else"): - expr3 = self.parse_condexpr() - else: - expr3 = None + expr3 = self.parse_condexpr() if self.stream.skip_if("name:else") else None expr1 = nodes.CondExpr(expr2, expr1, expr3, lineno=lineno) lineno = self.stream.current.lineno return expr1 @@ -579,9 +572,7 @@ def parse_compare(self) -> nodes.Expr: else: break lineno = self.stream.current.lineno - if not ops: - return expr - return nodes.Compare(expr, ops, lineno=lineno) + return expr if not ops else nodes.Compare(expr, ops, lineno=lineno) def parse_math1(self) -> nodes.Expr: lineno = self.stream.current.lineno @@ -600,9 +591,7 @@ def parse_concat(self) -> nodes.Expr: while self.stream.current.type == "tilde": next(self.stream) args.append(self.parse_math2()) - if len(args) == 1: - return args[0] - return nodes.Concat(args, lineno=lineno) + return args[0] if len(args) == 1 else nodes.Concat(args, lineno=lineno) def parse_math2(self) -> nodes.Expr: lineno = self.stream.current.lineno @@ -773,10 +762,8 @@ def parse_dict(self) -> nodes.Dict: def parse_postfix(self, node: nodes.Expr) -> nodes.Expr: while True: token_type = self.stream.current.type - if token_type == "dot" or token_type == "lbracket": + if token_type in ["dot", "lbracket"]: node = self.parse_subscript(node) - # calls are valid both after postfix expressions (getattr - # and getitem) as well as filters and tests elif token_type == "lparen": node = self.parse_call(node) else: diff --git a/src/jinja2/runtime.py b/src/jinja2/runtime.py index a90d15f48..36545ea6d 100644 --- a/src/jinja2/runtime.py +++ b/src/jinja2/runtime.py @@ -100,10 +100,7 @@ def new_context( """Internal helper for context creation.""" if vars is None: vars = {} - if shared: - parent = vars - else: - parent = dict(globals or (), **vars) + parent = vars if shared else dict(globals or (), **vars) if locals: # if the parent is shared a copy should be created because # we don't want to modify the dict passed @@ -219,10 +216,7 @@ def resolve(self, key: str) -> t.Union[t.Any, "Undefined"]: """ rv = self.resolve_or_missing(key) - if rv is missing: - return self.environment.undefined(name=key) - - return rv + return self.environment.undefined(name=key) if rv is missing else rv def resolve_or_missing(self, key: str) -> t.Any: """Look up a variable by name, or return a ``missing`` sentinel @@ -237,10 +231,7 @@ def resolve_or_missing(self, key: str) -> t.Any: if key in self.vars: return self.vars[key] - if key in self.parent: - return self.parent[key] - - return missing + return self.parent[key] if key in self.parent else missing def get_exported(self) -> t.Dict[str, t.Any]: """Get a new dict with the exported variables.""" @@ -253,9 +244,7 @@ def get_all(self) -> t.Dict[str, t.Any]: """ if not self.vars: return self.parent - if not self.parent: - return self.vars - return dict(self.parent, **self.vars) + return self.vars if not self.parent else dict(self.parent, **self.vars) @internalcode def call( @@ -369,10 +358,7 @@ async def _async_call(self) -> str: [x async for x in self._stack[self._depth](self._context)] # type: ignore ) - if self._context.eval_ctx.autoescape: - return Markup(rv) - - return rv + return Markup(rv) if self._context.eval_ctx.autoescape else rv @internalcode def __call__(self) -> str: @@ -381,10 +367,7 @@ def __call__(self) -> str: rv = concat(self._stack[self._depth](self._context)) - if self._context.eval_ctx.autoescape: - return Markup(rv) - - return rv + return Markup(rv) if self._context.eval_ctx.autoescape else rv class LoopContext: @@ -523,10 +506,7 @@ def nextitem(self) -> t.Union[t.Any, "Undefined"]: """ rv = self._peek_next() - if rv is missing: - return self._undefined("there is no next item") - - return rv + return self._undefined("there is no next item") if rv is missing else rv def cycle(self, *args: V) -> V: """Return a value from the given args, cycling through based on @@ -634,10 +614,7 @@ async def last(self) -> bool: # type: ignore async def nextitem(self) -> t.Union[t.Any, "Undefined"]: rv = await self._peek_next() - if rv is missing: - return self._undefined("there is no next item") - - return rv + return self._undefined("there is no next item") if rv is missing else rv def __aiter__(self) -> "AsyncLoopContext": return self @@ -768,10 +745,7 @@ def __call__(self, *args: t.Any, **kwargs: t.Any) -> str: async def _async_invoke(self, arguments: t.List[t.Any], autoescape: bool) -> str: rv = await self._func(*arguments) # type: ignore - if autoescape: - return Markup(rv) - - return rv # type: ignore + return Markup(rv) if autoescape else rv def _invoke(self, arguments: t.List[t.Any], autoescape: bool) -> str: if self._environment.is_async: @@ -856,7 +830,7 @@ def _fail_with_undefined_error( @internalcode def __getattr__(self, name: str) -> t.Any: - if name[:2] == "__": + if name.startswith("__"): raise AttributeError(name) return self._fail_with_undefined_error() diff --git a/src/jinja2/sandbox.py b/src/jinja2/sandbox.py index 153f42ec2..14fc70b0b 100644 --- a/src/jinja2/sandbox.py +++ b/src/jinja2/sandbox.py @@ -88,10 +88,7 @@ def inspect_format_method(callable: t.Callable[..., t.Any]) -> t.Optional[str]: obj = callable.__self__ - if isinstance(obj, str): - return obj - - return None + return obj if isinstance(obj, str) else None def safe_range(*args: int) -> range: @@ -178,10 +175,14 @@ def modifies_known_mutable(obj: t.Any, attr: str) -> bool: >>> modifies_known_mutable("foo", "upper") False """ - for typespec, unsafe in _mutable_spec: - if isinstance(obj, typespec): - return attr in unsafe - return False + return next( + ( + attr in unsafe + for typespec, unsafe in _mutable_spec + if isinstance(obj, typespec) + ), + False, + ) class SandboxedEnvironment(Environment): @@ -417,10 +418,7 @@ def get_field( first, rest = formatter_field_name_split(field_name) obj = self.get_value(first, args, kwargs) for is_attr, i in rest: - if is_attr: - obj = self._env.getattr(obj, i) - else: - obj = self._env.getitem(obj, i) + obj = self._env.getattr(obj, i) if is_attr else self._env.getitem(obj, i) return obj, first diff --git a/src/jinja2/tests.py b/src/jinja2/tests.py index 0d29f9475..5013633f7 100644 --- a/src/jinja2/tests.py +++ b/src/jinja2/tests.py @@ -137,12 +137,12 @@ def test_float(value: t.Any) -> bool: def test_lower(value: str) -> bool: """Return true if the variable is lowercased.""" - return str(value).islower() + return value.islower() def test_upper(value: str) -> bool: """Return true if the variable is uppercased.""" - return str(value).isupper() + return value.isupper() def test_string(value: t.Any) -> bool: diff --git a/src/jinja2/utils.py b/src/jinja2/utils.py index 4b4720f6d..b8097fe3c 100644 --- a/src/jinja2/utils.py +++ b/src/jinja2/utils.py @@ -80,10 +80,7 @@ class _PassArg(enum.Enum): @classmethod def from_obj(cls, obj: F) -> t.Optional["_PassArg"]: - if hasattr(obj, "jinja_pass_arg"): - return obj.jinja_pass_arg # type: ignore - - return None + return obj.jinja_pass_arg if hasattr(obj, "jinja_pass_arg") else None def internalcode(f: F) -> F: @@ -111,8 +108,6 @@ def default(var, default=''): def consume(iterable: t.Iterable[t.Any]) -> None: """Consumes an iterable without doing anything with it.""" - for _ in iterable: - pass def clear_caches() -> None: @@ -156,10 +151,7 @@ def open_if_exists(filename: str, mode: str = "rb") -> t.Optional[t.IO[t.Any]]: """Returns a file descriptor for the filename if that file exists, otherwise ``None``. """ - if not os.path.isfile(filename): - return None - - return open(filename, mode) + return None if not os.path.isfile(filename) else open(filename, mode) def object_type_repr(obj: t.Any) -> str: @@ -259,10 +251,7 @@ def urlize( if trim_url_limit is not None: def trim_url(x: str) -> str: - if len(x) > trim_url_limit: - return f"{x[:trim_url_limit]}..." - - return x + return f"{x[:trim_url_limit]}..." if len(x) > trim_url_limit else x else: @@ -381,7 +370,7 @@ def generate_lorem_ipsum( p_str = " ".join(p) if p_str.endswith(","): - p_str = p_str[:-1] + "." + p_str = f"{p_str[:-1]}." elif not p_str.endswith("."): p_str += "." @@ -618,9 +607,7 @@ def autoescape(template_name: t.Optional[str]) -> bool: template_name = template_name.lower() if template_name.endswith(enabled_patterns): return True - if template_name.endswith(disabled_patterns): - return False - return default + return False if template_name.endswith(disabled_patterns) else default return autoescape @@ -712,9 +699,8 @@ def next(self) -> t.Any: """Return the current item, then advance :attr:`current` to the next item. """ - rv = self.current self.pos = (self.pos + 1) % len(self.items) - return rv + return self.current __next__ = next diff --git a/src/jinja2/visitor.py b/src/jinja2/visitor.py index 17c6aaba5..4d32c4193 100644 --- a/src/jinja2/visitor.py +++ b/src/jinja2/visitor.py @@ -86,7 +86,4 @@ def visit_list(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.List[Node]: """ rv = self.visit(node, *args, **kwargs) - if not isinstance(rv, list): - return [rv] - - return rv + return [rv] if not isinstance(rv, list) else rv diff --git a/tests/test_api.py b/tests/test_api.py index 4db3b4a96..fdef721b4 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -134,9 +134,7 @@ def test_get_template_undefined(self, env): def test_autoescape_autoselect(self, env): def select_autoescape(name): - if name is None or "." not in name: - return False - return name.endswith(".html") + return False if name is None or "." not in name else name.endswith(".html") env = Environment( autoescape=select_autoescape, @@ -183,7 +181,7 @@ def test_find_refererenced_templates(self, env): i = meta.find_referenced_templates(ast) assert next(i) == "layout.html" assert next(i) is None - assert list(i) == [] + assert not list(i) ast = env.parse( '{% extends "layout.html" %}' @@ -422,9 +420,13 @@ class CustomEnvironment(Environment): assert tmpl.render() == "bar" def test_custom_context(self): + + + class CustomContext(Context): def resolve_or_missing(self, key): - return "resolve-" + key + return f"resolve-{key}" + class CustomEnvironment(Environment): context_class = CustomContext diff --git a/tests/test_async.py b/tests/test_async.py index c9ba70c3e..4a74642b1 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -311,9 +311,7 @@ def test_empty_blocks(self, test_env_async): ) assert tmpl.render() == "<>" - @pytest.mark.parametrize( - "transform", [lambda x: x, iter, reversed, lambda x: (i for i in x), auto_aiter] - ) + @pytest.mark.parametrize("transform", [lambda x: x, iter, reversed, lambda x: iter(x), auto_aiter]) def test_context_vars(self, test_env_async, transform): t = test_env_async.from_string( "{% for item in seq %}{{ loop.index }}|{{ loop.index0 }}" diff --git a/tests/test_core_tags.py b/tests/test_core_tags.py index 4bb95e024..5f2cd771f 100644 --- a/tests/test_core_tags.py +++ b/tests/test_core_tags.py @@ -31,7 +31,7 @@ def test_empty_blocks(self, env): def test_context_vars(self, env): slist = [42, 24] - for seq in [slist, iter(slist), reversed(slist), (_ for _ in slist)]: + for seq in [slist, iter(slist), reversed(slist), iter(slist)]: tmpl = env.from_string( """{% for item in seq -%} {{ loop.index }}|{{ loop.index0 }}|{{ loop.revindex }}|{{ diff --git a/tests/test_ext.py b/tests/test_ext.py index 2e842e0ab..7fe0adce0 100644 --- a/tests/test_ext.py +++ b/tests/test_ext.py @@ -71,10 +71,7 @@ def _get_with_context(value, ctx=None): - if isinstance(value, dict): - return value.get(ctx, value) - - return value + return value.get(ctx, value) if isinstance(value, dict) else value @pass_context @@ -161,7 +158,7 @@ def parse(self, parser): [ nodes.EnvironmentAttribute("sandboxed"), self.attr("ext_attr"), - nodes.ImportedName(__name__ + ".importable_object"), + nodes.ImportedName(f"{__name__}.importable_object"), self.context_reference_node_cls(), ], ) @@ -200,8 +197,7 @@ def interpolate(self, token): match = _gettext_re.search(token.value, pos) if match is None: break - value = token.value[pos : match.start()] - if value: + if value := token.value[pos : match.start()]: yield Token(lineno, "data", value) lineno += count_newlines(token.value) yield Token(lineno, "variable_begin", None) @@ -271,7 +267,7 @@ def test_contextreference_node_can_pass_locals(self): assert tmpl.render() == "False|42|23|{}|test_content" def test_identifier(self): - assert ExampleExtension.identifier == __name__ + ".ExampleExtension" + assert ExampleExtension.identifier == f"{__name__}.ExampleExtension" def test_rebinding(self): original = Environment(extensions=[ExampleExtension]) diff --git a/tests/test_filters.py b/tests/test_filters.py index 32897c546..3a2308b64 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -252,7 +252,7 @@ def test_lower(self, env): assert out == "foo" def test_items(self, env): - d = {i: c for i, c in enumerate("abc")} + d = dict(enumerate("abc")) tmpl = env.from_string("""{{ d|items|list }}""") out = tmpl.render(d=d) assert out == "[(0, 'a'), (1, 'b'), (2, 'c')]" diff --git a/tests/test_regression.py b/tests/test_regression.py index 46e492bdd..b9c029ba7 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -364,8 +364,10 @@ def test_callable_defaults(self): def test_macro_escaping(self): env = Environment(autoescape=lambda x: False) - template = "{% macro m() %}{% endmacro %}" - template += "{% autoescape true %}{{ m() }}{% endautoescape %}" + template = ( + "{% macro m() %}{% endmacro %}" + + "{% autoescape true %}{{ m() }}{% endautoescape %}" + ) assert env.from_string(template).render() def test_macro_scoping(self, env): @@ -601,7 +603,7 @@ def test_markup_and_chainable_undefined(self): from markupsafe import Markup from jinja2.runtime import ChainableUndefined - assert str(Markup(ChainableUndefined())) == "" + assert not str(Markup(ChainableUndefined())) def test_scoped_block_loop_vars(self, env): tmpl = env.from_string(