Skip to content

Commit

Permalink
Simplify (?) and add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
rolfhm committed Oct 19, 2023
1 parent 7560abb commit afb17af
Showing 1 changed file with 108 additions and 36 deletions.
144 changes: 108 additions & 36 deletions loki/transform/transform_scalar_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,57 +39,126 @@ def check_if_scalar_syntax(arg, dummy):


def single_sum(expr):
"""
Return a Sum object of expr if expr is not an instance of pymbolic.primitives.Sum.
Otherwise return expr
Parameters
----------
expr: any pymbolic expression
"""
if isinstance(expr, pmbl.Sum):
return expr
else:
return Sum((expr,))


def sum_ints(expr):
def product_value(expr):
"""
If expr is an instance of pymbolic.primitives.Product, try to evaluate it
If it is possible, return the value as an int.
If it is not possible, try to simplify the the product and return as a Product
If it is not a pymbolic.primitives.Product , return expr
Note: Negative numbers and subtractions in Sums are represented as Product of
the integer -1 and the symbol. This complicates matters.
Parameters
----------
expr: any pymbolic expression
"""
if isinstance(expr, pmbl.Product):
m = 1
new_children = []
for c in expr.children:
if isinstance(c, IntLiteral):
m = m*c.value
elif isinstance(c, int):
m = m*c
else:
new_children += [c]

if m == 0:
return 0
elif not new_children:
return m
else:
if m > 1:
m = IntLiteral(m)
elif m < -1:
m = Product((-1, IntLiteral(abs(m))))
return m*Product(as_tuple(new_children))
else:
return expr


def simplify_sum(expr):
"""
If expr is an instance of pymbolic.primitives.Sum,
try to simplify it by evaluating any Products and adding up ints and IntLiterals.
If the sum can be reduced to a number, it returns an IntLiteral
If the Sum reduces to one expression, it returns that expression
Parameters
----------
expr: any pymbolic expression
"""

if isinstance(expr, pmbl.Sum):
n = 0
new_children = []
for c in expr.children:
c = product_value(c)
if isinstance(c, IntLiteral):
n += c.value
elif (isinstance(c, pmbl.Product) and
all(isinstance(cc, IntLiteral) or isinstance(cc,int) for cc in c.children)):
m = 1
for cc in c.children:
if isinstance(cc, IntLiteral):
m = m*cc.value
else:
m = m*cc
n += m
elif isinstance(c, int):
n += c
else:
new_children += [c]

if n != 0:
new_children += [IntLiteral(n)]
if new_children:
if n > 0:
new_children += [IntLiteral(n)]
elif n < 0:
new_children += [Product((-1,IntLiteral(abs(n))))]

if len(new_children) > 1:
return Sum(as_tuple(new_children))
else:
return new_children[0]

expr.children = as_tuple(new_children)

else:
return IntLiteral(n)
else:
return expr


def construct_range_index(lower, length):
"""
Construct a range index from lower to lower + length - 1
if lower == IntLiteral(1):
new_high = length
elif isinstance(lower, IntLiteral) and isinstance(length, IntLiteral):
new_high = IntLiteral(value = length.value + lower.value - 1)
elif isinstance(lower, IntLiteral):
new_high = single_sum(length) + IntLiteral(value = lower.value - 1)
elif isinstance(length, IntLiteral):
new_high = single_sum(lower) + IntLiteral(value = length.value - 1)
else:
new_high = single_sum(length) + lower - IntLiteral(1)
Parameters
----------
lower : any pymbolic expression
length: any pymbolic expression
"""

sum_ints(new_high)
new_high = simplify_sum(single_sum(length) + lower - IntLiteral(1))

return RangeIndex((lower, new_high))


def process_symbol(symbol, caller, call):
"""
Map symbol in call.routine to the appropriate symbol in caller,
taking any parents into account
Parameters
----------
symbol: Loki variable in call.routine
caller: Subroutine object containing call
call : Call object
"""

if isinstance(symbol, IntLiteral):
return symbol
Expand All @@ -107,19 +176,22 @@ def process_symbol(symbol, caller, call):
raise RuntimeError('[Loki::fix_scalar_syntax] Unable to resolve argument dimension. Module variable?')


def construct_length(xrange, routine, call):
def construct_length(xrange, caller, call):
"""
Construct an expression for the length of xrange,
defined in call.routine, in caller.
Parameters
----------
xrange: RangeIndex object defined in call.routine
caller: Subroutine object
call : call contained in caller
"""

new_start = process_symbol(xrange.start, routine, call)
new_stop = process_symbol(xrange.stop, routine, call)
new_start = process_symbol(xrange.start, caller, call)
new_stop = process_symbol(xrange.stop, caller, call)

if isinstance(new_start, IntLiteral) and isinstance(new_stop, IntLiteral):
return IntLiteral(value = new_stop.value - new_start.value + 1)
elif isinstance(new_start, IntLiteral):
return single_sum(new_stop) - IntLiteral(value = new_start.value - 1)
elif isinstance(new_stop, IntLiteral):
return single_sum(IntLiteral(value = new_stop.value + 1)) - new_start
else:
return single_sum(new_stop) - new_start + IntLiteral(1)
return simplify_sum(single_sum(new_stop) - new_start + IntLiteral(1))


def fix_scalar_syntax(routine):
Expand Down

0 comments on commit afb17af

Please sign in to comment.