From 843da73b6c2889884de620070a5bec7ba4b4dea9 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Mon, 7 Oct 2024 18:18:51 +0100 Subject: [PATCH] Added new fields to accept weights only --- cellfinder/napari/detect/detect.py | 14 ++++++++++++++ cellfinder/napari/detect/detect_containers.py | 5 +++++ 2 files changed, 19 insertions(+) diff --git a/cellfinder/napari/detect/detect.py b/cellfinder/napari/detect/detect.py index 5f2de700..a7b26bd7 100644 --- a/cellfinder/napari/detect/detect.py +++ b/cellfinder/napari/detect/detect.py @@ -254,7 +254,9 @@ def widget( classification_options, skip_classification: bool, use_pre_trained_weights: bool, + weights_only: bool, trained_model: Optional[Path], + model_weights: Optional[Path], batch_size: int, misc_options, start_plane: int, @@ -299,6 +301,8 @@ def widget( should be attempted use_pre_trained_weights : bool Select to use pre-trained model weights + weights_only : bool + Select to only provide the model weights batch_size : int How many points to classify at one time skip_classification : bool @@ -307,6 +311,9 @@ def widget( trained_model : Optional[Path] Trained model file path (home directory (default) -> pretrained weights) + model_weights : Optional[Path] + Model weights file path (home directory (default) -> pretrained + weights) start_plane : int First plane to process (to process a subset of the data) end_plane : int @@ -372,12 +379,19 @@ def widget( max_cluster_size, ) + if weights_only: + trained_model = None + if use_pre_trained_weights: trained_model = None + model_weights = None + classification_inputs = ClassificationInputs( skip_classification, use_pre_trained_weights, + weights_only, trained_model, + model_weights, batch_size, ) diff --git a/cellfinder/napari/detect/detect_containers.py b/cellfinder/napari/detect/detect_containers.py index 953e6248..19a47ff6 100644 --- a/cellfinder/napari/detect/detect_containers.py +++ b/cellfinder/napari/detect/detect_containers.py @@ -113,12 +113,15 @@ class ClassificationInputs(InputContainer): skip_classification: bool = False use_pre_trained_weights: bool = True + weights_only: bool = False trained_model: Optional[Path] = Path.home() + model_weights: Optional[Path] = Path.home() batch_size: int = 64 def as_core_arguments(self) -> dict: args = super().as_core_arguments() del args["use_pre_trained_weights"] + del args["weights_only"] return args @classmethod @@ -128,7 +131,9 @@ def widget_representation(cls) -> dict: use_pre_trained_weights=dict( value=cls.defaults()["use_pre_trained_weights"] ), + weights_only=dict(value=cls.defaults()["weights_only"]), trained_model=dict(value=cls.defaults()["trained_model"]), + model_weights=dict(value=cls.defaults()["model_weights"]), skip_classification=dict( value=cls.defaults()["skip_classification"] ),