diff --git a/src/fast_langdetect/ft_detect/infer.py b/src/fast_langdetect/ft_detect/infer.py index a6fabb5..d227a36 100644 --- a/src/fast_langdetect/ft_detect/infer.py +++ b/src/fast_langdetect/ft_detect/infer.py @@ -47,7 +47,8 @@ class DetectError(Exception): pass -def load_model(low_memory: bool = False, download_proxy: Optional[str] = None, +def load_model(low_memory: bool = False, + download_proxy: Optional[str] = None, use_strict_mode: bool = False) -> "fasttext.FastText._FastText": """ Load the FastText model based on memory preference. @@ -75,6 +76,16 @@ def load_local_small_model(): logger.error(f"Failed to load the local small model '{LOCAL_SMALL_MODEL_PATH}': {e}") raise DetectError("Unable to load low-memory model from local resources.") + def load_large_model(): + """Try to load the large model.""" + try: + loaded_model = fasttext.load_model(str(model_path)) + _model_cache.set_model(ModelType.HIGH_MEMORY, loaded_model) + return loaded_model + except Exception as e: + logger.error(f"Failed to load the large model '{model_path}': {e}") + return None + if low_memory: # Attempt to load the local small model return load_local_small_model() @@ -82,50 +93,38 @@ def load_local_small_model(): # Path for the large model model_path = Path(CACHE_DIRECTORY) / "lid.176.bin" - def load_large_model(): - if model_path.exists(): - try: - loaded_model = fasttext.load_model(str(model_path)) - _model_cache.set_model(ModelType.HIGH_MEMORY, loaded_model) - return loaded_model - except Exception as e: - logger.error(f"Failed to load the large model '{model_path}': {e}") - if use_strict_mode: - raise DetectError("Strict mode enabled: Unable to load the large model.") - else: - logger.info("Attempting to fall back to local small model.") - return None - - # Attempt to load large model - loaded_model = load_large_model() - if loaded_model: - return loaded_model - - if not use_strict_mode: - # If not strict or download fails, attempt to download - model_url = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin" - try: - logger.info(f"Downloading large model from {model_url} to {model_path}") - download( - url=model_url, - folder=CACHE_DIRECTORY, - filename="lid.176.bin", - proxy=download_proxy, - retry_max=3, - timeout=20 - ) - # Try loading the model again after download - return load_large_model() - except Exception as e: - logger.error(f"Failed to download the large model: {e}") - logger.info("Attempting to fall back to local small model.") - - # Fallback to the small model if download fails and not in strict mode - if not use_strict_mode: + if model_path.exists(): + # Attempt to load large model + loaded_model = load_large_model() + if loaded_model: + return loaded_model + + # If the large model is not present, attempt to download (only if necessary) + model_url = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin" + try: + logger.info(f"Downloading large model from {model_url} to {model_path}") + download( + url=model_url, + folder=CACHE_DIRECTORY, + filename="lid.176.bin", + proxy=download_proxy, + retry_max=3, + timeout=20 + ) + # Try loading the model again after download + loaded_model = load_large_model() + if loaded_model: + return loaded_model + except Exception as e: + logger.error(f"Failed to download the large model: {e}") + + # Handle fallback logic for strict and non-strict modes + if use_strict_mode: + raise DetectError("Strict mode enabled: Unable to download or load the large model.") + else: + logger.info("Attempting to fall back to local small model.") return load_local_small_model() - raise DetectError("Strict mode enabled: Unable to download or load the large model.") - def detect(text: str, *, low_memory: bool = True,