diff --git a/boxmot/trackers/basetracker.py b/boxmot/trackers/basetracker.py index 4987ae7ba6..c56ac0f669 100644 --- a/boxmot/trackers/basetracker.py +++ b/boxmot/trackers/basetracker.py @@ -125,6 +125,11 @@ def per_class_decorator(update_method): Decorator for the update method to handle per-class processing. """ def wrapper(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None): + + #handle different types of inputs + if dets is None or len(dets) == 0: + dets = np.empty((0, 6)) + if self.per_class: # Initialize an array to store the tracks for each class per_class_tracks = [] diff --git a/tests/unit/test_trackers.py b/tests/unit/test_trackers.py index fa7d7ff8ad..763ed23bcb 100644 --- a/tests/unit/test_trackers.py +++ b/tests/unit/test_trackers.py @@ -165,3 +165,22 @@ def test_per_class_tracker_active_tracks(tracker_type): assert tracker.per_class_active_tracks[0], "No active tracks for class 0" assert tracker.per_class_active_tracks[65], "No active tracks for class 65" + +@pytest.mark.parametrize("tracker_type", ALL_TRACKERS) +@pytest.mark.parametrize("dets", [None, np.array([])]) +def test_tracker_with_no_detections(tracker_type, dets): + tracker_conf = get_tracker_config(tracker_type) + tracker = create_tracker( + tracker_type=tracker_type, + tracker_config=tracker_conf, + reid_weights=WEIGHTS / 'mobilenetv2_x1_4_dukemtmcreid.pt', + device='cpu', + half=False, + per_class=False + ) + + rgb = np.random.randint(255, size=(640, 640, 3), dtype=np.uint8) + embs = np.random.random(size=(2, 512)) + + output = tracker.update(dets, rgb, embs) + assert output.size == 0, "Output should be empty when no detections are provided" \ No newline at end of file