Skip to content

Commit

Permalink
offshoot gen_train_data() from very long generate_data()
Browse files Browse the repository at this point in the history
Function generate_data() is huge and messy.

Refactor it.

Signed-off-by: Costa Shulyupin <[email protected]>
  • Loading branch information
makelinux authored and russellb committed Jun 24, 2024
1 parent 883bb54 commit c0f7320
Showing 1 changed file with 23 additions and 19 deletions.
42 changes: 23 additions & 19 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,26 @@ def _gen_test_data(
outfile.write("\n")


def _gen_train_data(machine_instruction_data, output_file_train):
train_data = []
for synth_example in machine_instruction_data:
user = synth_example["instruction"]
if len(synth_example["input"]) > 0:
user += "\n" + synth_example["input"]
train_data.append(
{
"system": utils.get_sysprompt(),
"user": unescape(user),
"assistant": unescape(synth_example["output"]),
}
)
# utils.jdump(train_data, output_file_train)
with open(output_file_train, "w", encoding="utf-8") as outfile:
for entry in train_data:
json.dump(entry, outfile, ensure_ascii=False)
outfile.write("\n")


def generate_data(
logger,
api_base,
Expand Down Expand Up @@ -606,25 +626,9 @@ def generate_data(
f"Generated {total} instructions(discarded {discarded}), rouged {total - keep}, kept {keep} instructions"
)
utils.jdump(machine_instruction_data, os.path.join(output_dir, output_file))
train_data = []
for synth_example in machine_instruction_data:
user = synth_example["instruction"]
if len(synth_example["input"]) > 0:
user += "\n" + synth_example["input"]
train_data.append(
{
"system": utils.get_sysprompt(),
"user": unescape(user),
"assistant": unescape(synth_example["output"]),
}
)
# utils.jdump(train_data, os.path.join(output_dir, output_file_train))
with open(
os.path.join(output_dir, output_file_train), "w", encoding="utf-8"
) as outfile:
for entry in train_data:
json.dump(entry, outfile, ensure_ascii=False)
outfile.write("\n")
_gen_train_data(
machine_instruction_data, os.path.join(output_dir, output_file_train)
)

progress_bar.close()

Expand Down

0 comments on commit c0f7320

Please sign in to comment.