diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index 1221eff7..aac007b7 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -33,6 +33,7 @@ PipelineContext, ) from instructlab.sdg.utils import GenerateException, models +from instructlab.sdg.utils.json import jldump from instructlab.sdg.utils.taxonomy import ( leaf_node_to_samples, read_taxonomy_leaf_nodes, @@ -112,15 +113,9 @@ def _gen_train_data( } messages_data.append(_convert_to_messages(sample)) - with open(output_file_train, "w", encoding="utf-8") as outfile: - for entry in train_data: - json.dump(entry, outfile, ensure_ascii=False) - outfile.write("\n") + jldump(train_data, output_file_train) - with open(output_file_messages, "w", encoding="utf-8") as outfile: - for entry in messages_data: - json.dump(entry, outfile, ensure_ascii=False) - outfile.write("\n") + jldump(messages_data, output_file_messages) def _knowledge_seed_example_to_test_data(seed_example, system_prompt): @@ -170,10 +165,7 @@ def _gen_test_data( } ) - with open(output_file_test, "w", encoding="utf-8") as outfile: - for entry in test_data: - json.dump(entry, outfile, ensure_ascii=False) - outfile.write("\n") + jldump(test_data, output_file_test) def _check_pipeline_dir(pipeline): diff --git a/src/instructlab/sdg/utils/json.py b/src/instructlab/sdg/utils/json.py index 8fd25268..041d817b 100644 --- a/src/instructlab/sdg/utils/json.py +++ b/src/instructlab/sdg/utils/json.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Standard +from typing import Any, Iterable import io import json import os @@ -46,3 +47,16 @@ def jload(f, mode="r"): """Load a .json file into a dictionary.""" with _make_r_io_base(f, mode) as f_: return json.load(f_) + + +def jldump(data: Iterable[Any], out: str | io.IOBase) -> None: + """Dump a list to a file in jsonl format. + + Args: + data: An data to be written. + f: io.IOBase or file path + """ + with _make_w_io_base(out, "w") as outfile: + for entry in data: + json.dump(entry, outfile, ensure_ascii=False) + outfile.write("\n")