Skip to content

Commit

Permalink
offshoot gen_test_data() from very long generate_data()
Browse files Browse the repository at this point in the history
Function generate_data() is huge and messy. Refactor it.
Move code around `json.dump` to gen_test_data()
because it is under loop
`for instruction_data_entry in instruction_data`
by accidentally.

Code is just moved, without internal changes.

Signed-off-by: Costa Shulyupin <[email protected]>
  • Loading branch information
makelinux authored and russellb committed Jun 24, 2024
1 parent f3090c1 commit 883bb54
Showing 1 changed file with 74 additions and 55 deletions.
129 changes: 74 additions & 55 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,53 +362,19 @@ def read_taxonomy(*args, **kwargs):
return instructlab.utils.read_taxonomy(*args, **kwargs)


def generate_data(
logger,
api_base,
tls_insecure,
model_family: str,
yaml_rules: Optional[str] = None,
output_dir: Optional[str] = None,
taxonomy: Optional[str] = None,
taxonomy_base: Optional[str] = None,
prompt_file_path: Optional[str] = None,
model_name: Optional[str] = None,
num_cpus: Optional[int] = None,
num_instructions_to_generate: Optional[int] = None,
num_prompt_instructions=2,
request_batch_size=5,
temperature=1.0,
top_p=1.0,
rouge_threshold: Optional[float] = None,
console_output=True,
api_key: Optional[str] = None,
chunk_word_count=None,
server_ctx_size=None,
tls_client_cert: Optional[str] = None,
tls_client_key: Optional[str] = None,
tls_client_passwd: Optional[str] = None,
):
seed_instruction_data = []
machine_seed_instruction_data = []
generate_start = time.time()
def unescape(s):
return bytes(s, "utf-8").decode("utf-8")

if not os.path.exists(output_dir):
os.mkdir(output_dir)

# check taxonomy first then seed_tasks_path
# throw an error if both not found
# pylint: disable=broad-exception-caught,raise-missing-from
if taxonomy and os.path.exists(taxonomy):
seed_instruction_data = read_taxonomy(
logger, taxonomy, taxonomy_base, yaml_rules
)
else:
raise SystemExit(f"Error: taxonomy ({taxonomy}) does not exist.")

prompt_template = check_prompt_file(
prompt_file_path, get_model_family(model_family, model_name)
)
max_seed_tokens = max_seed_example_tokens(server_ctx_size, len(prompt_template))
def _gen_test_data(
logger,
seed_instruction_data,
max_seed_tokens,
taxonomy,
chunk_word_count,
server_ctx_size,
output_file_test,
):
max_seed_chars = num_chars_from_tokens(max_seed_tokens)
for seed_example in seed_instruction_data:
if (
Expand All @@ -426,9 +392,6 @@ def generate_data(
if not seeds:
raise SystemExit("Nothing to generate. Exiting.")

def unescape(s):
return bytes(s, "utf-8").decode("utf-8")

test_data = []
for seed_example in seed_instruction_data:
user = seed_example["instruction"]
Expand Down Expand Up @@ -457,6 +420,60 @@ def unescape(s):
fg="red",
)
raise click.exceptions.Exit(1)
# utils.jdump(test_data, os.path.join(output_dir, output_file_test))
with open(output_file_test, "w", encoding="utf-8") as outfile:
for entry in test_data:
json.dump(entry, outfile, ensure_ascii=False)
outfile.write("\n")


def generate_data(
logger,
api_base,
tls_insecure,
model_family: str,
yaml_rules: Optional[str] = None,
output_dir: Optional[str] = None,
taxonomy: Optional[str] = None,
taxonomy_base: Optional[str] = None,
prompt_file_path: Optional[str] = None,
model_name: Optional[str] = None,
num_cpus: Optional[int] = None,
num_instructions_to_generate: Optional[int] = None,
num_prompt_instructions=2,
request_batch_size=5,
temperature=1.0,
top_p=1.0,
rouge_threshold: Optional[float] = None,
console_output=True,
api_key: Optional[str] = None,
chunk_word_count=None,
server_ctx_size=None,
tls_client_cert: Optional[str] = None,
tls_client_key: Optional[str] = None,
tls_client_passwd: Optional[str] = None,
):
seed_instruction_data = []
machine_seed_instruction_data = []
generate_start = time.time()

if not os.path.exists(output_dir):
os.mkdir(output_dir)

# check taxonomy first then seed_tasks_path
# throw an error if both not found
# pylint: disable=broad-exception-caught,raise-missing-from
if taxonomy and os.path.exists(taxonomy):
seed_instruction_data = read_taxonomy(
logger, taxonomy, taxonomy_base, yaml_rules
)
else:
raise SystemExit(f"Error: taxonomy ({taxonomy}) does not exist.")

prompt_template = check_prompt_file(
prompt_file_path, get_model_family(model_family, model_name)
)
max_seed_tokens = max_seed_example_tokens(server_ctx_size, len(prompt_template))

name = Path(model_name).stem # Just in case it is a file path
date_suffix = datetime.now().replace(microsecond=0).isoformat().replace(":", "_")
Expand All @@ -466,6 +483,15 @@ def unescape(s):
output_file_discarded = os.path.join(
output_dir, f"discarded_{name}_{date_suffix}.log"
)
_gen_test_data(
logger,
seed_instruction_data,
max_seed_tokens,
taxonomy,
chunk_word_count,
server_ctx_size,
os.path.join(output_dir, output_file_test),
)
logger.debug(f"Generating to: {os.path.join(output_dir, output_file)}")

request_idx = 0
Expand Down Expand Up @@ -599,13 +625,6 @@ def unescape(s):
for entry in train_data:
json.dump(entry, outfile, ensure_ascii=False)
outfile.write("\n")
# utils.jdump(test_data, os.path.join(output_dir, output_file_test))
with open(
os.path.join(output_dir, output_file_test), "w", encoding="utf-8"
) as outfile:
for entry in test_data:
json.dump(entry, outfile, ensure_ascii=False)
outfile.write("\n")

progress_bar.close()

Expand Down

0 comments on commit 883bb54

Please sign in to comment.