diff --git a/src/jinja2/async_utils.py b/src/jinja2/async_utils.py
index 79c1adc..8264152 100644
--- a/src/jinja2/async_utils.py
+++ b/src/jinja2/async_utils.py
@@ -5,4 +5,88 @@ from functools import wraps
from .utils import _PassArg
from .utils import pass_eval_context
V = t.TypeVar('V')
-_common_primitives = {int, float, bool, str, list, dict, tuple, type(None)}
\ No newline at end of file
+_common_primitives = {int, float, bool, str, list, dict, tuple, type(None)}
+
+def async_variant(normal_func: t.Callable[..., V]) -> t.Callable[..., t.Awaitable[V]]:
+ """Returns an async variant of a sync function. This is useful to provide
+ async versions of filters and tests.
+
+ Example::
+
+ def foo(x):
+ return x
+
+ async_foo = async_variant(foo)
+
+ When called, the original function will be run in a thread executor.
+ """
+ is_async = inspect.iscoroutinefunction(normal_func)
+ is_filter = getattr(normal_func, '_jinja_filter', False)
+ is_test = getattr(normal_func, '_jinja_test', False)
+ is_pass_arg = getattr(normal_func, '_pass_arg', None)
+ is_pass_eval_context = getattr(normal_func, '_pass_eval_context', False)
+
+ if is_async:
+ return normal_func
+
+ async def async_func(*args: t.Any, **kwargs: t.Any) -> V:
+ if is_pass_arg is not None:
+ kwargs.pop(is_pass_arg.value, None)
+ if is_pass_eval_context:
+ args = args[1:]
+ return await normal_func(*args, **kwargs)
+
+ # Properly rewrap the function to carry over all the original attributes
+ # and metadata of the original function
+ async_func = wraps(normal_func, assigned=WRAPPER_ASSIGNMENTS)(async_func)
+ if is_filter:
+ async_func._jinja_filter = True # type: ignore
+ if is_test:
+ async_func._jinja_test = True # type: ignore
+ if is_pass_arg is not None:
+ async_func._pass_arg = is_pass_arg # type: ignore
+ if is_pass_eval_context:
+ async_func._pass_eval_context = True # type: ignore
+
+ return async_func
+
+async def auto_aiter(value: t.Any) -> t.AsyncIterator[t.Any]:
+ """Convert an iterable into an async iterable. This is useful to
+ iterate over sync iterables in async contexts.
+
+ Example::
+
+ async for item in auto_aiter([1, 2, 3]):
+ ...
+ """
+ if hasattr(value, '__aiter__'):
+ async for item in value:
+ yield item
+ else:
+ for item in value:
+ yield item
+
+async def auto_await(value: t.Any) -> t.Any:
+ """Convert a value into its awaited value if it's awaitable.
+ Otherwise return it as-is. This is useful to handle both async
+ and sync functions in the same code.
+
+ Example::
+
+ value = await auto_await(value)
+ """
+ if hasattr(value, '__await__'):
+ return await value
+ return value
+
+async def auto_to_list(value: t.Any) -> t.List[t.Any]:
+ """Convert an iterable into a list. This is useful to handle both
+ async and sync iterables in the same code.
+
+ Example::
+
+ items = await auto_to_list(items)
+ """
+ if hasattr(value, '__aiter__'):
+ return [item async for item in value]
+ return list(value)
\ No newline at end of file
diff --git a/src/jinja2/compiler.py b/src/jinja2/compiler.py
index 5d623b3..2ca8550 100644
--- a/src/jinja2/compiler.py
+++ b/src/jinja2/compiler.py
@@ -71,11 +71,26 @@ class Frame:
def copy(self) -> 'Frame':
"""Create a copy of the current one."""
- pass
+ rv = object.__new__(Frame)
+ rv.eval_ctx = self.eval_ctx
+ rv.parent = self.parent
+ rv.symbols = self.symbols
+ rv.require_output_check = self.require_output_check
+ rv.buffer = self.buffer
+ rv.block = self.block
+ rv.toplevel = self.toplevel
+ rv.rootlevel = self.rootlevel
+ rv.loop_frame = self.loop_frame
+ rv.block_frame = self.block_frame
+ rv.soft_frame = self.soft_frame
+ return rv
def inner(self, isolated: bool=False) -> 'Frame':
"""Return an inner frame."""
- pass
+ rv = Frame(self.eval_ctx, self, level=self.symbols.level + 1)
+ if isolated:
+ rv.symbols.add_level()
+ return rv
def soft(self) -> 'Frame':
"""Return a soft frame. A soft frame may not be modified as
@@ -85,7 +100,19 @@ class Frame:
This is only used to implement if-statements and conditional
expressions.
"""
- pass
+ rv = object.__new__(Frame)
+ rv.eval_ctx = self.eval_ctx
+ rv.parent = self.parent
+ rv.symbols = self.symbols
+ rv.require_output_check = self.require_output_check
+ rv.buffer = self.buffer
+ rv.block = self.block
+ rv.toplevel = False
+ rv.rootlevel = False
+ rv.loop_frame = self.loop_frame
+ rv.block_frame = self.block_frame
+ rv.soft_frame = True
+ return rv
__copy__ = copy
class VisitorExit(RuntimeError):
@@ -100,7 +127,7 @@ class DependencyFinderVisitor(NodeVisitor):
def visit_Block(self, node: nodes.Block) -> None:
"""Stop visiting at blocks."""
- pass
+ return
class UndeclaredNameVisitor(NodeVisitor):
"""A visitor that checks if a name is accessed without being
@@ -114,7 +141,7 @@ class UndeclaredNameVisitor(NodeVisitor):
def visit_Block(self, node: nodes.Block) -> None:
"""Stop visiting a blocks."""
- pass
+ return
class CompilerExit(Exception):
"""Raised if the compiler encountered a situation where it just
@@ -122,6 +149,22 @@ class CompilerExit(Exception):
raises such an exception is not further processed.
"""
+def _make_binop(op: str) -> t.Callable[['CodeGenerator', nodes.BinExpr, Frame], None]:
+ def visitor(self: 'CodeGenerator', node: nodes.BinExpr, frame: Frame) -> None:
+ self.write('(')
+ self.visit(node.left, frame)
+ self.write(f' {op} ')
+ self.visit(node.right, frame)
+ self.write(')')
+ return visitor
+
+def _make_unop(op: str) -> t.Callable[['CodeGenerator', nodes.UnaryExpr, Frame], None]:
+ def visitor(self: 'CodeGenerator', node: nodes.UnaryExpr, frame: Frame) -> None:
+ self.write('(' + op)
+ self.visit(node.node, frame)
+ self.write(')')
+ return visitor
+
class CodeGenerator(NodeVisitor):
def __init__(self, environment: 'Environment', name: t.Optional[str], filename: t.Optional[str], stream: t.Optional[t.TextIO]=None, defer_init: bool=False, optimized: bool=True) -> None:
@@ -156,57 +199,89 @@ class CodeGenerator(NodeVisitor):
def fail(self, msg: str, lineno: int) -> 'te.NoReturn':
"""Fail with a :exc:`TemplateAssertionError`."""
- pass
+ raise TemplateAssertionError(msg, lineno, self.name, self.filename)
def temporary_identifier(self) -> str:
"""Get a new unique identifier."""
- pass
+ self._last_identifier += 1
+ return f'_tmp_{self._last_identifier}'
def buffer(self, frame: Frame) -> None:
"""Enable buffering for the frame from that point onwards."""
- pass
+ frame.buffer = self.temporary_identifier()
def return_buffer_contents(self, frame: Frame, force_unescaped: bool=False) -> None:
"""Return the buffer contents of the frame."""
- pass
+ if not frame.buffer:
+ self.writeline('if 0: yield None')
+ return
+
+ if force_unescaped:
+ self.writeline(f'yield str({frame.buffer})')
+ elif frame.eval_ctx.volatile:
+ self.writeline(f'yield str({frame.buffer}) if context.eval_ctx.autoescape else {frame.buffer}')
+ elif frame.eval_ctx.autoescape:
+ self.writeline(f'yield str({frame.buffer})')
+ else:
+ self.writeline(f'yield {frame.buffer}')
def indent(self) -> None:
"""Indent by one."""
- pass
+ self._indentation += 1
def outdent(self, step: int=1) -> None:
"""Outdent by step."""
- pass
+ self._indentation = max(0, self._indentation - step)
def start_write(self, frame: Frame, node: t.Optional[nodes.Node]=None) -> None:
"""Yield or write into the frame buffer."""
- pass
+ if frame.buffer is None:
+ self.writeline('yield ', node)
+ else:
+ self.writeline(f'{frame.buffer} = ', node)
def end_write(self, frame: Frame) -> None:
"""End the writing process started by `start_write`."""
- pass
+ if frame.buffer is not None:
+ self.writeline(f'yield {frame.buffer}')
def simple_write(self, s: str, frame: Frame, node: t.Optional[nodes.Node]=None) -> None:
"""Simple shortcut for start_write + write + end_write."""
- pass
+ self.start_write(frame, node)
+ self.write(s)
+ self.end_write(frame)
def blockvisit(self, nodes: t.Iterable[nodes.Node], frame: Frame) -> None:
"""Visit a list of nodes as block in a frame. If the current frame
is no buffer a dummy ``if 0: yield None`` is written automatically.
"""
- pass
+ if frame.buffer is None:
+ self.writeline('if 0: yield None')
+ for node in nodes:
+ self.visit(node, frame)
def write(self, x: str) -> None:
"""Write a string into the output stream."""
- pass
+ if self._new_lines:
+ if not self._first_write:
+ self.stream.write('\n' * self._new_lines)
+ self.stream.write(' ' * self._indentation)
+ self._new_lines = 0
+ self.stream.write(x)
+ self._first_write = False
def writeline(self, x: str, node: t.Optional[nodes.Node]=None, extra: int=0) -> None:
"""Combination of newline and write."""
- pass
+ if node is not None:
+ self._write_debug_info(node)
+ self.newline(node, extra)
+ self.write(x)
def newline(self, node: t.Optional[nodes.Node]=None, extra: int=0) -> None:
"""Add one or more newlines before the next write."""
- pass
+ if node is not None:
+ self._write_debug_info(node)
+ self._new_lines = max(self._new_lines, 1 + extra)
def signature(self, node: t.Union[nodes.Call, nodes.Filter, nodes.Test], frame: Frame, extra_kwargs: t.Optional[t.Mapping[str, t.Any]]=None) -> None:
"""Writes a function call to the stream for the current node.
@@ -215,7 +290,33 @@ class CodeGenerator(NodeVisitor):
error could occur. The extra keyword arguments should be given
as python dict.
"""
- pass
+ if extra_kwargs is None:
+ extra_kwargs = {}
+
+ # Regular argument handling
+ for arg in node.args:
+ self.write(', ')
+ self.visit(arg, frame)
+
+ # Keyword argument handling
+ for kwarg in node.kwargs:
+ self.write(', ')
+ self.write(kwarg.key + '=')
+ self.visit(kwarg.value, frame)
+
+ # Dynamic arguments
+ if node.dyn_args is not None:
+ self.write(', *')
+ self.visit(node.dyn_args, frame)
+
+ # Dynamic keyword arguments
+ if node.dyn_kwargs is not None:
+ self.write(', **')
+ self.visit(node.dyn_kwargs, frame)
+
+ # Extra keyword arguments from the call
+ for key, value in extra_kwargs.items():
+ self.write(f', {key}={value!r}')
def pull_dependencies(self, nodes: t.Iterable[nodes.Node]) -> None:
"""Find all filter and test names used in the template and
@@ -228,26 +329,71 @@ class CodeGenerator(NodeVisitor):
Filters and tests in If and CondExpr nodes are checked at
runtime instead of compile time.
"""
- pass
+ visitor = DependencyFinderVisitor()
+ for node in nodes:
+ visitor.visit(node)
+ for name in visitor.filters:
+ if name not in self.filters:
+ self.filters[name] = self.temporary_identifier()
+ for name in visitor.tests:
+ if name not in self.tests:
+ self.tests[name] = self.temporary_identifier()
def macro_body(self, node: t.Union[nodes.Macro, nodes.CallBlock], frame: Frame) -> t.Tuple[Frame, MacroRef]:
"""Dump the function def of a macro or call block."""
- pass
+ macro_ref = MacroRef(node)
+ body_frame = frame.inner()
+ body_frame.require_output_check = False
+ self.writeline('def macro(', node)
+ self.indent()
+ self.write('caller=None')
+ for arg in node.args:
+ self.write(', ' + arg.name)
+ if arg.default is not None:
+ self.write('=')
+ self.visit(arg.default, frame)
+ self.write(', kwargs=None')
+ self.write('):')
+ self.indent()
+ self.buffer(body_frame)
+ self.pull_dependencies([node])
+ self.blockvisit(node.body, body_frame)
+ self.return_buffer_contents(body_frame)
+ self.outdent(2)
+ return body_frame, macro_ref
def macro_def(self, macro_ref: MacroRef, frame: Frame) -> None:
"""Dump the macro definition for the def created by macro_body."""
- pass
+ self.write('Macro(environment, macro, None, ')
+ if macro_ref.accesses_kwargs:
+ self.write('True')
+ else:
+ self.write('False')
+ if macro_ref.accesses_varargs:
+ self.write(', True')
+ else:
+ self.write(', False')
+ if macro_ref.accesses_caller:
+ self.write(', True')
+ else:
+ self.write(', False')
+ self.write(')')
def position(self, node: nodes.Node) -> str:
"""Return a human readable position for the node."""
- pass
+ rv = f'line {node.lineno}'
+ if self.name is not None:
+ rv = f'{rv} in {self.name}'
+ return rv
def write_commons(self) -> None:
"""Writes a common preamble that is used by root and block functions.
Primarily this sets up common local helpers and enforces a generator
through a dead branch.
"""
- pass
+ self.writeline('resolve = context.resolve_or_missing')
+ self.writeline('undefined = environment.undefined')
+ self.writeline('if 0: yield None')
def push_parameter_definitions(self, frame: Frame) -> None:
"""Pushes all parameter targets from the given frame into a local
@@ -256,51 +402,115 @@ class CodeGenerator(NodeVisitor):
undefined expressions for parameters in macros as macros can reference
otherwise unbound parameters.
"""
- pass
+ self._param_def_block.append(frame.symbols.stores)
def pop_parameter_definitions(self) -> None:
"""Pops the current parameter definitions set."""
- pass
+ self._param_def_block.pop()
def mark_parameter_stored(self, target: str) -> None:
"""Marks a parameter in the current parameter definitions as stored.
This will skip the enforced undefined checks.
"""
- pass
+ if self._param_def_block:
+ self._param_def_block[-1].add(target)
def parameter_is_undeclared(self, target: str) -> bool:
"""Checks if a given target is an undeclared parameter."""
- pass
+ if not self._param_def_block:
+ return False
+ return target not in self._param_def_block[-1]
def push_assign_tracking(self) -> None:
"""Pushes a new layer for assignment tracking."""
- pass
+ self._assign_stack.append(set())
def pop_assign_tracking(self, frame: Frame) -> None:
"""Pops the topmost level for assignment tracking and updates the
context variables if necessary.
"""
- pass
+ assignments = self._assign_stack.pop()
+ if assignments:
+ self.writeline(f'{frame.symbols.store_set(assignments)} = context.exported_vars.add_update({assignments!r})')
def visit_Block(self, node: nodes.Block, frame: Frame) -> None:
"""Call a block and register it for the template."""
- pass
+ block_frame = frame.inner()
+ block_frame.block = node.name
+ self.writeline(f'def block_{node.name}(context, missing=missing):', node)
+ self.indent()
+ self.buffer(block_frame)
+ self.blockvisit(node.body, block_frame)
+ self.return_buffer_contents(block_frame)
+ self.outdent()
+ self.blocks[node.name] = node
+ self.writeline(f'blocks[{node.name!r}] = block_{node.name}')
def visit_Extends(self, node: nodes.Extends, frame: Frame) -> None:
"""Calls the extender."""
- pass
+ if not frame.toplevel:
+ self.fail('cannot use extend from a non top-level scope', node.lineno)
+ if self.has_known_extends:
+ self.fail('cannot have multiple extends tags', node.lineno)
+ if not self.extends_so_far:
+ self.writeline('parent_template = None')
+ self.writeline('for parent in [', node)
+ self.visit(node.template, frame)
+ self.write(']:')
+ self.indent()
+ self.writeline('parent_template = parent')
+ self.writeline('break')
+ self.outdent()
+ self.has_known_extends = True
def visit_Include(self, node: nodes.Include, frame: Frame) -> None:
"""Handles includes."""
- pass
+ if node.ignore_missing:
+ self.writeline('try:')
+ self.indent()
+ self.writeline('for event in (', node)
+ if node.with_context:
+ self.write('context.blocks[')
+ else:
+ self.write('environment.get_template(')
+ self.visit(node.template, frame)
+ if node.ignore_missing:
+ self.write(').root_render_func(context)')
+ self.outdent()
+ self.writeline('except TemplateNotFound:')
+ self.indent()
+ self.writeline('pass')
+ self.outdent()
+ else:
+ self.write(').root_render_func(context)')
+ self.write('):')
+ self.indent()
+ self.writeline('yield event')
+ self.outdent()
def visit_Import(self, node: nodes.Import, frame: Frame) -> None:
"""Visit regular imports."""
- pass
+ self.writeline(f'template = environment.get_template(', node)
+ self.visit(node.template, frame)
+ self.write(')')
+ self.writeline(f'{frame.symbols.ref(node.target)} = environment.make_module(')
+ self.write('template.make_module(context.get_all()' if node.with_context else 'template.make_module()')
+ self.write(')')
+ frame.symbols.store(node.target)
def visit_FromImport(self, node: nodes.FromImport, frame: Frame) -> None:
"""Visit named imports."""
- pass
+ self.writeline('template = environment.get_template(', node)
+ self.visit(node.template, frame)
+ self.write(')')
+ self.writeline('module = template.make_module(' + ('context.get_all()' if node.with_context else '') + ')')
+ for name in node.names:
+ if isinstance(name, tuple):
+ name, alias = name
+ else:
+ alias = name
+ self.writeline(f'{frame.symbols.ref(alias)} = getattr(module, {name!r}, missing)')
+ frame.symbols.store(alias)
class _FinalizeInfo(t.NamedTuple):
const: t.Optional[t.Callable[..., str]]
@@ -312,7 +522,7 @@ class CodeGenerator(NodeVisitor):
configured with one. Or, if the environment has one, this is
called on that function's output for constants.
"""
- pass
+ return value
_finalize: t.Optional[_FinalizeInfo] = None
def _make_finalize(self) -> _FinalizeInfo:
@@ -328,14 +538,20 @@ class CodeGenerator(NodeVisitor):
Source code to output around nodes to be evaluated at
runtime.
"""
- pass
+ finalize = self.environment.finalize
+ if finalize is None:
+ return self._FinalizeInfo(const=self._default_finalize, src=None)
+
+ src = self.temporary_identifier()
+ self.writeline(f'{src} = environment.finalize')
+ return self._FinalizeInfo(const=finalize, src=src)
def _output_const_repr(self, group: t.Iterable[t.Any]) -> str:
"""Given a group of constant values converted from ``Output``
child nodes, produce a string to write to the template module
source.
"""
- pass
+ return repr(concat(group))
def _output_child_to_const(self, node: nodes.Expr, frame: Frame, finalize: _FinalizeInfo) -> str:
"""Try to optimize a child of an ``Output`` node by trying to
@@ -345,19 +561,24 @@ class CodeGenerator(NodeVisitor):
will be evaluated at runtime. Any other exception will also be
evaluated at runtime for easier debugging.
"""
- pass
+ const = node.as_const(frame.eval_ctx)
+ if finalize.const is not None:
+ const = finalize.const(const)
+ return str(const)
def _output_child_pre(self, node: nodes.Expr, frame: Frame, finalize: _FinalizeInfo) -> None:
"""Output extra source code before visiting a child of an
``Output`` node.
"""
- pass
+ if finalize.src is not None:
+ self.write(f'{finalize.src}(')
def _output_child_post(self, node: nodes.Expr, frame: Frame, finalize: _FinalizeInfo) -> None:
"""Output extra source code after visiting a child of an
``Output`` node.
"""
- pass
+ if finalize.src is not None:
+ self.write(')')
visit_Add = _make_binop('+')
visit_Sub = _make_binop('-')
visit_Mul = _make_binop('*')
diff --git a/src/jinja2/idtracking.py b/src/jinja2/idtracking.py
index 8171019..3c2f3f3 100644
--- a/src/jinja2/idtracking.py
+++ b/src/jinja2/idtracking.py
@@ -20,6 +20,10 @@ class Symbols:
self.loads: t.Dict[str, t.Any] = {}
self.stores: t.Set[str] = set()
+def _simple_visit(self, node: nodes.Node, **kwargs: t.Any) -> None:
+ for child in node.iter_child_nodes():
+ self.sym_visitor.visit(child, **kwargs)
+
class RootVisitor(NodeVisitor):
def __init__(self, symbols: 'Symbols') -> None:
@@ -40,30 +44,37 @@ class FrameSymbolVisitor(NodeVisitor):
def visit_Name(self, node: nodes.Name, store_as_param: bool=False, **kwargs: t.Any) -> None:
"""All assignments to names go through this function."""
- pass
+ if node.ctx == 'store':
+ self.symbols.stores.add(node.name)
+ if store_as_param:
+ self.symbols.loads[node.name] = VAR_LOAD_PARAMETER
+ elif node.ctx == 'param':
+ self.symbols.loads[node.name] = VAR_LOAD_PARAMETER
def visit_Assign(self, node: nodes.Assign, **kwargs: t.Any) -> None:
"""Visit assignments in the correct order."""
- pass
+ self.visit(node.node, **kwargs)
+ self.visit(node.target, store_as_param=True, **kwargs)
def visit_For(self, node: nodes.For, **kwargs: t.Any) -> None:
"""Visiting stops at for blocks. However the block sequence
is visited as part of the outer scope.
"""
- pass
+ self.visit(node.iter, **kwargs)
+ self.visit(node.target, store_as_param=True, **kwargs)
def visit_AssignBlock(self, node: nodes.AssignBlock, **kwargs: t.Any) -> None:
"""Stop visiting at block assigns."""
- pass
+ self.visit(node.target, store_as_param=True, **kwargs)
def visit_Scope(self, node: nodes.Scope, **kwargs: t.Any) -> None:
"""Stop visiting at scopes."""
- pass
+ return
def visit_Block(self, node: nodes.Block, **kwargs: t.Any) -> None:
"""Stop visiting at blocks."""
- pass
+ return
def visit_OverlayScope(self, node: nodes.OverlayScope, **kwargs: t.Any) -> None:
"""Do not visit into overlay scopes."""
- pass
\ No newline at end of file
+ return
\ No newline at end of file
diff --git a/src/jinja2/nodes.py b/src/jinja2/nodes.py
index 42b3187..5a313dd 100644
--- a/src/jinja2/nodes.py
+++ b/src/jinja2/nodes.py
@@ -16,6 +16,9 @@ _binop_to_func: t.Dict[str, t.Callable[[t.Any, t.Any], t.Any]] = {'*': operator.
_uaop_to_func: t.Dict[str, t.Callable[[t.Any], t.Any]] = {'not': operator.not_, '+': operator.pos, '-': operator.neg}
_cmpop_to_func: t.Dict[str, t.Callable[[t.Any, t.Any], t.Any]] = {'eq': operator.eq, 'ne': operator.ne, 'gt': operator.gt, 'gteq': operator.ge, 'lt': operator.lt, 'lteq': operator.le, 'in': lambda a, b: a in b, 'notin': lambda a, b: a not in b}
+def _failing_new(cls, *args, **kwargs):
+ raise TypeError('abstract %r cannot be instantiated' % cls.__name__)
+
class Impossible(Exception):
"""Raised if the node could not perform a requested action."""
@@ -77,7 +80,7 @@ class Node(metaclass=NodeType):
if len(fields) != len(self.fields):
if not self.fields:
raise TypeError(f'{type(self).__name__!r} takes 0 arguments')
- raise TypeError(f'{type(self).__name__!r} takes 0 or {len(self.fields)} argument{('s' if len(self.fields) != 1 else '')}')
+ raise TypeError(f'{type(self).__name__!r} takes 0 or {len(self.fields)} argument{"s" if len(self.fields) != 1 else ""}')
for name, arg in zip(self.fields, fields):
setattr(self, name, arg)
for attr in self.attributes:
@@ -92,26 +95,46 @@ class Node(metaclass=NodeType):
parameter or to exclude some using the `exclude` parameter. Both
should be sets or tuples of field names.
"""
- pass
+ for name in self.fields:
+ if (exclude is None or name not in exclude) and (only is None or name in only):
+ try:
+ yield name, getattr(self, name)
+ except AttributeError:
+ pass
def iter_child_nodes(self, exclude: t.Optional[t.Container[str]]=None, only: t.Optional[t.Container[str]]=None) -> t.Iterator['Node']:
"""Iterates over all direct child nodes of the node. This iterates
over all fields and yields the values of they are nodes. If the value
of a field is a list all the nodes in that list are returned.
"""
- pass
+ for field, item in self.iter_fields(exclude, only):
+ if isinstance(item, list):
+ for n in item:
+ if isinstance(n, Node):
+ yield n
+ elif isinstance(item, Node):
+ yield item
def find(self, node_type: t.Type[_NodeBound]) -> t.Optional[_NodeBound]:
"""Find the first node of a given type. If no such node exists the
return value is `None`.
"""
- pass
+ if isinstance(self, node_type):
+ return self
+ for child in self.iter_child_nodes():
+ result = child.find(node_type)
+ if result is not None:
+ return result
+ return None
def find_all(self, node_type: t.Union[t.Type[_NodeBound], t.Tuple[t.Type[_NodeBound], ...]]) -> t.Iterator[_NodeBound]:
"""Find all the nodes of a given type. If the type is a tuple,
the check is performed for any of the tuple items.
"""
- pass
+ if isinstance(self, node_type):
+ yield self
+ for child in self.iter_child_nodes():
+ yield from child.find_all(node_type)
def set_ctx(self, ctx: str) -> 'Node':
"""Reset the context of a node and all child nodes. Per default the
@@ -119,15 +142,26 @@ class Node(metaclass=NodeType):
most common one. This method is used in the parser to set assignment
targets and other nodes to a store context.
"""
- pass
+ if hasattr(self, 'ctx'):
+ self.ctx = ctx
+ for child in self.iter_child_nodes():
+ child.set_ctx(ctx)
+ return self
def set_lineno(self, lineno: int, override: bool=False) -> 'Node':
"""Set the line numbers of the node and children."""
- pass
+ if not hasattr(self, 'lineno') or override:
+ self.lineno = lineno
+ for child in self.iter_child_nodes():
+ child.set_lineno(lineno, override)
+ return self
def set_environment(self, environment: 'Environment') -> 'Node':
"""Set the environment for all nodes."""
- pass
+ self.environment = environment
+ for child in self.iter_child_nodes():
+ child.set_environment(environment)
+ return self
def __eq__(self, other: t.Any) -> bool:
if type(self) is not type(other):
@@ -303,11 +337,11 @@ class Expr(Node):
.. versionchanged:: 2.4
the `eval_ctx` parameter was added.
"""
- pass
+ raise Impossible()
def can_assign(self) -> bool:
"""Check if it's possible to assign something to this node."""
- pass
+ return False
class BinExpr(Expr):
"""Baseclass for all binary expressions."""
@@ -361,7 +395,31 @@ class Const(Literal):
constant value in the generated code, otherwise it will raise
an `Impossible` exception.
"""
- pass
+ if isinstance(value, (bool, int, float, str, Markup)):
+ return cls(value, lineno=lineno, environment=environment)
+ if value is None:
+ return cls(None, lineno=lineno, environment=environment)
+ if isinstance(value, (tuple, list)):
+ items = []
+ for item in value:
+ try:
+ items.append(cls.from_untrusted(item, lineno=lineno,
+ environment=environment).value)
+ except Impossible:
+ raise Impossible(f'Constant {value!r} includes non-constant value')
+ if isinstance(value, tuple):
+ return cls(tuple(items), lineno=lineno, environment=environment)
+ return cls(items, lineno=lineno, environment=environment)
+ if isinstance(value, dict):
+ items = {}
+ for key, value in value.items():
+ try:
+ items[str(key)] = cls.from_untrusted(value, lineno=lineno,
+ environment=environment).value
+ except Impossible:
+ raise Impossible(f'Constant {value!r} includes non-constant value')
+ return cls(items, lineno=lineno, environment=environment)
+ raise Impossible(f'Value of type {type(value).__name__} is not JSON serializable')
class TemplateData(Literal):
"""A constant template string."""
diff --git a/src/jinja2/runtime.py b/src/jinja2/runtime.py
index ae56b56..6b4e04f 100644
--- a/src/jinja2/runtime.py
+++ b/src/jinja2/runtime.py
@@ -1,5 +1,11 @@
"""The runtime functions and state used by compiled templates."""
import functools
+
+def _dict_method_all(method: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]:
+ """Wrap a dict method to return all values from parent and child."""
+ def wrapped(self: 'Context') -> t.Any:
+ return chain(method(self.vars), method(self.parent))
+ return wrapped
import sys
import typing as t
from collections import abc
@@ -38,15 +44,15 @@ def identity(x: V) -> V:
"""Returns its argument. Useful for certain things in the
environment.
"""
- pass
+ return x
def markup_join(seq: t.Iterable[t.Any]) -> str:
"""Concatenation that escapes if necessary and converts to string."""
- pass
+ return Markup('').join(seq)
def str_join(seq: t.Iterable[t.Any]) -> str:
"""Simple args to string conversion and concatenation."""
- pass
+ return ''.join(map(str, seq))
def new_context(environment: 'Environment', template_name: t.Optional[str], blocks: t.Dict[str, t.Callable[['Context'], t.Iterator[str]]], vars: t.Optional[t.Dict[str, t.Any]]=None, shared: bool=False, globals: t.Optional[t.MutableMapping[str, t.Any]]=None, locals: t.Optional[t.Mapping[str, t.Any]]=None) -> 'Context':
"""Internal helper for context creation."""