From bc96cbc734e589a127d3e64428d810bde3628241 Mon Sep 17 00:00:00 2001 From: Chaluvadi Date: Tue, 12 Mar 2024 11:42:09 -0400 Subject: [PATCH] readability changes pt.2 --- tests/test_constants.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/test_constants.py b/tests/test_constants.py index 929c8ed..e513a0b 100644 --- a/tests/test_constants.py +++ b/tests/test_constants.py @@ -29,7 +29,7 @@ ) -types = [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64] +all_types = [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64] @pytest.mark.parametrize( @@ -66,7 +66,6 @@ def test_constant_complex_shape(shape: tuple) -> None: """Test if constant_complex creates an array with the correct shape.""" dtype = c32 - dtype = c32 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -167,11 +166,11 @@ def test_constant_ulong_shape_invalid() -> None: @pytest.mark.parametrize( "dtype", - types, + all_types, ) def test_constant_dtype(dtype: Dtype) -> None: """Test if constant creates an array with the correct dtype.""" - if dtype in [c32, c64] or (dtype == f64 and not wrapper.get_dbl_support()): + if is_cmplx_type(dtype) or not is_system_supported(dtype): pytest.skip() rand_array = wrapper.randu((1, 1), dtype) @@ -186,11 +185,11 @@ def test_constant_dtype(dtype: Dtype) -> None: @pytest.mark.parametrize( "dtype", - types, + all_types, ) def test_constant_complex_dtype(dtype: Dtype) -> None: """Test if constant_complex creates an array with the correct dtype.""" - if dtype not in [c32, c64] or (dtype == c64 and not wrapper.get_dbl_support()): + if not is_cmplx_type(dtype) or not is_system_supported(dtype): pytest.skip() rand_array = wrapper.randu((1, 1), dtype) @@ -234,3 +233,14 @@ def test_constant_ulong_dtype() -> None: assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype else: pytest.skip() + + +def is_cmplx_type(dtype: Dtype) -> bool: + return dtype == c32 or dtype == c64 + + +def is_system_supported(dtype: Dtype) -> bool: + if dtype in [f64, c64] and not wrapper.get_dbl_support(): + return False + + return True