Skip to content

Commit

Permalink
expose max_num_tokens as configurable
Browse files Browse the repository at this point in the history
max-num-tokens is a nice way to run a shorter or longer SDG run.
locally I have been modifiyng the pipeline yaml from 2048 to 512 which ends up just generating less data
exposing this to the CLI could allow power users to run different types of SDG runs!

Signed-off-by: Charlie Doern <[email protected]>
  • Loading branch information
cdoern committed Nov 7, 2024
1 parent 4c82c05 commit 835a194
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 7 deletions.
10 changes: 9 additions & 1 deletion src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
# pylint: disable=ungrouped-imports
from instructlab.sdg.datamixing import DataMixer, _get_question_hack, _get_response_hack
from instructlab.sdg.eval_data import generate_eval_task_data, mmlubench_pipe_init
from instructlab.sdg.llmblock import MODEL_FAMILY_MERLINITE, MODEL_FAMILY_MIXTRAL
from instructlab.sdg.llmblock import (
DEFAULT_MAX_NUM_TOKENS,
MODEL_FAMILY_MERLINITE,
MODEL_FAMILY_MIXTRAL,
)
from instructlab.sdg.pipeline import (
FULL_PIPELINES_PACKAGE,
SIMPLE_PIPELINES_PACKAGE,
Expand Down Expand Up @@ -183,6 +187,7 @@ def _context_init(
save_freq: int,
batch_num_workers: Optional[int],
batch_size: Optional[int],
max_num_tokens: Optional[int] = DEFAULT_MAX_NUM_TOKENS,
):
extra_kwargs = {}
if batch_size is not None:
Expand All @@ -196,6 +201,7 @@ def _context_init(
num_instructions_to_generate=num_instructions_to_generate,
checkpoint_dir=checkpoint_dir,
save_freq=save_freq,
max_num_tokens=max_num_tokens,
**extra_kwargs,
)

Expand Down Expand Up @@ -281,6 +287,7 @@ def generate_data(
pipeline: Optional[str] = "simple",
batch_size: Optional[int] = None,
checkpoint_dir: Optional[str] = None,
max_num_tokens: Optional[int] = DEFAULT_MAX_NUM_TOKENS,
) -> None:
"""Generate data for training and testing a model.
Expand Down Expand Up @@ -343,6 +350,7 @@ def generate_data(
1, # save_freq
batch_size=batch_size,
batch_num_workers=num_cpus,
max_num_tokens=max_num_tokens,
)

knowledge_pipe, freeform_skills_pipe, grounded_skills_pipe = _sdg_init(
Expand Down
27 changes: 23 additions & 4 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

logger = logging.getLogger(__name__)

DEFAULT_MAX_NUM_TOKENS = 4096

MODEL_FAMILY_MIXTRAL = "mixtral"
MODEL_FAMILY_MERLINITE = "merlinite"

Expand Down Expand Up @@ -62,6 +64,7 @@ def __init__(
ctx,
pipe,
block_name,
max_num_token_override,
config_path,
output_cols,
model_prompt=None,
Expand All @@ -81,8 +84,19 @@ def __init__(
self.parser_name = parser_kwargs.get("parser_name", None)
self.parsing_pattern = parser_kwargs.get("parsing_pattern", None)
self.parser_cleanup_tags = parser_kwargs.get("parser_cleanup_tags", None)
# max_num_tokens should only be applicable to knowledge blocks
# gen_knowledge if the full/simple pipeline's knowledge generation block
if block_name != "gen_knowledge":
logger.debug(
f"Not applying max_num_tokens to block {block_name}. This is only applicable for gen_knowledge."
)
max_num_token_override = DEFAULT_MAX_NUM_TOKENS
self.gen_kwargs = self._gen_kwargs(
gen_kwargs, model=self.ctx.model_id, temperature=0, max_tokens=4096
max_num_token_override,
gen_kwargs,
model=self.ctx.model_id,
temperature=0,
max_tokens=DEFAULT_MAX_NUM_TOKENS,
)
# Whether the LLM server supports a list of input prompts
# and supports the n parameter to generate n outputs per input
Expand Down Expand Up @@ -142,23 +156,26 @@ def _format_prompt(self, sample: Dict) -> str:

return prompt if model_prompt is None else model_prompt.format(prompt=prompt)

def _gen_kwargs(self, gen_kwargs, **defaults):
def _gen_kwargs(self, max_num_token_override, gen_kwargs, **defaults):
gen_kwargs = {**defaults, **gen_kwargs}
if (
"n" in gen_kwargs
and isinstance(gen_kwargs["n"], str)
and gen_kwargs["n"] == "scaled"
):
gen_kwargs["n"] = self.ctx.num_instructions_to_generate
if "max_tokens" in gen_kwargs:
gen_kwargs["max_tokens"] = int(gen_kwargs["max_tokens"])
if "temperature" in gen_kwargs:
gen_kwargs["temperature"] = float(gen_kwargs["temperature"])
if max_num_token_override != DEFAULT_MAX_NUM_TOKENS:
gen_kwargs["max_tokens"] = max_num_token_override
elif "max_tokens" in gen_kwargs:
gen_kwargs["max_tokens"] = int(gen_kwargs["max_tokens"])
return gen_kwargs

def _generate(self, samples) -> list:
prompts = [self._format_prompt(sample) for sample in samples]
logger.debug(f"STARTING GENERATION FOR LLMBlock USING PROMPTS: {prompts}")
print(self.gen_kwargs)
if self.server_supports_batched:
response = self.ctx.client.completions.create(
prompt=prompts, **self.gen_kwargs
Expand Down Expand Up @@ -259,6 +276,7 @@ def __init__(
ctx,
pipe,
block_name,
max_num_token_override,
config_paths,
output_cols,
selector_column_name,
Expand All @@ -271,6 +289,7 @@ def __init__(
ctx,
pipe,
block_name,
max_num_token_override,
config_paths[0][0],
output_cols,
model_prompt=model_prompt,
Expand Down
15 changes: 13 additions & 2 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes
central executor pool.
dataset_num_procs: The number of processes to use when performing parallel
map operations on individual datasets.
max_num_tokens: the maximum number of tokens to generate per sample.
"""

# The default batch size of 8 has been determined as a good default for
Expand All @@ -65,6 +66,7 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes
dataset_num_procs: Optional[int] = DEFAULT_DATASET_NUM_PROCS
checkpoint_dir: Optional[str] = None
save_freq: Optional[int] = 1
max_num_tokens: Optional[int] = 4096
batch_size: int = DEFAULT_BATCH_SIZE
batch_num_workers: Optional[int] = None

Expand Down Expand Up @@ -191,11 +193,20 @@ def _generate_single(self, dataset) -> Dataset:
block_name = block_prop["name"]
block_type = _lookup_block_type(block_prop["type"])
block_config = block_prop["config"]
max_num_token_override = self.ctx.max_num_tokens
drop_columns = block_prop.get("drop_columns", [])
drop_duplicates_cols = block_prop.get("drop_duplicates", False)
block = block_type(self.ctx, self, block_name, **block_config)
if block_type in (llmblock.LLMBlock, llmblock.ConditionalLLMBlock):
block = block_type(
self.ctx,
self,
block_name,
max_num_token_override,
**block_config,
)
else:
block = block_type(self.ctx, self, block_name, **block_config)
logger.info("Running block: %s", block_name)

# Execute the block and wrap errors with the block name/type
dataset = block.generate(dataset)
except Exception as err:
Expand Down
19 changes: 19 additions & 0 deletions tests/test_llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_model_prompt_empty_string(self, mock_load_config):
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="test_block",
max_num_token_override=2048,
config_path="",
output_cols=[],
model_prompt="",
Expand All @@ -57,6 +58,7 @@ def test_model_prompt_none(self, mock_load_config):
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="test_block",
max_num_token_override=2048,
config_path="",
output_cols=[],
model_prompt=None, # Or simply omit model_prompt as it defaults to None
Expand All @@ -76,6 +78,7 @@ def test_model_prompt_none(self, mock_load_config):
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="test_block",
max_num_token_override=2048,
config_path="",
output_cols=[],
model_prompt="FOO {prompt} BAR",
Expand All @@ -86,3 +89,19 @@ def test_model_prompt_none(self, mock_load_config):
"FOO pear\nintroduction\nprinciples\nexamples\ngeneration BAR",
"model_prompt should be a non-empty string when set to None",
)

@patch("src.instructlab.sdg.block.Block._load_config")
def test_max_num_tokens_override(self, mock_load_config):
mock_load_config.return_value = self.config_return_value
# Ensure that if a custom model_prompt is specified, it is used correctly
block = LLMBlock(
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="gen_knowledge",
max_num_token_override=512,
config_path="",
output_cols=[],
model_prompt="",
)
num_tokens = block.gen_kwargs["max_tokens"]
assert num_tokens == 512

0 comments on commit 835a194

Please sign in to comment.