From 1bc0f3a84cf3733040a0a0f6a2bc6ce790dc933e Mon Sep 17 00:00:00 2001 From: Chaluvadi Date: Tue, 12 Mar 2024 10:48:35 -0400 Subject: [PATCH] Readability changes to cosntants tests --- tests/test_constants.py | 69 +++++++++++++++++++++++++---------------- 1 file changed, 42 insertions(+), 27 deletions(-) diff --git a/tests/test_constants.py b/tests/test_constants.py index 855d94a..929c8ed 100644 --- a/tests/test_constants.py +++ b/tests/test_constants.py @@ -2,8 +2,23 @@ import pytest -import arrayfire_wrapper.dtypes as dtypes import arrayfire_wrapper.lib as wrapper +from arrayfire_wrapper.dtypes import ( + Dtype, + c32, + c64, + c_api_value_to_dtype, + f16, + f32, + f64, + s16, + s32, + s64, + u8, + u16, + u32, + u64, +) invalid_shape = ( random.randint(1, 10), @@ -14,6 +29,9 @@ ) +types = [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64] + + @pytest.mark.parametrize( "shape", [ @@ -27,7 +45,7 @@ def test_constant_shape(shape: tuple) -> None: """Test if constant creates an array with the correct shape.""" number = 5.0 - dtype = dtypes.s16 + dtype = s16 result = wrapper.constant(number, shape, dtype) @@ -46,9 +64,9 @@ def test_constant_shape(shape: tuple) -> None: ) def test_constant_complex_shape(shape: tuple) -> None: """Test if constant_complex creates an array with the correct shape.""" - dtype = dtypes.c32 + dtype = c32 - dtype = dtypes.c32 + dtype = c32 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -71,7 +89,7 @@ def test_constant_complex_shape(shape: tuple) -> None: ) def test_constant_long_shape(shape: tuple) -> None: """Test if constant_long creates an array with the correct shape.""" - dtype = dtypes.s64 + dtype = s64 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -93,7 +111,7 @@ def test_constant_long_shape(shape: tuple) -> None: ) def test_constant_ulong_shape(shape: tuple) -> None: """Test if constant_ulong creates an array with the correct shape.""" - dtype = dtypes.u64 + dtype = u64 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -109,7 +127,7 @@ def test_constant_shape_invalid() -> None: """Test if constant handles a shape with greater than 4 dimensions""" with pytest.raises(TypeError): number = 5.0 - dtype = dtypes.s16 + dtype = s16 wrapper.constant(number, invalid_shape, dtype) @@ -117,7 +135,7 @@ def test_constant_shape_invalid() -> None: def test_constant_complex_shape_invalid() -> None: """Test if constant_complex handles a shape with greater than 4 dimensions""" with pytest.raises(TypeError): - dtype = dtypes.c32 + dtype = c32 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -128,7 +146,7 @@ def test_constant_complex_shape_invalid() -> None: def test_constant_long_shape_invalid() -> None: """Test if constant_long handles a shape with greater than 4 dimensions""" with pytest.raises(TypeError): - dtype = dtypes.s64 + dtype = s64 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -139,7 +157,7 @@ def test_constant_long_shape_invalid() -> None: def test_constant_ulong_shape_invalid() -> None: """Test if constant_ulong handles a shape with greater than 4 dimensions""" with pytest.raises(TypeError): - dtype = dtypes.u64 + dtype = u64 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -148,50 +166,47 @@ def test_constant_ulong_shape_invalid() -> None: @pytest.mark.parametrize( - "dtype_index", - [i for i in range(13)], + "dtype", + types, ) -def test_constant_dtype(dtype_index: int) -> None: +def test_constant_dtype(dtype: Dtype) -> None: """Test if constant creates an array with the correct dtype.""" - if dtype_index in [1, 3] or (dtype_index == 2 and not wrapper.get_dbl_support()): + if dtype in [c32, c64] or (dtype == f64 and not wrapper.get_dbl_support()): pytest.skip() - dtype = dtypes.c_api_value_to_dtype(dtype_index) - rand_array = wrapper.randu((1, 1), dtype) value = wrapper.get_scalar(rand_array, dtype) shape = (2, 2) if isinstance(value, (int, float)): result = wrapper.constant(value, shape, dtype) - assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype + assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype else: pytest.skip() @pytest.mark.parametrize( - "dtype_index", - [i for i in range(13)], + "dtype", + types, ) -def test_constant_complex_dtype(dtype_index: int) -> None: +def test_constant_complex_dtype(dtype: Dtype) -> None: """Test if constant_complex creates an array with the correct dtype.""" - if dtype_index not in [1, 3] or (dtype_index == 3 and not wrapper.get_dbl_support()): + if dtype not in [c32, c64] or (dtype == c64 and not wrapper.get_dbl_support()): pytest.skip() - dtype = dtypes.c_api_value_to_dtype(dtype_index) rand_array = wrapper.randu((1, 1), dtype) value = wrapper.get_scalar(rand_array, dtype) shape = (2, 2) if isinstance(value, (int, float, complex)): result = wrapper.constant_complex(value, shape, dtype) - assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype + assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype else: pytest.skip() def test_constant_long_dtype() -> None: """Test if constant_long creates an array with the correct dtype.""" - dtype = dtypes.s64 + dtype = s64 rand_array = wrapper.randu((1, 1), dtype) value = wrapper.get_scalar(rand_array, dtype) @@ -200,14 +215,14 @@ def test_constant_long_dtype() -> None: if isinstance(value, (int, float)): result = wrapper.constant_long(value, shape, dtype) - assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype + assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype else: pytest.skip() def test_constant_ulong_dtype() -> None: """Test if constant_ulong creates an array with the correct dtype.""" - dtype = dtypes.u64 + dtype = u64 rand_array = wrapper.randu((1, 1), dtype) value = wrapper.get_scalar(rand_array, dtype) @@ -216,6 +231,6 @@ def test_constant_ulong_dtype() -> None: if isinstance(value, (int, float)): result = wrapper.constant_ulong(value, shape, dtype) - assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype + assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype else: pytest.skip()