Skip to content

Commit

Permalink
test c vector extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd committed Mar 11, 2022
1 parent 56c64d9 commit 8713fc3
Showing 1 changed file with 89 additions and 0 deletions.
89 changes: 89 additions & 0 deletions test/test_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,95 @@ def test_glibc_bessel_functions(dtype):
rtol=1e-6, atol=1e-6)


def test_c_vector_extensions():
knl = lp.make_kernel(
"{[i, j1, j2, j3]: 0<=i<10 and 0<=j1,j2,j3<4}",
"""
<> temp1[j1] = x[i, j1]
<> temp2[j2] = 2*temp1[j2] + 1 {inames=i:j2}
y[i, j3] = temp2[j3]
""",
[lp.GlobalArg("x, y", shape=lp.auto, dtype=float)],
seq_dependencies=True,
target=lp.CVectorExtensionsTarget())

knl = lp.tag_inames(knl, "j2:vec, j1:ilp, j3:ilp")
knl = lp.tag_array_axes(knl, "temp1,temp2", "vec")

print(lp.generate_code_v2(knl).device_code())


def test_omp_simd_tag():
knl = lp.make_kernel(
"{[i]: 0<=i<16}",
"""
y[i] = 2 * x[i]
""")

knl = lp.add_dtypes(knl, {"x": "float64"})
knl = lp.split_iname(knl, "i", 4)
knl = lp.tag_inames(knl, {"i_inner": lp.OpenMPSIMDTag()})

code_str = lp.generate_code_v2(knl).device_code()

assert any(line.strip() == "#pragma omp simd"
for line in code_str.split("\n"))


def test_vec_tag_with_omp_simd_fallback():
knl = lp.make_kernel(
"{[i, j1, j2, j3]: 0<=i<10 and 0<=j1,j2,j3<4}",
"""
<> temp1[j1] = x[i, j1]
<> temp2[j2] = 2*temp1[j2] + 1 {inames=i:j2}
y[i, j3] = temp2[j3]
""",
[lp.GlobalArg("x, y", shape=lp.auto, dtype=float)],
seq_dependencies=True,
target=lp.ExecutableCVectorExtensionsTarget())

knl = lp.tag_inames(knl, {"j1": lp.VectorizeTag(lp.OpenMPSIMDTag()),
"j2": lp.VectorizeTag(lp.OpenMPSIMDTag()),
"j3": lp.VectorizeTag(lp.OpenMPSIMDTag())})
knl = lp.tag_array_axes(knl, "temp1,temp2", "vec")

code_str = lp.generate_code_v2(knl).device_code()

assert len([line
for line in code_str.split("\n")
if line.strip() == "#pragma omp simd"]) == 2

x = np.random.rand(10, 4)
_, (out,) = knl(x=x)
np.testing.assert_allclose(out, 2*x+1)


def test_vec_extensions_with_multiple_loopy_body_insns():
knl = lp.make_kernel(
"{[n]: 0<=n<N}",
"""
for n
... nop {id=expr_start}
<> tmp = 2.0
dat0[n, 0] = tmp {id=expr_insn}
... nop {id=statement0}
end
""",
seq_dependencies=True,
target=lp.ExecutableCVectorExtensionsTarget())

knl = lp.add_dtypes(knl, {"dat0": "float64"})
knl = lp.split_iname(knl, "n", 4, slabs=(1, 1),
inner_iname="n_batch")
knl = lp.privatize_temporaries_with_inames(knl, "n_batch")
knl = lp.tag_array_axes(knl, "tmp", "vec")
knl = lp.tag_inames(knl, {
"n_batch": lp.VectorizeTag(lp.OpenMPSIMDTag())})

_, (out,) = knl(N=100)
np.testing.assert_allclose(out, 2*np.ones((100, 1)))


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down

0 comments on commit 8713fc3

Please sign in to comment.