From a8f64f3ea3745efff1e49c897ae1f886a6b80507 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Mon, 24 Jun 2024 17:29:58 -0400 Subject: [PATCH] Move some code from instructlab.utils This code was only used by instructlab.sdg, so move it over here instead of leaving it back in the `instructlab` repo. Part of issue #11 Signed-off-by: Russell Bryant --- requirements.txt | 1 + src/instructlab/sdg/generate_data.py | 13 ++--- src/instructlab/sdg/utils.py | 87 +++++++++++++++++++++++++++- 3 files changed, 92 insertions(+), 9 deletions(-) diff --git a/requirements.txt b/requirements.txt index fb0e4a2a..ed37c6c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ click>=8.1.7,<9.0.0 httpx>=0.25.0,<1.0.0 jinja2 +langchain-text-splitters openai>=1.13.3,<2.0.0 rouge_score tqdm>=4.66.2,<5.0.0 diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index f6c052ce..cc0a9ef3 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -16,11 +16,6 @@ # Third Party # instructlab - All of these need to go away (other than sdg) - issue #6 from instructlab.configuration import get_model_family -from instructlab.utils import ( - chunk_document, - max_seed_example_tokens, - num_chars_from_tokens, -) from jinja2 import Template from rouge_score import rouge_scorer import click @@ -375,7 +370,7 @@ def _gen_test_data( server_ctx_size, output_file_test, ): - max_seed_chars = num_chars_from_tokens(max_seed_tokens) + max_seed_chars = utils.num_chars_from_tokens(max_seed_tokens) for seed_example in seed_instruction_data: if ( len(seed_example["instruction"]) @@ -398,7 +393,7 @@ def _gen_test_data( documents = seed_example["document"] if documents: - seed_example["document"] = chunk_document( + seed_example["document"] = utils.chunk_document( documents=documents, server_ctx_size=server_ctx_size, chunk_word_count=chunk_word_count, @@ -493,7 +488,9 @@ def generate_data( prompt_template = check_prompt_file( prompt_file_path, get_model_family(model_family, model_name) ) - max_seed_tokens = max_seed_example_tokens(server_ctx_size, len(prompt_template)) + max_seed_tokens = utils.max_seed_example_tokens( + server_ctx_size, len(prompt_template) + ) name = Path(model_name).stem # Just in case it is a file path date_suffix = datetime.now().replace(microsecond=0).isoformat().replace(":", "_") diff --git a/src/instructlab/sdg/utils.py b/src/instructlab/sdg/utils.py index 83423fe8..5ad54ada 100644 --- a/src/instructlab/sdg/utils.py +++ b/src/instructlab/sdg/utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Standard -from typing import Optional, Sequence, Union +from typing import List, Optional, Sequence, Union import copy import dataclasses import io @@ -15,11 +15,14 @@ # instructlab - TODO these need to go away, issue #6 from instructlab.configuration import DEFAULT_API_KEY, DEFAULT_MODEL_OLD from instructlab.utils import get_sysprompt +from langchain_text_splitters import RecursiveCharacterTextSplitter from openai import OpenAI, OpenAIError import httpx StrOrOpenAIObject = Union[str, object] +DEFAULT_CHUNK_OVERLAP = 100 + class GenerateException(Exception): """An exception raised during generate step.""" @@ -219,3 +222,85 @@ def jload(f, mode="r"): """Load a .json file into a dictionary.""" with _make_r_io_base(f, mode) as f_: return json.load(f_) + + +def num_tokens_from_words(num_words) -> int: + return int(num_words * 1.3) # 1 word ~ 1.3 token + + +def num_chars_from_tokens(num_tokens) -> int: + return int(num_tokens * 4) # 1 token ~ 4 English character + + +def num_tokens_from_chars(num_chars) -> int: + return int(num_chars / 4) # 1 token ~ 4 English character + + +def max_seed_example_tokens(server_ctx_size, prompt_num_chars) -> int: + """ + Estimates the maximum number of tokens any seed example can have based + on the server context size and number of characters in the selected prompt. + + A lot has to fit into the given server context size: + - The prompt itself, which can vary in size a bit based on model family and knowledge vs skill + - Two seed examples, which we append to the prompt template. + - A knowledge document chunk, if this is a knowledge example. + - The generated completion, which can vary substantially in length. + + This is an attempt to roughly estimate the maximum size any seed example + (question + answer + context values from the yaml) should be to even have + a hope of not often exceeding the server's maximum context size. + + NOTE: This does not take into account knowledge document chunks. It's meant + to calculate the maximum size that any seed example should be, whether knowledge + or skill. Knowledge seed examples will want to stay well below this limit. + + NOTE: This is a very simplistic calculation, and examples with lots of numbers + or punctuation may have quite a different token count than the estimates here, + depending on the model (and thus tokenizer) in use. That's ok, as it's only + meant to be a rough estimate. + + Args: + server_ctx_size (int): Size of the server context, in tokens. + prompt_num_chars (int): Number of characters in the prompt (not including the examples) + """ + # Ensure we have at least 1024 tokens available for a response. + max_seed_tokens = server_ctx_size - 1024 + # Subtract the number of tokens in our prompt template + max_seed_tokens = max_seed_tokens - num_tokens_from_chars(prompt_num_chars) + # Divide number of characters by 2, since we insert 2 examples + max_seed_tokens = int(max_seed_tokens / 2) + return max_seed_tokens + + +def chunk_document(documents: List, server_ctx_size, chunk_word_count) -> List[str]: + """ + Iterates over the documents and splits them into chunks based on the word count provided by the user. + Args: + documents (dict): List of documents retrieved from git (can also consist of a single document). + server_ctx_size (int): Context window size of server. + chunk_word_count (int): Maximum number of words to chunk a document. + Returns: + List[str]: List of chunked documents. + """ + no_tokens_per_doc = num_tokens_from_words(chunk_word_count) + if no_tokens_per_doc > int(server_ctx_size - 1024): + raise ValueError( + "Error: {}".format( + str( + f"Given word count ({chunk_word_count}) per doc will exceed the server context window size ({server_ctx_size})" + ) + ) + ) + content = [] + text_splitter = RecursiveCharacterTextSplitter( + separators=["\n\n", "\n", " "], + chunk_size=num_chars_from_tokens(no_tokens_per_doc), + chunk_overlap=DEFAULT_CHUNK_OVERLAP, + ) + + for docs in documents: + temp = text_splitter.create_documents([docs]) + content.extend([item.page_content for item in temp]) + + return content