From a6689e3984cbeb20353ff35679da503ebbe47066 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Thu, 14 Nov 2024 12:26:12 -0500 Subject: [PATCH] Add two unit tests for docling model path These simple unit tests just test the cases where we found a config.yaml to parse for the docling model path and where we didn't. Signed-off-by: Ben Browning (cherry picked from commit 0e9d75d6872ad469d6e5476a2c90d5546a80ed9b) --- tests/conftest.py | 10 ++++++ tests/test_generate_data.py | 36 ++++++++++++++++++- .../instructlab/sdg/models/config.yaml | 4 +++ 3 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 tests/testdata/mock_xdg_data_dir/instructlab/sdg/models/config.yaml diff --git a/tests/conftest.py b/tests/conftest.py index 80d61903..ed3fd8c4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,8 @@ # Standard from unittest import mock +import pathlib +import typing # Third Party from datasets import Dataset @@ -17,6 +19,14 @@ # Local from .taxonomy import MockTaxonomy +TESTS_PATH = pathlib.Path(__file__).parent.absolute() + + +@pytest.fixture +def testdata_path() -> typing.Generator[pathlib.Path, None, None]: + """Path to local test data directory""" + yield TESTS_PATH / "testdata" + def get_ctx(**kwargs) -> PipelineContext: kwargs.setdefault("client", mock.MagicMock()) diff --git a/tests/test_generate_data.py b/tests/test_generate_data.py index f382a351..0d04a80f 100644 --- a/tests/test_generate_data.py +++ b/tests/test_generate_data.py @@ -20,7 +20,7 @@ import yaml # First Party -from instructlab.sdg.generate_data import _context_init, generate_data +from instructlab.sdg.generate_data import _context_init, _sdg_init, generate_data from instructlab.sdg.llmblock import LLMBlock from instructlab.sdg.pipeline import PipelineContext @@ -548,3 +548,37 @@ def test_context_init_batch_size_optional(): batch_num_workers=32, ) assert ctx.batch_size == 20 + + +def test_sdg_init_docling_path_config_found(testdata_path): + with patch.dict(os.environ): + os.environ["XDG_DATA_HOME"] = str(testdata_path.joinpath("mock_xdg_data_dir")) + ctx = _context_init( + None, + "mixtral", + "foo.bar", + 1, + "/checkpoint/dir", + 1, + batch_size=20, + batch_num_workers=32, + ) + _, _, _, docling_model_path = _sdg_init(ctx, "full") + assert docling_model_path == "/mock/docling-models" + + +def test_sdg_init_docling_path_config_not_found(testdata_path): + with patch.dict(os.environ): + os.environ["XDG_DATA_HOME"] = str(testdata_path.joinpath("nonexistent_dir")) + ctx = _context_init( + None, + "mixtral", + "foo.bar", + 1, + "/checkpoint/dir", + 1, + batch_size=20, + batch_num_workers=32, + ) + _, _, _, docling_model_path = _sdg_init(ctx, "full") + assert docling_model_path is None diff --git a/tests/testdata/mock_xdg_data_dir/instructlab/sdg/models/config.yaml b/tests/testdata/mock_xdg_data_dir/instructlab/sdg/models/config.yaml new file mode 100644 index 00000000..657cfdf3 --- /dev/null +++ b/tests/testdata/mock_xdg_data_dir/instructlab/sdg/models/config.yaml @@ -0,0 +1,4 @@ +models: +- path: /mock/docling-models + source: https://huggingface.co/ds4sd/docling-models + revision: main