Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify yolox and model load #339

Merged
merged 3 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions easycv/models/detection/detectors/yolox/yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,22 @@ def __init__(self,

assert model_type in self.param_map, f'invalid model_type for yolox {model_type}, valid ones are {list(self.param_map.keys())}'

if num_classes is not None:
# adapt to previous export model (before easycv0.6.0)
logging.warning(
'Warning: You are now attend to use an old YOLOX model before easycv0.6.0 with key num_classes'
)
head = dict(
type='YOLOXHead',
model_type=model_type,
num_classes=num_classes,
)

# the change of backbone/neck/head only support model_type as 's'
if model_type != 's':
head_type = head.get('type', None)
assert backbone == 'CSPDarknet' and neck_type == 'yolo' and neck_mode == 'all' and head_type == 'YOLOXHead', 'We only support the architecture modification for YOLOX-S.'

self.pretrained = pretrained

in_channels = [256, 512, 1024]
Expand All @@ -68,19 +84,14 @@ def __init__(self,
asff_channel=asff_channel,
use_att=use_att)

if num_classes is not None:
# adapt to previous export model (before easycv0.6.0)
logging.warning(
'Warning: You are now attend to use an old YOLOX model before easycv0.6.0 with key num_classes'
)
head = dict(
type='YOLOXHead',
model_type=model_type,
num_classes=num_classes,
)

if head is not None:
# head is None for YOLOX-edge to define a special head
# set and check model type in head as the same of yolox
head_model_type = head.get('model_type', None)
if head_model_type is None:
head['model_type'] = model_type
else:
assert model_type == head_model_type, 'Please provide the same model_type of YOLOX in config.'
self.head = build_head(head)
self.num_classes = self.head.num_classes

Expand Down
151 changes: 149 additions & 2 deletions easycv/utils/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
import re
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from mmcv.parallel import is_module_wrapper
from mmcv.runner import load_checkpoint as mmcv_load_checkpoint
from mmcv.runner import _load_checkpoint as _load_checkpoint
from mmcv.runner.checkpoint import get_state_dict, weights_to_cpu
from torch import distributed as dist
from torch.optim import Optimizer

from easycv.file import io
Expand Down Expand Up @@ -46,6 +51,148 @@ def get_checkpoint(filename):
return filename


def load_and_check_state_dict(module: nn.Module,
state_dict: Union[dict, OrderedDict],
strict: bool = False,
logger: Optional[logging.Logger] = None) -> None:
"""Load state_dict to a module.

This method is modified from :meth:`mmcv.runner.checkpoint.load_state_dict`.
Default value for ``strict`` is set to ``False`` and the message for
param mismatch will be shown even if strict is False.
Raise error when state_dict is highly mismatched.

Args:
module (Module): Module that receives the state_dict.
state_dict (dict or OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys: List[str] = []
all_missing_keys: List[str] = []
err_msg: List[str] = []

metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy() # type: ignore
if metadata is not None:
state_dict._metadata = metadata # type: ignore

# use _load_from_state_dict to enable checkpoint version control
def load(module, prefix=''):
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_module_wrapper(module):
module = module.module
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
all_missing_keys, unexpected_keys,
err_msg)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')

def get_dist_info() -> Tuple[int, int]:
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size

load(module)
# break load->load reference cycle
load = None # type: ignore

# ignore "num_batches_tracked" of BN layers
missing_keys = [
key for key in all_missing_keys if 'num_batches_tracked' not in key
]

if unexpected_keys:
err_msg.append('unexpected key in source '
f'state_dict: {", ".join(unexpected_keys)}\n')
if missing_keys:
err_msg.append(
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')

rank, _ = get_dist_info()
if len(err_msg) > 0 and rank == 0:
err_msg.insert(
0, 'The model and loaded state dict do not match exactly\n')
err_msg = '\n'.join(err_msg) # type: ignore
if strict:
raise RuntimeError(err_msg)
else:
if logger is not None:
logger.warning(err_msg)
else:
print(err_msg)

err_msg_list = err_msg.split('\n')

for error_msg_info in err_msg_list:
if 'size mismatch' in error_msg_info and 'cls' not in error_msg_info:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

涉及单个模型定制化字段,不建议按照当前的方案。每个模型的字段可能都不相同,应该内聚在每单个模型中,不建议放在统一的通用接口。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,我们在文档 和gallery 内部显式 进行用户选择模型参数合理性判断。

raise RuntimeError(
'Please check your pretrained model. The parameters do not match outside of the cls layer.'
)


def load_and_check_checkpoint(
model: torch.nn.Module,
filename: str,
map_location: Union[str, Callable, None] = None,
strict: bool = False,
logger: Optional[logging.Logger] = None,
revise_keys: list = [(r'^module\.', '')]) -> Union[dict, OrderedDict]:
"""Load checkpoint from a file or URI.

Args:
model (Module): Module to load checkpoint.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
revise_keys (list): A list of customized keywords to modify the
state_dict in checkpoint. Each item is a (pattern, replacement)
pair of the regular expression operations. Default: strip
the prefix 'module.' by [(r'^module\\.', '')].

Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = _load_checkpoint(filename, map_location, logger)
# OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict):
raise RuntimeError(
f'No state_dict found in checkpoint file {filename}')
# get state_dict from checkpoint
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint

# strip prefix of state_dict
metadata = getattr(state_dict, '_metadata', OrderedDict())
for p, r in revise_keys:
state_dict = OrderedDict(
{re.sub(p, r, k): v
for k, v in state_dict.items()})
# Keep metadata in state_dict
state_dict._metadata = metadata

# load state_dict
load_and_check_state_dict(model, state_dict, strict, logger)
return checkpoint


def load_checkpoint(model,
filename,
map_location='cpu',
Expand All @@ -72,7 +219,7 @@ def load_checkpoint(model,
dict or OrderedDict: The loaded checkpoint.
"""
filename = get_checkpoint(filename)
return mmcv_load_checkpoint(
return load_and_check_checkpoint(
model,
filename,
map_location=map_location,
Expand Down
Loading