diff --git a/easycv/models/detection/detectors/yolox/yolox.py b/easycv/models/detection/detectors/yolox/yolox.py index 4e8eb452..3df8bf80 100644 --- a/easycv/models/detection/detectors/yolox/yolox.py +++ b/easycv/models/detection/detectors/yolox/yolox.py @@ -52,6 +52,22 @@ def __init__(self, assert model_type in self.param_map, f'invalid model_type for yolox {model_type}, valid ones are {list(self.param_map.keys())}' + if num_classes is not None: + # adapt to previous export model (before easycv0.6.0) + logging.warning( + 'Warning: You are now attend to use an old YOLOX model before easycv0.6.0 with key num_classes' + ) + head = dict( + type='YOLOXHead', + model_type=model_type, + num_classes=num_classes, + ) + + # the change of backbone/neck/head only support model_type as 's' + if model_type != 's': + head_type = head.get('type', None) + assert backbone == 'CSPDarknet' and neck_type == 'yolo' and neck_mode == 'all' and head_type == 'YOLOXHead', 'We only support the architecture modification for YOLOX-S.' + self.pretrained = pretrained in_channels = [256, 512, 1024] @@ -68,19 +84,14 @@ def __init__(self, asff_channel=asff_channel, use_att=use_att) - if num_classes is not None: - # adapt to previous export model (before easycv0.6.0) - logging.warning( - 'Warning: You are now attend to use an old YOLOX model before easycv0.6.0 with key num_classes' - ) - head = dict( - type='YOLOXHead', - model_type=model_type, - num_classes=num_classes, - ) - if head is not None: # head is None for YOLOX-edge to define a special head + # set and check model type in head as the same of yolox + head_model_type = head.get('model_type', None) + if head_model_type is None: + head['model_type'] = model_type + else: + assert model_type == head_model_type, 'Please provide the same model_type of YOLOX in config.' self.head = build_head(head) self.num_classes = self.head.num_classes