Skip to content

Commit

Permalink
feat: add adaptive throttling and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eshwarprasadS committed Jan 10, 2025
1 parent 1c54405 commit 323ff80
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 13 deletions.
54 changes: 42 additions & 12 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# First Party
from instructlab.sdg.checkpointing import Checkpointer
from instructlab.sdg.utils import pandas
from instructlab.sdg.throttlers import AdaptiveThrottler

# Local
from .blocks import llmblock
Expand Down Expand Up @@ -156,6 +157,9 @@ def generate(self, dataset, checkpoint_name=None) -> Dataset:
logger.info("Running pipeline single-threaded")
return self._generate_single(dataset)


# if self.ctx.batch_num_workers is None, calculate default number of workers, same as outlined here: https://docs.python.org/3.11/library/concurrent.futures.html
self.ctx.batch_num_workers = min(32, os.cpu_count() + 4)
# Otherwise, split the dataset into batches and run each batch as a
# future in the thread pool
logger.info(
Expand All @@ -165,18 +169,44 @@ def generate(self, dataset, checkpoint_name=None) -> Dataset:
)
input_splits = self._split_dataset(dataset)
output_splits = []
with ThreadPoolExecutor(max_workers=self.ctx.batch_num_workers) as executor:
futures = [
executor.submit(self._generate_single, input_split)
for input_split in input_splits
]

# Collect the results of each batch as they finish. This needs to
# wait for them all, so the order of waiting doesn't matter
for future in futures:
ds = future.result()
output_splits.append(ds)
checkpointer.checkpoint(ds)


throttler = AdaptiveThrottler(
min_workers=1,
max_workers=self.ctx.batch_num_workers, # Upper limit from config
initial_workers=self.ctx.batch_num_workers//2, # Start at 50% of max
)

if not input_splits:
logger.warning("Input splits are empty. Returning empty dataset.")
return concatenate_datasets([])

while input_splits:
# Get the current number of workers from the throttler
current_workers = throttler.get_workers()
# Take a slice of input splits to process concurrently
input_splits_batch = input_splits[:current_workers]
input_splits = input_splits[current_workers:]

with ThreadPoolExecutor(max_workers=current_workers) as executor:
# Submit tasks for each batch
futures = [
executor.submit(self._generate_single, input_split)
for input_split in input_splits_batch
]

# Collect the results of each batch as they finish. This needs to
# wait for them all, so the order of waiting doesn't matter
for future in futures:
try:
ds = future.result() # Block until the task is complete
output_splits.append(ds) # Store the successful result
checkpointer.checkpoint(ds) # Save progress
throttler.adjust_workers(success=True) # Increase workers on success
except Exception as err:
logger.error("Error in pipeline batch generation: %s", err)
throttler.adjust_workers(success=False)

checkpointer.done()
if pre_generated_data:
output_splits.append(pre_generated_data)
Expand Down
31 changes: 31 additions & 0 deletions src/instructlab/sdg/throttlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import threading

DEFAULT_TOLERANCE = 0.2 # Fraction to reduce workers on failure

class AdaptiveThrottler:
def __init__(self, min_workers, max_workers, initial_workers, tolerance=DEFAULT_TOLERANCE):
self.min_workers = min_workers # Lower limit of workers
self.max_workers = max_workers # Upper limit of workers, same as num_cpus cli argument
self.current_workers = initial_workers # Start with this number
self.tolerance = tolerance # Reduce workers by this fraction on error
self.lock = threading.Lock() # Ensure thread-safe updates

def adjust_workers(self, success=True):
"""Adjust the number of workers based on success or failure."""
with self.lock: # Use a lock to avoid race conditions in multi-threading
if success:
# Gradually increase workers up to max_workers
if self.current_workers < self.max_workers:
self.current_workers += 1
else:
# Reduce workers by a fraction on failure, respecting min_workers
if self.current_workers > self.min_workers:
self.current_workers = max(
self.min_workers,
int(self.current_workers * (1 - self.tolerance)),
)

def get_workers(self):
"""Get the current number of workers."""
with self.lock:
return self.current_workers
39 changes: 38 additions & 1 deletion tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from contextlib import contextmanager
from threading import Event
from unittest import mock

# Third Party
from datasets import Dataset
import pytest

# First Party
from instructlab.sdg import Block, Pipeline, PipelineBlockError
from instructlab.sdg.throttlers import AdaptiveThrottler

## Helpers ##

Expand Down Expand Up @@ -194,3 +194,40 @@ def test_block_generation_error_properties_from_strings():
str(gen_err)
== f"{PipelineBlockError.__name__}({block_type}/{block_name}): {inner_err}"
)

def test_pipeline_with_adaptive_throttler(sample_dataset, threaded_ctx):
"""Test that the Pipeline integrates correctly with AdaptiveThrottler."""

# Mock block.generate to simulate failures and successes
block_type_mock = mock.MagicMock()
throttler = AdaptiveThrottler(min_workers=1, max_workers=5, initial_workers=3)

def mock_generate(dataset):
try:
# Simulate a task
if len(dataset) % 3 == 0: # Simulate failure for some batches
throttler.adjust_workers(success=False) # Report failure
raise Exception("Simulated failure")
throttler.adjust_workers(success=True) # Report success
return dataset # Return same sample dataset for all cases
except Exception:
raise

block_type_mock().generate.side_effect = mock_generate

# Configure the pipeline
pipe_cfg = [
{
"name": "block-one",
"type": "test",
"config": {},
}
]
with block_types({"test": block_type_mock}):
result = Pipeline(threaded_ctx, "", pipe_cfg).generate(sample_dataset)

# Assertions
assert result is not None # Ensure we got some output
assert throttler.current_workers < 5 # Concurrency should adapt dynamically, so should not reach max_workers
assert throttler.current_workers >= 1 # Should not drop below min_workers
block_type_mock().generate.call_count > 0 # Ensure the block was called at least once

0 comments on commit 323ff80

Please sign in to comment.