Skip to content

Commit

Permalink
expose max-num-tokens and num_instructions_to_generate as configurable
Browse files Browse the repository at this point in the history
max-num-tokens is a nice way to run a shorter or longer SDG run.
locally I have been modifiyng the pipeline yaml from 2048 to 512 which ends up just generating less data
exposing this to the CLI could allow power users to run different types of SDG runs!

num_instructions_to_generate is currently ignored, meaning that if a user passes in anything other than 30 they will only
get 30 questions when training skills. Add `_batch_kwargs` which patches in `num_samples` which in skills generation is the value
of `num_instructions_to_generate`

Signed-off-by: Charlie Doern <[email protected]>
  • Loading branch information
cdoern committed Nov 7, 2024
1 parent 4c82c05 commit ed7ea9f
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 5 deletions.
4 changes: 4 additions & 0 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def _context_init(
save_freq: int,
batch_num_workers: Optional[int],
batch_size: Optional[int],
max_num_tokens: Optional[int] = 4096,
):
extra_kwargs = {}
if batch_size is not None:
Expand All @@ -196,6 +197,7 @@ def _context_init(
num_instructions_to_generate=num_instructions_to_generate,
checkpoint_dir=checkpoint_dir,
save_freq=save_freq,
max_num_tokens=max_num_tokens,
**extra_kwargs,
)

Expand Down Expand Up @@ -281,6 +283,7 @@ def generate_data(
pipeline: Optional[str] = "simple",
batch_size: Optional[int] = None,
checkpoint_dir: Optional[str] = None,
max_num_tokens: Optional[int] = 4096,
) -> None:
"""Generate data for training and testing a model.
Expand Down Expand Up @@ -343,6 +346,7 @@ def generate_data(
1, # save_freq
batch_size=batch_size,
batch_num_workers=num_cpus,
max_num_tokens=max_num_tokens,
)

knowledge_pipe, freeform_skills_pipe, grounded_skills_pipe = _sdg_init(
Expand Down
34 changes: 31 additions & 3 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

logger = logging.getLogger(__name__)

DEFAULT_MAX_NUM_TOKENS = 4096

DEFAULT_NUM_SAMPLES = 30

MODEL_FAMILY_MIXTRAL = "mixtral"
MODEL_FAMILY_MERLINITE = "merlinite"

Expand Down Expand Up @@ -62,6 +66,8 @@ def __init__(
ctx,
pipe,
block_name,
max_num_tokens,
num_instructions_to_generate,
config_path,
output_cols,
model_prompt=None,
Expand All @@ -81,8 +87,19 @@ def __init__(
self.parser_name = parser_kwargs.get("parser_name", None)
self.parsing_pattern = parser_kwargs.get("parsing_pattern", None)
self.parser_cleanup_tags = parser_kwargs.get("parser_cleanup_tags", None)
# max_num_tokens should only be applicable to knowledge blocks
if block_name != "gen_knowledge":
max_num_tokens = DEFAULT_MAX_NUM_TOKENS
# num_instructions_to_generate should apply to gen_questions for freeform and grounded
if block_name == "gen_questions" or block_name == "gen_grounded_questions":

Check warning on line 94 in src/instructlab/sdg/llmblock.py

View workflow job for this annotation

GitHub Actions / pylint

R1714: Consider merging these comparisons with 'in' by using 'block_name in ('gen_questions', 'gen_grounded_questions')'. Use a set instead if elements are hashable. (consider-using-in)
self.batch_params = self._batch_kwargs(
batch_kwargs, num_tokens=num_instructions_to_generate
)
self.gen_kwargs = self._gen_kwargs(
gen_kwargs, model=self.ctx.model_id, temperature=0, max_tokens=4096
gen_kwargs,
model=self.ctx.model_id,
temperature=0,
max_tokens=max_num_tokens,
)
# Whether the LLM server supports a list of input prompts
# and supports the n parameter to generate n outputs per input
Expand Down Expand Up @@ -142,6 +159,13 @@ def _format_prompt(self, sample: Dict) -> str:

return prompt if model_prompt is None else model_prompt.format(prompt=prompt)

def _batch_kwargs(self, batch_kwargs, **defaults):
batch_kwargs = {**defaults, **batch_kwargs}
if "num_samples" in batch_kwargs:
batch_kwargs["num_samples"] = int(batch_kwargs["num_samples"])

return batch_kwargs

def _gen_kwargs(self, gen_kwargs, **defaults):
gen_kwargs = {**defaults, **gen_kwargs}
if (
Expand All @@ -150,10 +174,10 @@ def _gen_kwargs(self, gen_kwargs, **defaults):
and gen_kwargs["n"] == "scaled"
):
gen_kwargs["n"] = self.ctx.num_instructions_to_generate
if "max_tokens" in gen_kwargs:
gen_kwargs["max_tokens"] = int(gen_kwargs["max_tokens"])
if "temperature" in gen_kwargs:
gen_kwargs["temperature"] = float(gen_kwargs["temperature"])
if "max_tokens" in gen_kwargs:
gen_kwargs["max_tokens"] = int(gen_kwargs["max_tokens"])
return gen_kwargs

def _generate(self, samples) -> list:
Expand Down Expand Up @@ -259,6 +283,8 @@ def __init__(
ctx,
pipe,
block_name,
max_num_tokens,
num_instructions_to_generate,
config_paths,
output_cols,
selector_column_name,
Expand All @@ -271,6 +297,8 @@ def __init__(
ctx,
pipe,
block_name,
max_num_tokens,
num_instructions_to_generate,
config_paths[0][0],
output_cols,
model_prompt=model_prompt,
Expand Down
17 changes: 15 additions & 2 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes
central executor pool.
dataset_num_procs: The number of processes to use when performing parallel
map operations on individual datasets.
max_num_tokens: the maximum number of tokens to generate per sample.
"""

# The default batch size of 8 has been determined as a good default for
Expand All @@ -65,6 +66,7 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes
dataset_num_procs: Optional[int] = DEFAULT_DATASET_NUM_PROCS
checkpoint_dir: Optional[str] = None
save_freq: Optional[int] = 1
max_num_tokens: Optional[int] = 4096
batch_size: int = DEFAULT_BATCH_SIZE
batch_num_workers: Optional[int] = None

Expand Down Expand Up @@ -191,11 +193,22 @@ def _generate_single(self, dataset) -> Dataset:
block_name = block_prop["name"]
block_type = _lookup_block_type(block_prop["type"])
block_config = block_prop["config"]
max_num_tokens = self.ctx.max_num_tokens
num_instructions_to_generate = self.ctx.num_instructions_to_generate
drop_columns = block_prop.get("drop_columns", [])
drop_duplicates_cols = block_prop.get("drop_duplicates", False)
block = block_type(self.ctx, self, block_name, **block_config)
if block_type in (llmblock.LLMBlock, llmblock.ConditionalLLMBlock):
block = block_type(
self.ctx,
self,
block_name,
max_num_tokens,
num_instructions_to_generate,
**block_config,
)
else:
block = block_type(self.ctx, self, block_name, **block_config)
logger.info("Running block: %s", block_name)

# Execute the block and wrap errors with the block name/type
dataset = block.generate(dataset)
except Exception as err:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def test_model_prompt_empty_string(self, mock_load_config):
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="test_block",
max_num_tokens=2048,
num_instructions_to_generate=30,
config_path="",
output_cols=[],
model_prompt="",
Expand All @@ -57,6 +59,8 @@ def test_model_prompt_none(self, mock_load_config):
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="test_block",
max_num_tokens=2048,
num_instructions_to_generate=30,
config_path="",
output_cols=[],
model_prompt=None, # Or simply omit model_prompt as it defaults to None
Expand All @@ -76,6 +80,8 @@ def test_model_prompt_none(self, mock_load_config):
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="test_block",
max_num_tokens=2048,
num_instructions_to_generate=30,
config_path="",
output_cols=[],
model_prompt="FOO {prompt} BAR",
Expand Down

0 comments on commit ed7ea9f

Please sign in to comment.