Skip to content

Commit

Permalink
compiler: convert printer to f-string
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jan 31, 2025
1 parent baebc8c commit 8030825
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 38 deletions.
2 changes: 0 additions & 2 deletions devito/arch/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,8 +916,6 @@ def __init_finalize__(self, **kwargs):
self.cflags.append('-fsycl-targets=nvptx64-cuda')
elif isinstance(platform, IntelDevice):
self.cflags.append('-fsycl-targets=spir64')
else:
raise NotImplementedError("Unsupported platform %s" % platform)


class CustomCompiler(Compiler):
Expand Down
73 changes: 40 additions & 33 deletions devito/ir/cgen/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def func_prefix(self, expr, abs=False):

def parenthesize(self, item, level, strict=False):
if isinstance(item, BooleanFunction):
return "(%s)" % self._print(item)
return f"({self._print(item)})"
return super().parenthesize(item, level, strict=strict)

def _print_type(self, expr):
Expand All @@ -120,7 +120,7 @@ def _print_Function(self, expr):
return super()._print_Function(expr)

def _print_CondEq(self, expr):
return "%s == %s" % (self._print(expr.lhs), self._print(expr.rhs))
return f"{self._print(expr.lhs)} == {self._print(expr.rhs)}"

def _print_Indexed(self, expr):
"""
Expand All @@ -131,7 +131,7 @@ def _print_Indexed(self, expr):
U[t,x,y,z] -> U[t][x][y][z]
"""
inds = ''.join(['[' + self._print(x) + ']' for x in expr.indices])
return '%s%s' % (self._print(expr.base.label), inds)
return f'{self._print(expr.base.label)}{inds}'

def _print_FIndexed(self, expr):
"""
Expand All @@ -146,7 +146,7 @@ def _print_FIndexed(self, expr):
label = expr.accessor.label
except AttributeError:
label = expr.base.label
return '%s(%s)' % (self._print(label), inds)
return f'{self._print(label)}({inds})'

def _print_Rational(self, expr):
"""Print a Rational as a C-like float/float division."""
Expand All @@ -155,10 +155,8 @@ def _print_Rational(self, expr):
# to be 32-bit floats.
# http://en.cppreference.com/w/cpp/language/floating_literal
p, q = int(expr.p), int(expr.q)
if self.dtype == np.float64:
return '%d.0/%d.0' % (p, q)
else:
return '%d.0F/%d.0F' % (p, q)
prec = self.prec_literal(expr)
return f'{p}.0{prec}/{q}.0{prec}'

def _print_math_func(self, expr, nest=False, known=None):
cls = type(expr)
Expand Down Expand Up @@ -208,15 +206,21 @@ def _print_SafeInv(self, expr):

def _print_Mod(self, expr):
"""Print a Mod as a C-like %-based operation."""
args = ['(%s)' % self._print(a) for a in expr.args]
args = [f'({self._print(a)})' for a in expr.args]
return '%'.join(args)

def _print_Mul(self, expr):
term = super()._print_Mul(expr)
# avoid (-1)*...
term = term.replace("(-1)*", "-")
# Avoid (-1) / ...
term = term.replace("(-1)/", f"-{self._prec(expr)(1)}/")
args = [a for a in expr.args if a != -1]
neg = (len(expr.args) - len(args)) % 2

if len(args) > 1:
term = super()._print_Mul(expr.func(*args, evaluate=False))
else:
term = self.parenthesize(args[0], precedence(expr))

if neg:
term = f'-{term}'

return term

def _print_fmath_func(self, name, expr):
Expand All @@ -230,7 +234,7 @@ def _print_Min(self, expr):
expr.func(*expr.args[1:]),
evaluate=False))
elif has_integer_args(*expr.args) and len(expr.args) == 2:
return "MIN(%s)" % self._print(expr.args)[1:-1]
return f"MIN({self._print(expr.args)[1:-1]})"
else:
return self._print_fmath_func('min', expr)

Expand All @@ -240,7 +244,7 @@ def _print_Max(self, expr):
expr.func(*expr.args[1:]),
evaluate=False))
elif has_integer_args(*expr.args) and len(expr.args) == 2:
return "MAX(%s)" % self._print(expr.args)[1:-1]
return f"MAX({self._print(expr.args)[1:-1]})"
else:
return self._print_fmath_func('max', expr)

Expand All @@ -251,7 +255,7 @@ def _print_Abs(self, expr):
# AOMPCC errors with abs, always use fabs
if isinstance(self.compiler, AOMPCompiler) and \
not np.issubdtype(self._prec(expr), np.integer):
return "fabs(%s)" % self._print(arg)
return f"fabs({self._print(arg)})"
return self._print_fmath_func('abs', expr)

def _print_Add(self, expr, order=None):
Expand All @@ -265,7 +269,7 @@ def _print_Add(self, expr, order=None):
for term in terms:
t = self._print(term)
if precedence(term) < PREC:
l.extend(["+", "(%s)" % t])
l.extend(["+", f"({t})"])
elif t.startswith('-'):
l.extend(["-", t[1:]])
else:
Expand Down Expand Up @@ -305,44 +309,44 @@ def _print_Float(self, expr):
return f'{rv}{self.prec_literal(expr)}'

def _print_Differentiable(self, expr):
return "(%s)" % self._print(expr._expr)
return f"({self._print(expr._expr)})"

_print_EvalDerivative = _print_Add

def _print_CallFromPointer(self, expr):
indices = [self._print(i) for i in expr.params]
return "%s->%s(%s)" % (expr.pointer, expr.call, ', '.join(indices))
return f"{expr.pointer}->{expr.call}({', '.join(indices)})"

def _print_CallFromComposite(self, expr):
indices = [self._print(i) for i in expr.params]
return "%s.%s(%s)" % (expr.pointer, expr.call, ', '.join(indices))
return f"{expr.pointer}.{expr.call}({', '.join(indices)})"

def _print_FieldFromPointer(self, expr):
return "%s->%s" % (expr.pointer, expr.field)
return f"{expr.pointer}->{expr.field}"

def _print_FieldFromComposite(self, expr):
return "%s.%s" % (expr.pointer, expr.field)
return f"{expr.pointer}.{expr.field}"

def _print_ListInitializer(self, expr):
return "{%s}" % ', '.join([self._print(i) for i in expr.params])
return f"{{{', '.join(self._print(i) for i in expr.params)}}}"

def _print_IndexedPointer(self, expr):
return "%s%s" % (expr.base, ''.join('[%s]' % self._print(i) for i in expr.index))
return f"{expr.base}{''.join(f'[{self._print(i)}]' for i in expr.index)}"

def _print_IntDiv(self, expr):
lhs = self._print(expr.lhs)
if not expr.lhs.is_Atom:
lhs = '(%s)' % (lhs)
lhs = f"({lhs})"
rhs = self._print(expr.rhs)
PREC = precedence(expr)
return self.parenthesize("%s / %s" % (lhs, rhs), PREC)
return self.parenthesize(f"{lhs} / {rhs}", PREC)

def _print_InlineIf(self, expr):
cond = self._print(expr.cond)
true_expr = self._print(expr.true_expr)
false_expr = self._print(expr.false_expr)
PREC = precedence(expr)
return self.parenthesize("(%s) ? %s : %s" % (cond, true_expr, false_expr), PREC)
return self.parenthesize(f"({cond}) ? {true_expr} : {false_expr}", PREC)

def _print_UnaryOp(self, expr, op=None, parenthesize=False):
op = op or expr._op
Expand All @@ -356,20 +360,23 @@ def _print_Cast(self, expr):
return self._print_UnaryOp(expr, op=cast)

def _print_ComponentAccess(self, expr):
return "%s.%s" % (self._print(expr.base), expr.sindex)
return f"{self._print(expr.base)}.{expr.sindex}"

def _print_DefFunction(self, expr):
arguments = [self._print(i) for i in expr.arguments]
if expr.template:
template = '<%s>' % ','.join([str(i) for i in expr.template])
ctemplate = ','.join([str(i) for i in expr.template])
template = f'<{ctemplate}>'
else:
template = ''
return "%s%s(%s)" % (expr.name, template, ','.join(arguments))
args = ','.join(arguments)
return f"{expr.name}{template}({args})"

def _print_SizeOf(self, expr):
return f'sizeof({self._print(expr.intype)}{self._print(expr.stars)})'

_print_MathFunction = _print_DefFunction
def _print_MathFunction(self, expr):
return f"{self._ns}{self._print_DefFunction(expr)}"

def _print_Fallback(self, expr):
return expr.__str__()
Expand All @@ -385,7 +392,7 @@ def _print_Fallback(self, expr):

# Lifted from SymPy so that we go through our own `_print_math_func`
for k in ('exp log sin cos tan ceiling floor').split():
setattr(BasePrinter, '_print_%s' % k, BasePrinter._print_math_func)
setattr(BasePrinter, f'_print_{k}', BasePrinter._print_math_func)


# Always parenthesize IntDiv and InlineIf within expressions
Expand Down
4 changes: 2 additions & 2 deletions devito/passes/iet/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
List, Break, Return, FindNodes, FindSymbols, Transformer,
make_callable)
from devito.passes.iet.engine import iet_pass
from devito.symbolics import CondEq, DefFunction
from devito.symbolics import CondEq, MathFunction
from devito.tools import dtype_to_ctype
from devito.types import Eq, Inc, LocalObject, Symbol

Expand Down Expand Up @@ -58,7 +58,7 @@ def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None):
irs, byproduct = rcompile(eqns)

name = sregistry.make_name(prefix='is_finite')
retval = Return(DefFunction('isfinite', accumulator))
retval = Return(MathFunction('isfinite', accumulator))
body = irs.iet.body.body + (retval,)
efunc = make_callable(name, body, retval='int')

Expand Down
3 changes: 2 additions & 1 deletion devito/tools/dtypes_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ class c_restrict_void_p(ctypes.c_void_p):

name = "%s%d" % (base_name, count)
ctype = type(name, (ctypes.Structure,),
{'_fields_': [(i, base_ctype)] for i in field_names[:count]})
{'_fields_': [(i, base_ctype) for i in field_names[:count]],
'_base_dtype': name})

ctypes_vector_mapper[dtype] = ctype

Expand Down

0 comments on commit 8030825

Please sign in to comment.