Skip to content

Commit

Permalink
Sketch out add_lexicographic_happens_after
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Jan 20, 2023
1 parent 93b4126 commit 01d3688
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 3 deletions.
83 changes: 80 additions & 3 deletions loopy/kernel/dependency.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# FIXME Add copyright header


import islpy as isl
from islpy import dim_type
import pymbolic.primitives as p

from dataclasses import dataclass
Expand All @@ -7,6 +11,8 @@

from loopy import LoopKernel
from loopy.symbolic import WalkMapper
from loopy.translation_unit import for_each_kernel
from loopy.typing import ExpressionT

@dataclass(frozen=True)
class HappensAfter:
Expand All @@ -32,7 +38,7 @@ def __init__(self, kernel: LoopKernel, var_names: set):

super.__init__()

def map_subscript(self, expr: p.expression, inames: frozenset, insn_id: str):
def map_subscript(self, expr: ExpressionT, inames: frozenset, insn_id: str):

domain = self.kernel.get_inames_domain(inames)

Expand Down Expand Up @@ -110,7 +116,7 @@ def compute_happens_after(knl: LoopKernel) -> LoopKernel:
# return the kernel with the new instructions
return knl.copy(instructions=new_insns)

def add_lexicographic_happens_after(knl: LoopKernel) -> None:
def add_lexicographic_happens_after_orig(knl: LoopKernel) -> None:
"""
TODO properly format this documentation.
Expand All @@ -122,7 +128,7 @@ def add_lexicographic_happens_after(knl: LoopKernel) -> None:
"""

# we want to modify the output dimension and OUT = 3
dim_type = isl.dim_type(3)
dim_type = isl.dim_type.out

# generate an unordered mapping from statement instances to points in the
# loop domain
Expand All @@ -148,3 +154,74 @@ def add_lexicographic_happens_after(knl: LoopKernel) -> None:

# determine a lexicographic order on the space the schedules belong to


@for_each_kernel
def add_lexicographic_happens_after(knl: LoopKernel) -> LoopKernel:

new_insns = []

for iafter, insn_after in enumerate(knl.instructions):
if iafter == 0:
new_insns.append(insn_after)
else:
insn_before = knl.instructions[iafter - 1]
shared_inames = insn_after.within_inames & insn_before.within_inames
unshared_before = insn_before.within_inames

domain_before = knl.get_inames_domain(insn_before.within_inames)
domain_after = knl.get_inames_domain(insn_after.within_inames)

happens_before = isl.Map.from_domain_and_range(
domain_before, domain_after)
for idim in range(happens_before.dim(dim_type.out)):
happens_before = happens_before.set_dim_name(
dim_type.out, idim,
happens_before.get_dim_name(dim_type.out, idim) + "'")
n_inames_before = happens_before.dim(dim_type.in_)
happens_before_set = happens_before.move_dims(
dim_type.out, 0,
dim_type.in_, 0,
n_inames_before).range()

shared_inames_order_before = [
domain_before.get_dim_name(dim_type.out, idim)
for idim in range(domain_before.dim(dim_type.out))
if domain_before.get_dim_name(dim_type.out, idim)
in shared_inames]
shared_inames_order_after = [
domain_after.get_dim_name(dim_type.out, idim)
for idim in range(domain_after.dim(dim_type.out))
if domain_after.get_dim_name(dim_type.out, idim)
in shared_inames]

assert shared_inames_order_after == shared_inames_order_before
shared_inames_order = shared_inames_order_after

affs = isl.affs_from_space(happens_before_set.space)

lex_set = isl.Set.empty(happens_before_set.space)
for iinnermost, innermost_iname in enumerate(shared_inames_order):
innermost_set = affs[innermost_iname].lt_set(
affs[innermost_iname+"'"])

for outer_iname in shared_inames_order[:iinnermost]:
innermost_set = innermost_set & (
affs[outer_iname].eq_set(affs[outer_iname + "'"]))

lex_set = lex_set | innermost_set

lex_map = isl.Map.from_range(lex_set).move_dims(
dim_type.in_, 0,
dim_type.out, 0,
n_inames_before)

happens_before = happens_before & lex_map

pu.db

new_insns.append(insn_after)

return knl.copy(instructions=new_insns)



32 changes: 32 additions & 0 deletions test/test_dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# FIXME Add copyright header


import sys
import loopy as lp


def test_lex_dependencies():
knl = lp.make_kernel(
[
"{[a,b]:0<=a,b<7}",
"{[i,j]: 0<=i,j<n and 0<=a,b<5}",
"{[k,l]: 0<=k,l<n and 0<=a,b<3}"
],
"""
v[a,b,i,j] = 2*v[a,b,i,j]
v[a,b,k,l] = 2*v[a,b,k,l]
""")

from loopy.kernel.dependency import add_lexicographic_happens_after

add_lexicographic_happens_after(knl)


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
else:
from pytest import main
main([__file__])

# vim: foldmethod=marker

0 comments on commit 01d3688

Please sign in to comment.