Skip to content

Commit

Permalink
♻️ refactor(infer.py): improve model loading and detection functions
Browse files Browse the repository at this point in the history
- Refactor `get_model_loaded` to `load_model` with enhanced error handling
- Rename constants for better clarity and consistency
- Simplify language detection functions and improve docstrings
- Use `Optional` typing and rename parameters for clarity
  • Loading branch information
sudoskys committed Sep 29, 2024
1 parent a1a784d commit ede2d25
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 83 deletions.
158 changes: 75 additions & 83 deletions src/fast_langdetect/ft_detect/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,129 +6,121 @@
import logging
import os
from pathlib import Path
from typing import Dict, Union, List
from typing import Dict, Union, List, Optional, Any

import fasttext
from robust_downloader import download

logger = logging.getLogger(__name__)
MODELS = {"low_mem": None, "high_mem": None}
FTLANG_CACHE = os.getenv("FTLANG_CACHE", "/tmp/fasttext-langdetect")
CACHE_DIRECTORY = os.getenv("FTLANG_CACHE", "/tmp/fasttext-langdetect")
LOCAL_SMALL_MODEL_PATH = Path(__file__).parent / "resources" / "lid.176.ftz"

# Suppress FastText output if possible
try:
# silences warnings as the package does not properly use the python 'warnings' package
# see https://github.com/facebookresearch/fastText/issues/1056
fasttext.FastText.eprint = lambda *args, **kwargs: None
except Exception:
pass


class DetectError(Exception):
"""Custom exception for language detection errors."""
pass


def get_model_map(low_memory=False):
def load_model(low_memory: bool = False, download_proxy: Optional[str] = None) -> fasttext.FastText:
"""
Getting model map
:param low_memory:
:return:
Load the FastText model based on memory preference.
:param low_memory: Indicates whether to load a smaller, memory-efficient model
:param download_proxy: Proxy to use for downloading the large model if necessary
:return: Loaded FastText model
:raises LanguageDetectionError: If the model cannot be loaded
"""
model_dict_key = "low_mem" if low_memory else "high_mem"
if MODELS[model_dict_key]:
return MODELS[model_dict_key]

if low_memory:
return "low_mem", FTLANG_CACHE, "lid.176.ftz", "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz"
model_path = LOCAL_SMALL_MODEL_PATH
else:
return "high_mem", FTLANG_CACHE, "lid.176.bin", "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin"


def get_model_loaded(
low_memory: bool = False,
download_proxy: str = None
):
"""
Getting model loaded
:param low_memory:
:param download_proxy:
:return:
"""
mode, cache, name, url = get_model_map(low_memory)
loaded = MODELS.get(mode, None)
if loaded:
return loaded
model_path = os.path.join(cache, name)
if Path(model_path).exists():
if Path(model_path).is_dir():
raise Exception(f"{model_path} is a directory")
try:
loaded_model = fasttext.load_model(model_path)
MODELS[mode] = loaded_model
except Exception as e:
logger.error(f"Error loading model {model_path}: {e}")
download(url=url, folder=cache, filename=name, proxy=download_proxy)
raise e
else:
return loaded_model

download(url=url, folder=cache, filename=name, proxy=download_proxy, retry_max=3, timeout=20)
loaded_model = fasttext.load_model(model_path)
MODELS[mode] = loaded_model
return loaded_model
model_path = Path(CACHE_DIRECTORY) / "lid.176.bin"
if not model_path.exists():
model_url = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin"
try:
logger.info(f"Downloading large model from {model_url} to {model_path}")
download(
url=model_url,
folder=CACHE_DIRECTORY,
filename="lid.176.bin",
proxy=download_proxy,
retry_max=3,
timeout=20
)
except Exception as e:
logger.error(f"Failed to download the large model: {e}")
raise DetectError("Unable to download the large model due to network issues.")

try:
loaded_model = fasttext.load_model(str(model_path))
MODELS[model_dict_key] = loaded_model
return loaded_model
except Exception as e:
logger.error(f"Failed to load the model '{model_path}': {e}")
model_type = "local small" if low_memory else "large"
raise DetectError(f"Unable to load the {model_type} model due to an error.")


def detect(text: str, *,
low_memory: bool = True,
model_download_proxy: str = None
model_download_proxy: Optional[str] = None
) -> Dict[str, Union[str, float]]:
"""
Detect language of text
This function assumes to be given a single line of text. We split words on whitespace (space, newline, tab, vertical tab) and the control characters carriage return, formfeed and the null character.
Detect the language of a text using FastText.
:param text: Text for language detection
:param low_memory: Whether to use low memory mode
:param model_download_proxy: model download proxy
:return: {"lang": "en", "score": 0.99}
:raise ValueError: predict processes one line at a time (remove \'\\n\')
:param text: The text for language detection
:param low_memory: Whether to use a memory-efficient model
:param model_download_proxy: Download proxy for the model if needed
:return: A dictionary with detected language and confidence score
:raises LanguageDetectionError: If detection fails
"""
model = get_model_loaded(low_memory=low_memory, download_proxy=model_download_proxy)
model = load_model(low_memory=low_memory, download_proxy=model_download_proxy)
labels, scores = model.predict(text)
label = labels[0].replace("__label__", '')
score = min(float(scores[0]), 1.0)
language_label = labels[0].replace("__label__", '')
confidence_score = min(float(scores[0]), 1.0)
return {
"lang": label,
"score": score,
"lang": language_label,
"score": confidence_score,
}


def detect_multilingual(text: str, *,
low_memory: bool = True,
model_download_proxy: str = None,
model_download_proxy: Optional[str] = None,
k: int = 5,
threshold: float = 0.0,
on_unicode_error: str = "strict"
) -> List[dict]:
) -> List[Dict[str, Any]]:
"""
Given a string, get a list of labels and a list of corresponding probabilities.
k controls the number of returned labels. A choice of 5, will return the 5 most probable labels.
By default this returns only the most likely label and probability. threshold filters the returned labels by a threshold on probability. A choice of 0.5 will return labels with at least 0.5 probability.
k and threshold will be applied together to determine the returned labels.
NOTE:This function assumes to be given a single line of text. We split words on whitespace (space, newline, tab, vertical tab) and the control characters carriage return, formfeed and the null character.
:param text: Text for language detection
:param low_memory: Whether to use low memory mode
:param model_download_proxy: model download proxy
:param k: Predict top k languages
:param threshold: Threshold for prediction
:param on_unicode_error: Error handling
:return:
Detect multiple potential languages and their probabilities in a given text.
:param text: The text for language detection
:param low_memory: Whether to use a memory-efficient model
:param model_download_proxy: Proxy for downloading the model
:param k: Number of top language predictions to return
:param threshold: Minimum score threshold for predictions
:param on_unicode_error: Error handling for Unicode errors
:return: A list of dictionaries, each containing a language and its confidence score
:raises LanguageDetectionError: If detection fails
"""
model = get_model_loaded(low_memory=low_memory, download_proxy=model_download_proxy)
labels, scores = model.predict(text=text, k=k, threshold=threshold, on_unicode_error=on_unicode_error)
detect_result = []
model = load_model(low_memory=low_memory, download_proxy=model_download_proxy)
labels, scores = model.predict(text, k=k, threshold=threshold, on_unicode_error=on_unicode_error)
results = []
for label, score in zip(labels, scores):
label = label.replace("__label__", '')
score = min(float(score), 1.0)
detect_result.append({
"lang": label,
"score": score,
language_label = label.replace("__label__", '')
confidence_score = min(float(score), 1.0)
results.append({
"lang": language_label,
"score": confidence_score,
})
return sorted(detect_result, key=lambda i: i['score'], reverse=True)
return sorted(results, key=lambda x: x['score'], reverse=True)
Binary file not shown.

0 comments on commit ede2d25

Please sign in to comment.