Skip to content

Commit

Permalink
feat: remove test_*.py indirection for producer tests
Browse files Browse the repository at this point in the history
This PR removes the indirection through `test_*_functions.py` and
`test_*_relation.py` files and uses file globbing on the test definition
files for test collection instead. This is the first step in removing
the `test_*.py` files alltogether. These files are heavily repetitive
while only providing minimum value (adding some metadata and homing some
skip definitions, both of which can be done otherwise) and, therefor, a
burden for maintainability.

Signed-off-by: Ingo Müller <[email protected]>
  • Loading branch information
ingomueller-net committed Dec 13, 2024
1 parent ab6b9e3 commit 65068bd
Show file tree
Hide file tree
Showing 23 changed files with 152 additions and 502 deletions.
14 changes: 14 additions & 0 deletions substrait_consumer/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,26 @@
from substrait_consumer.consumers.acero_consumer import AceroConsumer
from substrait_consumer.consumers.datafusion_consumer import DataFusionConsumer
from substrait_consumer.consumers.duckdb_consumer import DuckDBConsumer
from substrait_consumer.functional.common import load_custom_duckdb_table
from substrait_consumer.producers.datafusion_producer import DataFusionProducer
from substrait_consumer.producers.duckdb_producer import DuckDBProducer
from substrait_consumer.producers.ibis_producer import IbisProducer
from substrait_consumer.producers.isthmus_producer import IsthmusProducer


@pytest.fixture
def db_con():
db_connection = duckdb.connect()
db_connection.execute("INSTALL substrait")
db_connection.execute("LOAD substrait")

load_custom_duckdb_table(db_connection)

yield db_connection

db_connection.close()


@pytest.fixture(scope="session")
def prepare_tpch_parquet_data():
"""
Expand Down
5 changes: 3 additions & 2 deletions substrait_consumer/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def wrapper(snapshot, record_property, *args, **kwargs):


def substrait_producer_sql_test(
test_name: str,
path: Path,
snapshot: Snapshot,
record_property,
db_con: DuckDBPyConnection,
Expand Down Expand Up @@ -271,7 +271,8 @@ def substrait_producer_sql_test(

producer.setup(db_con, local_files, named_tables)

category, group, name = test_name.split(":")
path = str(path).split(".")[0].split("/")
category, group, name = path[0], path[1], path[-1]
record_property("category", category)
record_property("group", group)
record_property("name", name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

from substrait_consumer.functional.approximation_configs import AGGREGATE_FUNCTIONS
from substrait_consumer.functional.common import (
generate_snapshot_results, substrait_consumer_sql_test,
substrait_producer_sql_test)
generate_snapshot_results,
substrait_consumer_sql_test,
)
from substrait_consumer.parametrization import custom_parametrization


Expand All @@ -28,30 +29,6 @@ def setup_teardown_function(request):

cls.db_connection.close()

@custom_parametrization(AGGREGATE_FUNCTIONS)
@pytest.mark.produce_substrait_snapshot
def test_producer_approximation_functions(
self,
snapshot,
record_property,
test_name: str,
local_files: dict[str, str],
named_tables: dict[str, str],
sql_query: tuple,
producer,
) -> None:
test_name = f"function:approximation:{test_name}"
substrait_producer_sql_test(
test_name,
snapshot,
record_property,
self.db_connection,
local_files,
named_tables,
sql_query,
producer,
)

@custom_parametrization(AGGREGATE_FUNCTIONS)
@pytest.mark.consume_substrait_snapshot
def test_consumer_approximation_functions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from substrait_consumer.functional.arithmetic_decimal_configs import (
AGGREGATE_FUNCTIONS, SCALAR_FUNCTIONS)
from substrait_consumer.functional.common import (
generate_snapshot_results, load_custom_duckdb_table,
substrait_consumer_sql_test, substrait_producer_sql_test)
generate_snapshot_results,
load_custom_duckdb_table,
substrait_consumer_sql_test,
)
from substrait_consumer.parametrization import custom_parametrization
from substrait_consumer.consumers.datafusion_consumer import DataFusionConsumer

Expand Down Expand Up @@ -40,30 +42,6 @@ def setup_teardown_function(request):

cls.db_connection.close()

@custom_parametrization(SCALAR_FUNCTIONS + AGGREGATE_FUNCTIONS)
@pytest.mark.produce_substrait_snapshot
def test_producer_arithmetic_decimal_functions(
self,
snapshot,
record_property,
test_name: str,
local_files: dict[str, str],
named_tables: dict[str, str],
sql_query: tuple,
producer,
) -> None:
test_name = f"function:arithmetic_decimal:{test_name}"
substrait_producer_sql_test(
test_name,
snapshot,
record_property,
self.db_connection,
local_files,
named_tables,
sql_query,
producer,
)

@custom_parametrization(SCALAR_FUNCTIONS + AGGREGATE_FUNCTIONS)
@pytest.mark.consume_substrait_snapshot
@pytest.mark.usefixtures('mark_consumer_tests_as_xfail')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from substrait_consumer.functional.arithmetic_configs import (
AGGREGATE_FUNCTIONS, SCALAR_FUNCTIONS)
from substrait_consumer.functional.common import (
generate_snapshot_results, load_custom_duckdb_table,
substrait_consumer_sql_test, substrait_producer_sql_test)
generate_snapshot_results,
load_custom_duckdb_table,
substrait_consumer_sql_test,
)
from substrait_consumer.parametrization import custom_parametrization
from substrait_consumer.consumers.datafusion_consumer import DataFusionConsumer

Expand Down Expand Up @@ -45,30 +47,6 @@ def setup_teardown_function(request):

cls.db_connection.close()

@custom_parametrization(SCALAR_FUNCTIONS + AGGREGATE_FUNCTIONS)
@pytest.mark.produce_substrait_snapshot
def test_producer_arithmetic_functions(
self,
snapshot,
record_property,
test_name: str,
local_files: dict[str, str],
named_tables: dict[str, str],
sql_query: tuple,
producer,
) -> None:
test_name = f"function:arithmetic:{test_name}"
substrait_producer_sql_test(
test_name,
snapshot,
record_property,
self.db_connection,
local_files,
named_tables,
sql_query,
producer,
)

@custom_parametrization(SCALAR_FUNCTIONS + AGGREGATE_FUNCTIONS)
@pytest.mark.consume_substrait_snapshot
@pytest.mark.usefixtures('mark_consumer_tests_as_xfail')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from substrait_consumer.functional.boolean_configs import (
AGGREGATE_FUNCTIONS, SCALAR_FUNCTIONS)
from substrait_consumer.functional.common import (
generate_snapshot_results, load_custom_duckdb_table,
substrait_consumer_sql_test, substrait_producer_sql_test)
generate_snapshot_results,
load_custom_duckdb_table,
substrait_consumer_sql_test,
)
from substrait_consumer.parametrization import custom_parametrization


Expand All @@ -30,30 +32,6 @@ def setup_teardown_function(request):

cls.db_connection.close()

@custom_parametrization(SCALAR_FUNCTIONS + AGGREGATE_FUNCTIONS)
@pytest.mark.produce_substrait_snapshot
def test_producer_boolean_functions(
self,
snapshot,
record_property,
test_name: str,
local_files: dict[str, str],
named_tables: dict[str, str],
sql_query: tuple,
producer,
) -> None:
test_name = f"function:boolean:{test_name}"
substrait_producer_sql_test(
test_name,
snapshot,
record_property,
self.db_connection,
local_files,
named_tables,
sql_query,
producer,
)

@custom_parametrization(SCALAR_FUNCTIONS + AGGREGATE_FUNCTIONS)
@pytest.mark.consume_substrait_snapshot
def test_consumer_boolean_functions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import pytest

from substrait_consumer.functional.common import (
generate_snapshot_results, load_custom_duckdb_table,
substrait_consumer_sql_test, substrait_producer_sql_test)
generate_snapshot_results,
load_custom_duckdb_table,
substrait_consumer_sql_test,
)
from substrait_consumer.functional.comparison_configs import SCALAR_FUNCTIONS
from substrait_consumer.parametrization import custom_parametrization

Expand All @@ -29,30 +31,6 @@ def setup_teardown_function(request):

cls.db_connection.close()

@custom_parametrization(SCALAR_FUNCTIONS)
@pytest.mark.produce_substrait_snapshot
def test_producer_comparison_functions(
self,
snapshot,
record_property,
test_name: str,
local_files: dict[str, str],
named_tables: dict[str, str],
sql_query: tuple,
producer,
) -> None:
test_name = f"function:comparison:{test_name}"
substrait_producer_sql_test(
test_name,
snapshot,
record_property,
self.db_connection,
local_files,
named_tables,
sql_query,
producer,
)

@custom_parametrization(SCALAR_FUNCTIONS)
@pytest.mark.consume_substrait_snapshot
def test_consumer_comparison_functions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import pytest

from substrait_consumer.functional.common import (
generate_snapshot_results, substrait_consumer_sql_test,
substrait_producer_sql_test)
generate_snapshot_results,
substrait_consumer_sql_test,
)
from substrait_consumer.functional.datetime_configs import SCALAR_FUNCTIONS
from substrait_consumer.parametrization import custom_parametrization
from substrait_consumer.consumers.datafusion_consumer import DataFusionConsumer
Expand Down Expand Up @@ -44,30 +45,6 @@ def setup_teardown_function(request):

cls.db_connection.close()

@custom_parametrization(SCALAR_FUNCTIONS)
@pytest.mark.produce_substrait_snapshot
def test_producer_datetime_functions(
self,
snapshot,
record_property,
test_name: str,
local_files: dict[str, str],
named_tables: dict[str, str],
sql_query: tuple,
producer,
) -> None:
test_name = f"function:datetime:{test_name}"
substrait_producer_sql_test(
test_name,
snapshot,
record_property,
self.db_connection,
local_files,
named_tables,
sql_query,
producer,
)

@custom_parametrization(SCALAR_FUNCTIONS)
@pytest.mark.consume_substrait_snapshot
@pytest.mark.usefixtures('mark_consumer_tests_as_xfail')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import pytest

from substrait_consumer.functional.common import (
generate_snapshot_results, substrait_consumer_sql_test,
substrait_producer_sql_test)
generate_snapshot_results,
substrait_consumer_sql_test,
)
from substrait_consumer.functional.logarithmic_configs import SCALAR_FUNCTIONS
from substrait_consumer.parametrization import custom_parametrization

Expand All @@ -28,30 +29,6 @@ def setup_teardown_function(request):

cls.db_connection.close()

@custom_parametrization(SCALAR_FUNCTIONS)
@pytest.mark.produce_substrait_snapshot
def test_producer_logarithmic_functions(
self,
snapshot,
record_property,
test_name: str,
local_files: dict[str, str],
named_tables: dict[str, str],
sql_query: tuple,
producer,
) -> None:
test_name = f"function:logarithmic:{test_name}"
substrait_producer_sql_test(
test_name,
snapshot,
record_property,
self.db_connection,
local_files,
named_tables,
sql_query,
producer,
)

@custom_parametrization(SCALAR_FUNCTIONS)
@pytest.mark.consume_substrait_snapshot
def test_consumer_logarithmic_functions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import pytest

from substrait_consumer.functional.common import (
generate_snapshot_results, substrait_consumer_sql_test,
substrait_producer_sql_test)
generate_snapshot_results,
substrait_consumer_sql_test,
)
from substrait_consumer.functional.rounding_configs import SCALAR_FUNCTIONS
from substrait_consumer.parametrization import custom_parametrization
from substrait_consumer.consumers.datafusion_consumer import DataFusionConsumer
Expand Down Expand Up @@ -38,30 +39,6 @@ def setup_teardown_function(request):

cls.db_connection.close()

@custom_parametrization(SCALAR_FUNCTIONS)
@pytest.mark.produce_substrait_snapshot
def test_producer_rounding_functions(
self,
snapshot,
record_property,
test_name: str,
local_files: dict[str, str],
named_tables: dict[str, str],
sql_query: tuple,
producer,
) -> None:
test_name = f"function:rounding:{test_name}"
substrait_producer_sql_test(
test_name,
snapshot,
record_property,
self.db_connection,
local_files,
named_tables,
sql_query,
producer,
)

@custom_parametrization(SCALAR_FUNCTIONS)
@pytest.mark.consume_substrait_snapshot
@pytest.mark.usefixtures('mark_consumer_tests_as_xfail')
Expand Down
Loading

0 comments on commit 65068bd

Please sign in to comment.