Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 10, 2025
1 parent 35f2922 commit 60d8681
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 20 deletions.
45 changes: 30 additions & 15 deletions src/scvi/model/base/_archesmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def load_query_data(
)

attr_dict, var_names, load_state_dict, pyro_param_store = _get_loaded_data(
reference_model, device=device)
reference_model, device=device
)

if isinstance(adata, MuData):
for modality in adata.mod:
Expand Down Expand Up @@ -148,7 +149,7 @@ def load_query_data(
# model tweaking
new_state_dict = model.module.state_dict()
for key, load_ten in load_state_dict.items():
print('SSSSS', key)
print("SSSSS", key)
new_ten = new_state_dict[key]
load_ten = load_ten.to(new_ten.device)
if new_ten.size() == load_ten.size():
Expand Down Expand Up @@ -189,19 +190,23 @@ def _arches_pyro_setup(model, pyro_param_store):
model.module.on_load(model)
param_names = pyro.get_param_store().get_all_param_names()
param_store = pyro.get_param_store().get_state()
pyro.clear_param_store() # we will re-add the params with the correct loaded values.
pyro.clear_param_store() # we will re-add the params with the correct loaded values.
block_parameter = []
for name in param_names:
new_param = param_store['params'][name]
new_constraint = param_store['constraints'][name]
old_param = pyro_param_store['params'].pop(name, None).to(new_param.device)
old_constraint = pyro_param_store['constraints'].pop(name, None)
new_param = param_store["params"][name]
new_constraint = param_store["constraints"][name]
old_param = pyro_param_store["params"].pop(name, None).to(new_param.device)
old_constraint = pyro_param_store["constraints"].pop(name, None)
if old_param is None:
logging.warning(f'Parameter {name} in pyro param_store but not found in reference model.')
logging.warning(
f"Parameter {name} in pyro param_store but not found in reference model."
)
pyro.param(name, new_param, constraint=new_constraint)
continue
if type(new_constraint) is not type(old_constraint):
logging.warning(f'Constraint mismatch for {name} in pyro param_store. Cannot transfer map parameter.')
logging.warning(
f"Constraint mismatch for {name} in pyro param_store. Cannot transfer map parameter."
)
pyro.param(name, new_param, constraint=new_constraint)
continue
old_param = transform_to(old_constraint)(old_param).detach().requires_grad_()
Expand All @@ -212,16 +217,26 @@ def _arches_pyro_setup(model, pyro_param_store):
else:
dim_diff = new_param.size()[-1] - old_param.size()[-1]
if dim_diff:
updated_param = torch.cat([old_param, new_param[..., -dim_diff:]], dim=-1).detach().requires_grad_()
updated_param = (
torch.cat([old_param, new_param[..., -dim_diff:]], dim=-1)
.detach()
.requires_grad_()
)
pyro.param(name, updated_param, constraint=old_constraint)
elif new_param.size()[0] - old_param.size()[0]:
dim_diff = new_param.size()[0] - old_param.size()[0]
updated_param = torch.cat([old_param, new_param[-dim_diff:, ...]], dim=0).detach().requires_grad_()
updated_param = (
torch.cat([old_param, new_param[-dim_diff:, ...]], dim=0)
.detach()
.requires_grad_()
)
pyro.param(name, updated_param, constraint=old_constraint)
else:
ValueError('Parameter size mismatch in other dimension than 0 or 1. This is not supported.')
ValueError(
"Parameter size mismatch in other dimension than 0 or 1. This is not supported."
)

if hasattr(model, '_block_parameters'):
if hasattr(model, "_block_parameters"):
model._block_parameters = block_parameter

@staticmethod
Expand Down Expand Up @@ -409,14 +424,14 @@ def _get_loaded_data(reference_model, device=None):
reference_model, load_adata=False, map_location=device
)
pyro_param_store = load_state_dict.pop("pyro_param_store", None)
print('PPPP loading')
print("PPPP loading")
else:
attr_dict = reference_model._get_user_attributes()
attr_dict = {a[0]: a[1] for a in attr_dict if a[0][-1] == "_"}
var_names = _get_var_names(reference_model.adata)
load_state_dict = deepcopy(reference_model.module.state_dict())
pyro_param_store = pyro.get_param_store().get_state()
print('PPPP loaded')
print("PPPP loaded")

return attr_dict, var_names, load_state_dict, pyro_param_store

Expand Down
4 changes: 2 additions & 2 deletions src/scvi/module/base/_base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,9 @@ def on_load(self, model, **kwargs):
old_history = model.history_.copy() if model.history_ is not None else None
model.train(max_steps=1, **self.on_load_kwargs)
model.history_ = old_history
if 'pyro_param_store' in kwargs:
if "pyro_param_store" in kwargs:
# For scArches shapes are changed and we don't want to overwrite these changed shapes.
pyro.get_param_store().set_state(kwargs['pyro_param_store'])
pyro.get_param_store().set_state(kwargs["pyro_param_store"])

def create_predictive(
self,
Expand Down
4 changes: 1 addition & 3 deletions src/scvi/train/_trainingplans.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,9 +1070,7 @@ def __init__(
# We let SVI take care of all optimization
self.automatic_optimization = False
self.block_fn = (
lambda obj: pyro.poutine.block(obj, hide=blocked)
if blocked is not None
else obj
lambda obj: pyro.poutine.block(obj, hide=blocked) if blocked is not None else obj
)

self.svi = pyro.infer.SVI(
Expand Down

0 comments on commit 60d8681

Please sign in to comment.