Skip to content

Commit

Permalink
Fix linting and formatting for context-aware chunking
Browse files Browse the repository at this point in the history
Signed-off-by: Khaled Sulayman <[email protected]>
  • Loading branch information
khaledsulayman committed Nov 6, 2024
1 parent 7c5050a commit e71e99a
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 78 deletions.
9 changes: 7 additions & 2 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,17 @@ def generate_data(
is_knowledge = False
leaf_node_path = leaf_node[0]["taxonomy_path"].replace("->", "_")
samples = leaf_node_to_samples(
leaf_node, taxonomy, server_ctx_size, chunk_word_count, document_output_dir, model_name
leaf_node,
taxonomy,
server_ctx_size,
chunk_word_count,
document_output_dir,
model_name,
)

if not samples:
raise GenerateException("Error: No samples found in leaf node.")

if "document" in samples.column_names:
pipe = knowledge_pipe
is_knowledge = True
Expand Down
96 changes: 56 additions & 40 deletions src/instructlab/sdg/utils/chunkers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import json
import logging
import re
import yaml
# Standard
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
from pathlib import Path
from typing import Iterable, List, Tuple, DefaultDict
from typing import DefaultDict, Iterable, List, Tuple, Union
import json
import logging
import re

# Third Party
from datasets import Dataset, concatenate_datasets
Expand All @@ -16,7 +16,7 @@
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
from tabulate import tabulate
from transformers import AutoTokenizer

import yaml

logger = logging.getLogger(__name__)
_DEFAULT_CHUNK_OVERLAP = 100
Expand All @@ -25,6 +25,7 @@
def _num_tokens_from_words(num_words) -> int:
return int(num_words * 1.3) # 1 word ~ 1.3 token


def _num_chars_from_tokens(num_tokens) -> int:
return int(num_tokens * 4) # 1 token ~ 4 English character

Expand All @@ -36,34 +37,35 @@ class FileTypes(Enum):

class ChunkerBase(ABC):
@abstractmethod
def chunk_documents():
def chunk_documents(self):
pass


class DocumentChunker:
"""A factory chunker class that instantiates the applicable chunker
Currently, only Markdown and PDF are supported. For Markdown, returns
TextSplitChunker, and for PDF, returns ContextAwareChunker"""

def __new__(
cls,
leaf_node = None,
taxonomy_path = None,
output_dir: Path = None,
leaf_node,
taxonomy_path,
output_dir: Path,
server_ctx_size=4096,
chunk_word_count=1024,
tokenizer_model_name: str = None,
tokenizer_model_name: str | None = None,
):
"""Insantiate the appropriate chunker for the provided document
Args:
leaf_node: a leaf node dict containing "documents",
"filepaths", and "taxonomy_path" keys
output_dir (Path): directory where artifacts should be stored
server_ctx_size (int): Context window size of server
chunk_word_count (int): Maximum number of words to chunk a document
tokenizer_model_name (str): name of huggingface model to get
tokenizer from
tokenizer from
Returns:
TextSplitChunker | ContextAwareChunker: Object of the appropriate
chunker class for the provided filetype
Expand Down Expand Up @@ -91,44 +93,48 @@ def __new__(
raise ValueError(f"Received multiple document types")

Check warning on line 93 in src/instructlab/sdg/utils/chunkers.py

View workflow job for this annotation

GitHub Actions / pylint

W1309: Using an f-string that does not have any interpolated variables (f-string-without-interpolation)

if FileTypes.MD in doc_dict:
doc_contents = [d for d, _ in doc_dict[FileTypes.MD]]
return TextSplitChunker(
doc_dict[FileTypes.MD],
doc_contents,
server_ctx_size,
chunk_word_count,
output_dir,
)

if FileTypes.PDF in doc_dict:
doc_paths = [p for _, p in doc_dict[FileTypes.PDF]]
return ContextAwareChunker(
doc_dict[FileTypes.PDF],
doc_paths,
filepaths,
taxonomy_path / leaf_node_path / "qna.yaml",
output_dir,
output_dir,
chunk_word_count,
tokenizer_model_name,
)

@staticmethod
def _split_docs_by_filetype(documents: List[str], filepaths: List[Path]) -> defaultdict[any, List]:
def _split_docs_by_filetype(
documents: List[str], filepaths: List[Path]
) -> DefaultDict[FileTypes, List[Tuple[str, Path]]]:
"""Separate documents into lists based on their filetype.
Currently, only Markdown and PDF are supported.
Args:
documents (List[str]): A list of the document contents as strings
filepaths (List[Path]): Corresponding document filepaths
Returns:
defaultdict: Dictionary with either ".md" or ".pdf" as a key.
DefaultDict: Dictionary with either ".md" or ".pdf" as a key.
Markdown items contain document contents, PDF items contain
paths to documents.
"""
doc_dict = defaultdict(list)
for doc, path in zip(documents, filepaths):
if path.suffix == ".md":
# append doc contents
doc_dict[FileTypes.MD].append(doc)
doc_dict[FileTypes.MD].append((doc, path))
elif path.suffix == ".pdf":
# append doc paths
doc_dict[FileTypes.PDF].append(path)
doc_dict[FileTypes.PDF].append((doc, path))
else:
raise ValueError(
f"Received document of type .{path.suffix}, which is not a supported filetype"
Expand Down Expand Up @@ -185,7 +191,11 @@ def __init__(
self.leaf_node_path = leaf_node_path
self.output_dir = self._path_validator(output_dir)
self.chunk_word_count = chunk_word_count
self.tokenizer_model_name = tokenizer_model_name if tokenizer_model_name is not None else "mistralai/Mixtral-8x7B-Instruct-v0.1"
self.tokenizer_model_name = (
tokenizer_model_name
if tokenizer_model_name is not None
else "mistralai/Mixtral-8x7B-Instruct-v0.1"
)
self.qna_yaml = self._load_qna_yaml(
self._path_validator(leaf_node_path) if leaf_node_path else None
)
Expand Down Expand Up @@ -229,7 +239,7 @@ def _path_validator(self, path) -> Path:
raise FileNotFoundError(f"{path} does not exist.")
return path

def _load_qna_yaml(self, qna_yaml_path: Path) -> dict:
def _load_qna_yaml(self, qna_yaml_path: Path | None) -> dict:
"""
Load the qna YAML file.
Args:
Expand Down Expand Up @@ -265,8 +275,10 @@ def _process_parsed_docling_json(self, json_fp: Path) -> Dataset:
num_tokens_per_doc = _num_tokens_from_words(self.chunk_word_count)
chunk_size = _num_chars_from_tokens(num_tokens_per_doc)
return chunk_markdowns(fused_texts, chunk_size)

def fuse_texts(self, text_list: List, short_length_threshold: int = 130):

def fuse_texts(
self, text_list: List, short_length_threshold: int = 130
) -> List[str]:
"""
Fuse short texts with preceding longer texts if their token count is below the threshold.
Args:
Expand All @@ -277,11 +289,13 @@ def fuse_texts(self, text_list: List, short_length_threshold: int = 130):
Returns:
list: List of fused texts.
"""
fused_texts = []
fused_texts: List[str] = []
previous_long_text = ""

for text in text_list:
token_count = self.get_token_count(text, self.tokenizer) # Use tokenizer for token count
token_count = self.get_token_count(
text, self.tokenizer
) # Use tokenizer for token count

if token_count <= short_length_threshold and previous_long_text:
# Append the short text to the last long text
Expand All @@ -292,7 +306,7 @@ def fuse_texts(self, text_list: List, short_length_threshold: int = 130):
previous_long_text = text

return fused_texts

def create_tokenizer(self, model_name: str):
"""
Create a tokenizer instance from a pre-trained model or a local directory.
Expand All @@ -311,7 +325,6 @@ def create_tokenizer(self, model_name: str):
logger.error(f"Failed to load tokenizer from {model_name}: {str(e)}")
raise


def get_token_count(self, text, tokenizer):
"""
Get the number of tokens in a text using the provided tokenizer.
Expand All @@ -323,7 +336,6 @@ def get_token_count(self, text, tokenizer):
"""
return len(tokenizer.tokenize(text))


def add_heading_formatting(self, text):
"""
Add heading formatting to the text if the first part is short.
Expand All @@ -341,7 +353,6 @@ def add_heading_formatting(self, text):
text = ".".join(text)
return text


def generate_table_from_parsed_rep(self, item):
"""
Generate the table from the parsed representation and return as a string.
Expand Down Expand Up @@ -371,7 +382,6 @@ def generate_table_from_parsed_rep(self, item):
table_text += f"\nCaption: {caption}\n"
return table_text


def get_table(self, json_book, table_ref):
"""
Retrieve a table from a document based on a reference string.
Expand All @@ -382,10 +392,11 @@ def get_table(self, json_book, table_ref):
str: Formatted table string.
"""
parts = table_ref.split("/")
table_text = self.generate_table_from_parsed_rep(json_book[parts[1]][int(parts[2])])
table_text = self.generate_table_from_parsed_rep(
json_book[parts[1]][int(parts[2])]
)
return table_text


def get_table_page_number(self, json_book, idx):
"""
Get the page number of a table or other document element.
Expand Down Expand Up @@ -458,7 +469,9 @@ def build_chunks_from_docling_json(
"equation",
]: # 'page-header',
if book_element["type"] == "table":
current_book_page_number = self.get_table_page_number(json_book, idx)
current_book_page_number = self.get_table_page_number(
json_book, idx
)
else:
current_book_page_number = book_element["prov"][0]["page"]
book_text = book_element["text"]
Expand Down Expand Up @@ -492,16 +505,20 @@ def build_chunks_from_docling_json(
>= max_token_per_chunk
and len(current_buffer) > 1
):
chunk_text = '\n\n'.join(current_buffer[:-1])
print(f"Current chunk size {self.get_token_count(chunk_text, tokenizer)} and max is {max_token_per_chunk}")
chunk_text = "\n\n".join(current_buffer[:-1])
print(
f"Current chunk size {self.get_token_count(chunk_text, tokenizer)} and max is {max_token_per_chunk}"
)

document_chunks.append("\n\n".join(current_buffer[:-1]))

if (
self.get_token_count(current_buffer[-1], tokenizer)
>= max_token_per_chunk
):
print(f"This is too big a document to be left in the current buffer {self.get_token_count(current_buffer[-1], tokenizer)}")
print(
f"This is too big a document to be left in the current buffer {self.get_token_count(current_buffer[-1], tokenizer)}"
)
document_chunks.append(current_buffer[-1])
current_buffer = []
else:
Expand All @@ -525,10 +542,9 @@ def build_chunks_from_docling_json(
document_chunks.append("\n\n".join(current_buffer))
return document_chunks


def export_documents(self, converted_docs: Iterable[ConvertedDocument]):
"""Write converted documents to json files
Check for successful conversions and write those to the docling artifacts directory.
Returns:
Path: path to directory with docling json artifacts
Expand Down
Loading

0 comments on commit e71e99a

Please sign in to comment.