diff --git a/src/fast_langdetect/ft_detect/infer.py b/src/fast_langdetect/ft_detect/infer.py index a2d7ba1..0921906 100644 --- a/src/fast_langdetect/ft_detect/infer.py +++ b/src/fast_langdetect/ft_detect/infer.py @@ -6,129 +6,121 @@ import logging import os from pathlib import Path -from typing import Dict, Union, List +from typing import Dict, Union, List, Optional, Any import fasttext from robust_downloader import download logger = logging.getLogger(__name__) MODELS = {"low_mem": None, "high_mem": None} -FTLANG_CACHE = os.getenv("FTLANG_CACHE", "/tmp/fasttext-langdetect") +CACHE_DIRECTORY = os.getenv("FTLANG_CACHE", "/tmp/fasttext-langdetect") +LOCAL_SMALL_MODEL_PATH = Path(__file__).parent / "resources" / "lid.176.ftz" +# Suppress FastText output if possible try: - # silences warnings as the package does not properly use the python 'warnings' package - # see https://github.com/facebookresearch/fastText/issues/1056 fasttext.FastText.eprint = lambda *args, **kwargs: None except Exception: pass class DetectError(Exception): + """Custom exception for language detection errors.""" pass -def get_model_map(low_memory=False): +def load_model(low_memory: bool = False, download_proxy: Optional[str] = None) -> fasttext.FastText: """ - Getting model map - :param low_memory: - :return: + Load the FastText model based on memory preference. + + :param low_memory: Indicates whether to load a smaller, memory-efficient model + :param download_proxy: Proxy to use for downloading the large model if necessary + :return: Loaded FastText model + :raises LanguageDetectionError: If the model cannot be loaded """ + model_dict_key = "low_mem" if low_memory else "high_mem" + if MODELS[model_dict_key]: + return MODELS[model_dict_key] + if low_memory: - return "low_mem", FTLANG_CACHE, "lid.176.ftz", "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz" + model_path = LOCAL_SMALL_MODEL_PATH else: - return "high_mem", FTLANG_CACHE, "lid.176.bin", "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin" - - -def get_model_loaded( - low_memory: bool = False, - download_proxy: str = None -): - """ - Getting model loaded - :param low_memory: - :param download_proxy: - :return: - """ - mode, cache, name, url = get_model_map(low_memory) - loaded = MODELS.get(mode, None) - if loaded: - return loaded - model_path = os.path.join(cache, name) - if Path(model_path).exists(): - if Path(model_path).is_dir(): - raise Exception(f"{model_path} is a directory") - try: - loaded_model = fasttext.load_model(model_path) - MODELS[mode] = loaded_model - except Exception as e: - logger.error(f"Error loading model {model_path}: {e}") - download(url=url, folder=cache, filename=name, proxy=download_proxy) - raise e - else: - return loaded_model - - download(url=url, folder=cache, filename=name, proxy=download_proxy, retry_max=3, timeout=20) - loaded_model = fasttext.load_model(model_path) - MODELS[mode] = loaded_model - return loaded_model + model_path = Path(CACHE_DIRECTORY) / "lid.176.bin" + if not model_path.exists(): + model_url = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin" + try: + logger.info(f"Downloading large model from {model_url} to {model_path}") + download( + url=model_url, + folder=CACHE_DIRECTORY, + filename="lid.176.bin", + proxy=download_proxy, + retry_max=3, + timeout=20 + ) + except Exception as e: + logger.error(f"Failed to download the large model: {e}") + raise DetectError("Unable to download the large model due to network issues.") + + try: + loaded_model = fasttext.load_model(str(model_path)) + MODELS[model_dict_key] = loaded_model + return loaded_model + except Exception as e: + logger.error(f"Failed to load the model '{model_path}': {e}") + model_type = "local small" if low_memory else "large" + raise DetectError(f"Unable to load the {model_type} model due to an error.") def detect(text: str, *, low_memory: bool = True, - model_download_proxy: str = None + model_download_proxy: Optional[str] = None ) -> Dict[str, Union[str, float]]: """ - Detect language of text - - This function assumes to be given a single line of text. We split words on whitespace (space, newline, tab, vertical tab) and the control characters carriage return, formfeed and the null character. + Detect the language of a text using FastText. - :param text: Text for language detection - :param low_memory: Whether to use low memory mode - :param model_download_proxy: model download proxy - :return: {"lang": "en", "score": 0.99} - :raise ValueError: predict processes one line at a time (remove \'\\n\') + :param text: The text for language detection + :param low_memory: Whether to use a memory-efficient model + :param model_download_proxy: Download proxy for the model if needed + :return: A dictionary with detected language and confidence score + :raises LanguageDetectionError: If detection fails """ - model = get_model_loaded(low_memory=low_memory, download_proxy=model_download_proxy) + model = load_model(low_memory=low_memory, download_proxy=model_download_proxy) labels, scores = model.predict(text) - label = labels[0].replace("__label__", '') - score = min(float(scores[0]), 1.0) + language_label = labels[0].replace("__label__", '') + confidence_score = min(float(scores[0]), 1.0) return { - "lang": label, - "score": score, + "lang": language_label, + "score": confidence_score, } def detect_multilingual(text: str, *, low_memory: bool = True, - model_download_proxy: str = None, + model_download_proxy: Optional[str] = None, k: int = 5, threshold: float = 0.0, on_unicode_error: str = "strict" - ) -> List[dict]: + ) -> List[Dict[str, Any]]: """ - Given a string, get a list of labels and a list of corresponding probabilities. - k controls the number of returned labels. A choice of 5, will return the 5 most probable labels. - By default this returns only the most likely label and probability. threshold filters the returned labels by a threshold on probability. A choice of 0.5 will return labels with at least 0.5 probability. - k and threshold will be applied together to determine the returned labels. - - NOTE:This function assumes to be given a single line of text. We split words on whitespace (space, newline, tab, vertical tab) and the control characters carriage return, formfeed and the null character. - - :param text: Text for language detection - :param low_memory: Whether to use low memory mode - :param model_download_proxy: model download proxy - :param k: Predict top k languages - :param threshold: Threshold for prediction - :param on_unicode_error: Error handling - :return: + Detect multiple potential languages and their probabilities in a given text. + + :param text: The text for language detection + :param low_memory: Whether to use a memory-efficient model + :param model_download_proxy: Proxy for downloading the model + :param k: Number of top language predictions to return + :param threshold: Minimum score threshold for predictions + :param on_unicode_error: Error handling for Unicode errors + :return: A list of dictionaries, each containing a language and its confidence score + :raises LanguageDetectionError: If detection fails """ - model = get_model_loaded(low_memory=low_memory, download_proxy=model_download_proxy) - labels, scores = model.predict(text=text, k=k, threshold=threshold, on_unicode_error=on_unicode_error) - detect_result = [] + model = load_model(low_memory=low_memory, download_proxy=model_download_proxy) + labels, scores = model.predict(text, k=k, threshold=threshold, on_unicode_error=on_unicode_error) + results = [] for label, score in zip(labels, scores): - label = label.replace("__label__", '') - score = min(float(score), 1.0) - detect_result.append({ - "lang": label, - "score": score, + language_label = label.replace("__label__", '') + confidence_score = min(float(score), 1.0) + results.append({ + "lang": language_label, + "score": confidence_score, }) - return sorted(detect_result, key=lambda i: i['score'], reverse=True) + return sorted(results, key=lambda x: x['score'], reverse=True) diff --git a/src/fast_langdetect/ft_detect/resources/lid.176.ftz b/src/fast_langdetect/ft_detect/resources/lid.176.ftz new file mode 100644 index 0000000..1fb85b3 Binary files /dev/null and b/src/fast_langdetect/ft_detect/resources/lid.176.ftz differ