Skip to content

Commit

Permalink
Readability changes to cosntants tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Chaluvadi committed Mar 28, 2024
1 parent 2fbaace commit 1bc0f3a
Showing 1 changed file with 42 additions and 27 deletions.
69 changes: 42 additions & 27 deletions tests/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,23 @@

import pytest

import arrayfire_wrapper.dtypes as dtypes
import arrayfire_wrapper.lib as wrapper
from arrayfire_wrapper.dtypes import (
Dtype,
c32,
c64,
c_api_value_to_dtype,
f16,
f32,
f64,
s16,
s32,
s64,
u8,
u16,
u32,
u64,
)

invalid_shape = (
random.randint(1, 10),
Expand All @@ -14,6 +29,9 @@
)


types = [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64]


@pytest.mark.parametrize(
"shape",
[
Expand All @@ -27,7 +45,7 @@
def test_constant_shape(shape: tuple) -> None:
"""Test if constant creates an array with the correct shape."""
number = 5.0
dtype = dtypes.s16
dtype = s16

result = wrapper.constant(number, shape, dtype)

Expand All @@ -46,9 +64,9 @@ def test_constant_shape(shape: tuple) -> None:
)
def test_constant_complex_shape(shape: tuple) -> None:
"""Test if constant_complex creates an array with the correct shape."""
dtype = dtypes.c32
dtype = c32

dtype = dtypes.c32
dtype = c32
rand_array = wrapper.randu((1, 1), dtype)
number = wrapper.get_scalar(rand_array, dtype)

Expand All @@ -71,7 +89,7 @@ def test_constant_complex_shape(shape: tuple) -> None:
)
def test_constant_long_shape(shape: tuple) -> None:
"""Test if constant_long creates an array with the correct shape."""
dtype = dtypes.s64
dtype = s64
rand_array = wrapper.randu((1, 1), dtype)
number = wrapper.get_scalar(rand_array, dtype)

Expand All @@ -93,7 +111,7 @@ def test_constant_long_shape(shape: tuple) -> None:
)
def test_constant_ulong_shape(shape: tuple) -> None:
"""Test if constant_ulong creates an array with the correct shape."""
dtype = dtypes.u64
dtype = u64
rand_array = wrapper.randu((1, 1), dtype)
number = wrapper.get_scalar(rand_array, dtype)

Expand All @@ -109,15 +127,15 @@ def test_constant_shape_invalid() -> None:
"""Test if constant handles a shape with greater than 4 dimensions"""
with pytest.raises(TypeError):
number = 5.0
dtype = dtypes.s16
dtype = s16

wrapper.constant(number, invalid_shape, dtype)


def test_constant_complex_shape_invalid() -> None:
"""Test if constant_complex handles a shape with greater than 4 dimensions"""
with pytest.raises(TypeError):
dtype = dtypes.c32
dtype = c32
rand_array = wrapper.randu((1, 1), dtype)
number = wrapper.get_scalar(rand_array, dtype)

Expand All @@ -128,7 +146,7 @@ def test_constant_complex_shape_invalid() -> None:
def test_constant_long_shape_invalid() -> None:
"""Test if constant_long handles a shape with greater than 4 dimensions"""
with pytest.raises(TypeError):
dtype = dtypes.s64
dtype = s64
rand_array = wrapper.randu((1, 1), dtype)
number = wrapper.get_scalar(rand_array, dtype)

Expand All @@ -139,7 +157,7 @@ def test_constant_long_shape_invalid() -> None:
def test_constant_ulong_shape_invalid() -> None:
"""Test if constant_ulong handles a shape with greater than 4 dimensions"""
with pytest.raises(TypeError):
dtype = dtypes.u64
dtype = u64
rand_array = wrapper.randu((1, 1), dtype)
number = wrapper.get_scalar(rand_array, dtype)

Expand All @@ -148,50 +166,47 @@ def test_constant_ulong_shape_invalid() -> None:


@pytest.mark.parametrize(
"dtype_index",
[i for i in range(13)],
"dtype",
types,
)
def test_constant_dtype(dtype_index: int) -> None:
def test_constant_dtype(dtype: Dtype) -> None:
"""Test if constant creates an array with the correct dtype."""
if dtype_index in [1, 3] or (dtype_index == 2 and not wrapper.get_dbl_support()):
if dtype in [c32, c64] or (dtype == f64 and not wrapper.get_dbl_support()):
pytest.skip()

dtype = dtypes.c_api_value_to_dtype(dtype_index)

rand_array = wrapper.randu((1, 1), dtype)
value = wrapper.get_scalar(rand_array, dtype)
shape = (2, 2)
if isinstance(value, (int, float)):
result = wrapper.constant(value, shape, dtype)
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
else:
pytest.skip()


@pytest.mark.parametrize(
"dtype_index",
[i for i in range(13)],
"dtype",
types,
)
def test_constant_complex_dtype(dtype_index: int) -> None:
def test_constant_complex_dtype(dtype: Dtype) -> None:
"""Test if constant_complex creates an array with the correct dtype."""
if dtype_index not in [1, 3] or (dtype_index == 3 and not wrapper.get_dbl_support()):
if dtype not in [c32, c64] or (dtype == c64 and not wrapper.get_dbl_support()):
pytest.skip()

dtype = dtypes.c_api_value_to_dtype(dtype_index)
rand_array = wrapper.randu((1, 1), dtype)
value = wrapper.get_scalar(rand_array, dtype)
shape = (2, 2)

if isinstance(value, (int, float, complex)):
result = wrapper.constant_complex(value, shape, dtype)
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
else:
pytest.skip()


def test_constant_long_dtype() -> None:
"""Test if constant_long creates an array with the correct dtype."""
dtype = dtypes.s64
dtype = s64

rand_array = wrapper.randu((1, 1), dtype)
value = wrapper.get_scalar(rand_array, dtype)
Expand All @@ -200,14 +215,14 @@ def test_constant_long_dtype() -> None:
if isinstance(value, (int, float)):
result = wrapper.constant_long(value, shape, dtype)

assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
else:
pytest.skip()


def test_constant_ulong_dtype() -> None:
"""Test if constant_ulong creates an array with the correct dtype."""
dtype = dtypes.u64
dtype = u64

rand_array = wrapper.randu((1, 1), dtype)
value = wrapper.get_scalar(rand_array, dtype)
Expand All @@ -216,6 +231,6 @@ def test_constant_ulong_dtype() -> None:
if isinstance(value, (int, float)):
result = wrapper.constant_ulong(value, shape, dtype)

assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
else:
pytest.skip()

0 comments on commit 1bc0f3a

Please sign in to comment.