Skip to content

Commit

Permalink
Move some code from instructlab.utils
Browse files Browse the repository at this point in the history
This code was only used by instructlab.sdg, so move it over here
instead of leaving it back in the `instructlab` repo.

Part of issue #11

Signed-off-by: Russell Bryant <[email protected]>
  • Loading branch information
russellb committed Jun 24, 2024
1 parent cba3a62 commit a8f64f3
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 9 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
click>=8.1.7,<9.0.0
httpx>=0.25.0,<1.0.0
jinja2
langchain-text-splitters
openai>=1.13.3,<2.0.0
rouge_score
tqdm>=4.66.2,<5.0.0
13 changes: 5 additions & 8 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@
# Third Party
# instructlab - All of these need to go away (other than sdg) - issue #6
from instructlab.configuration import get_model_family
from instructlab.utils import (
chunk_document,
max_seed_example_tokens,
num_chars_from_tokens,
)
from jinja2 import Template
from rouge_score import rouge_scorer
import click
Expand Down Expand Up @@ -375,7 +370,7 @@ def _gen_test_data(
server_ctx_size,
output_file_test,
):
max_seed_chars = num_chars_from_tokens(max_seed_tokens)
max_seed_chars = utils.num_chars_from_tokens(max_seed_tokens)
for seed_example in seed_instruction_data:
if (
len(seed_example["instruction"])
Expand All @@ -398,7 +393,7 @@ def _gen_test_data(

documents = seed_example["document"]
if documents:
seed_example["document"] = chunk_document(
seed_example["document"] = utils.chunk_document(
documents=documents,
server_ctx_size=server_ctx_size,
chunk_word_count=chunk_word_count,
Expand Down Expand Up @@ -493,7 +488,9 @@ def generate_data(
prompt_template = check_prompt_file(
prompt_file_path, get_model_family(model_family, model_name)
)
max_seed_tokens = max_seed_example_tokens(server_ctx_size, len(prompt_template))
max_seed_tokens = utils.max_seed_example_tokens(
server_ctx_size, len(prompt_template)
)

name = Path(model_name).stem # Just in case it is a file path
date_suffix = datetime.now().replace(microsecond=0).isoformat().replace(":", "_")
Expand Down
87 changes: 86 additions & 1 deletion src/instructlab/sdg/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
from typing import Optional, Sequence, Union
from typing import List, Optional, Sequence, Union
import copy
import dataclasses
import io
Expand All @@ -15,11 +15,14 @@
# instructlab - TODO these need to go away, issue #6
from instructlab.configuration import DEFAULT_API_KEY, DEFAULT_MODEL_OLD
from instructlab.utils import get_sysprompt
from langchain_text_splitters import RecursiveCharacterTextSplitter
from openai import OpenAI, OpenAIError
import httpx

StrOrOpenAIObject = Union[str, object]

DEFAULT_CHUNK_OVERLAP = 100


class GenerateException(Exception):
"""An exception raised during generate step."""
Expand Down Expand Up @@ -219,3 +222,85 @@ def jload(f, mode="r"):
"""Load a .json file into a dictionary."""
with _make_r_io_base(f, mode) as f_:
return json.load(f_)


def num_tokens_from_words(num_words) -> int:
return int(num_words * 1.3) # 1 word ~ 1.3 token


def num_chars_from_tokens(num_tokens) -> int:
return int(num_tokens * 4) # 1 token ~ 4 English character


def num_tokens_from_chars(num_chars) -> int:
return int(num_chars / 4) # 1 token ~ 4 English character


def max_seed_example_tokens(server_ctx_size, prompt_num_chars) -> int:
"""
Estimates the maximum number of tokens any seed example can have based
on the server context size and number of characters in the selected prompt.
A lot has to fit into the given server context size:
- The prompt itself, which can vary in size a bit based on model family and knowledge vs skill
- Two seed examples, which we append to the prompt template.
- A knowledge document chunk, if this is a knowledge example.
- The generated completion, which can vary substantially in length.
This is an attempt to roughly estimate the maximum size any seed example
(question + answer + context values from the yaml) should be to even have
a hope of not often exceeding the server's maximum context size.
NOTE: This does not take into account knowledge document chunks. It's meant
to calculate the maximum size that any seed example should be, whether knowledge
or skill. Knowledge seed examples will want to stay well below this limit.
NOTE: This is a very simplistic calculation, and examples with lots of numbers
or punctuation may have quite a different token count than the estimates here,
depending on the model (and thus tokenizer) in use. That's ok, as it's only
meant to be a rough estimate.
Args:
server_ctx_size (int): Size of the server context, in tokens.
prompt_num_chars (int): Number of characters in the prompt (not including the examples)
"""
# Ensure we have at least 1024 tokens available for a response.
max_seed_tokens = server_ctx_size - 1024
# Subtract the number of tokens in our prompt template
max_seed_tokens = max_seed_tokens - num_tokens_from_chars(prompt_num_chars)
# Divide number of characters by 2, since we insert 2 examples
max_seed_tokens = int(max_seed_tokens / 2)
return max_seed_tokens


def chunk_document(documents: List, server_ctx_size, chunk_word_count) -> List[str]:
"""
Iterates over the documents and splits them into chunks based on the word count provided by the user.
Args:
documents (dict): List of documents retrieved from git (can also consist of a single document).
server_ctx_size (int): Context window size of server.
chunk_word_count (int): Maximum number of words to chunk a document.
Returns:
List[str]: List of chunked documents.
"""
no_tokens_per_doc = num_tokens_from_words(chunk_word_count)
if no_tokens_per_doc > int(server_ctx_size - 1024):
raise ValueError(
"Error: {}".format(
str(
f"Given word count ({chunk_word_count}) per doc will exceed the server context window size ({server_ctx_size})"
)
)
)
content = []
text_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n", " "],
chunk_size=num_chars_from_tokens(no_tokens_per_doc),
chunk_overlap=DEFAULT_CHUNK_OVERLAP,
)

for docs in documents:
temp = text_splitter.create_documents([docs])
content.extend([item.page_content for item in temp])

return content

0 comments on commit a8f64f3

Please sign in to comment.