Skip to content

Commit

Permalink
Add two unit tests for docling model path
Browse files Browse the repository at this point in the history
These simple unit tests just test the cases where we found a
config.yaml to parse for the docling model path and where we didn't.

Signed-off-by: Ben Browning <[email protected]>
(cherry picked from commit 0e9d75d)
  • Loading branch information
bbrowning authored and mergify[bot] committed Nov 15, 2024
1 parent 5244929 commit a6689e3
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 1 deletion.
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

# Standard
from unittest import mock
import pathlib
import typing

# Third Party
from datasets import Dataset
Expand All @@ -17,6 +19,14 @@
# Local
from .taxonomy import MockTaxonomy

TESTS_PATH = pathlib.Path(__file__).parent.absolute()


@pytest.fixture
def testdata_path() -> typing.Generator[pathlib.Path, None, None]:
"""Path to local test data directory"""
yield TESTS_PATH / "testdata"


def get_ctx(**kwargs) -> PipelineContext:
kwargs.setdefault("client", mock.MagicMock())
Expand Down
36 changes: 35 additions & 1 deletion tests/test_generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import yaml

# First Party
from instructlab.sdg.generate_data import _context_init, generate_data
from instructlab.sdg.generate_data import _context_init, _sdg_init, generate_data
from instructlab.sdg.llmblock import LLMBlock
from instructlab.sdg.pipeline import PipelineContext

Expand Down Expand Up @@ -548,3 +548,37 @@ def test_context_init_batch_size_optional():
batch_num_workers=32,
)
assert ctx.batch_size == 20


def test_sdg_init_docling_path_config_found(testdata_path):
with patch.dict(os.environ):
os.environ["XDG_DATA_HOME"] = str(testdata_path.joinpath("mock_xdg_data_dir"))
ctx = _context_init(
None,
"mixtral",
"foo.bar",
1,
"/checkpoint/dir",
1,
batch_size=20,
batch_num_workers=32,
)
_, _, _, docling_model_path = _sdg_init(ctx, "full")
assert docling_model_path == "/mock/docling-models"


def test_sdg_init_docling_path_config_not_found(testdata_path):
with patch.dict(os.environ):
os.environ["XDG_DATA_HOME"] = str(testdata_path.joinpath("nonexistent_dir"))
ctx = _context_init(
None,
"mixtral",
"foo.bar",
1,
"/checkpoint/dir",
1,
batch_size=20,
batch_num_workers=32,
)
_, _, _, docling_model_path = _sdg_init(ctx, "full")
assert docling_model_path is None
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
models:
- path: /mock/docling-models
source: https://huggingface.co/ds4sd/docling-models
revision: main

0 comments on commit a6689e3

Please sign in to comment.