-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathadapt.py
163 lines (125 loc) · 4.24 KB
/
adapt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from pathlib import Path
import json
from math import sqrt
import numpy as np
import torch
from abc import ABCMeta, abstractmethod
class ScoreAdapter(metaclass=ABCMeta):
@abstractmethod
def denoise(self, xs, σ, **kwargs):
pass
def score(self, xs, σ, **kwargs):
Ds = self.denoise(xs, σ, **kwargs)
grad_log_p_t = (Ds - xs) / (σ ** 2)
return grad_log_p_t
@abstractmethod
def data_shape(self):
return (3, 256, 256) # for example
def samps_centered(self):
# if centered, samples expected to be in range [-1, 1], else [0, 1]
return True
@property
@abstractmethod
def σ_max(self):
pass
@property
@abstractmethod
def σ_min(self):
pass
def cond_info(self, batch_size):
return {}
@abstractmethod
def unet_is_cond(self):
return False
@abstractmethod
def use_cls_guidance(self):
return False # most models do not use cls guidance
def classifier_grad(self, xs, σ, ys):
raise NotImplementedError()
@abstractmethod
def snap_t_to_nearest_tick(self, t):
# need to confirm for each model; continuous time model doesn't need this
return t, None
@property
def device(self):
return self._device
def checkpoint_root(self):
"""the path at which the pretrained checkpoints are stored"""
with Path(__file__).resolve().with_name("env.json").open("r") as f:
root = json.load(f)['data_root']
root = Path(root) / "diffusion_ckpts"
return root
def karras_t_schedule(ρ=7, N=10, σ_max=80, σ_min=0.002):
ts = []
for i in range(N):
t = (
σ_max ** (1 / ρ) + (i / (N - 1)) * (σ_min ** (1 / ρ) - σ_max ** (1 / ρ))
) ** ρ
ts.append(t)
return ts
def power_schedule(σ_max, σ_min, num_stages):
σs = np.exp(np.linspace(np.log(σ_max), np.log(σ_min), num_stages))
return σs
class Karras():
@classmethod
@torch.no_grad()
def inference(
cls, model, batch_size, num_t, *,
σ_max=80, cls_scaling=1,
init_xs=None, heun=True,
langevin=False,
S_churn=80, S_min=0.05, S_max=50, S_noise=1.003,
):
σ_max = min(σ_max, model.σ_max)
σ_min = model.σ_min
ts = karras_t_schedule(ρ=7, N=num_t, σ_max=σ_max, σ_min=σ_min)
assert len(ts) == num_t
ts = [model.snap_t_to_nearest_tick(t)[0] for t in ts]
ts.append(0) # 0 is the destination
σ_max = ts[0]
cond_inputs = model.cond_info(batch_size)
def compute_step(xs, σ):
grad_log_p_t = model.score(
xs, σ, **(cond_inputs if model.unet_is_cond() else {})
)
if model.use_cls_guidance():
grad_cls = model.classifier_grad(xs, σ, cond_inputs["y"])
grad_cls = grad_cls * cls_scaling
grad_log_p_t += grad_cls
d_i = -1 * σ * grad_log_p_t
return d_i
if init_xs is not None:
xs = init_xs.to(model.device)
else:
xs = σ_max * torch.randn(
batch_size, *model.data_shape(), device=model.device
)
yield xs
for i in range(num_t):
t_i = ts[i]
if langevin and (S_min < t_i and t_i < S_max):
xs, t_i = cls.noise_backward_in_time(
model, xs, t_i, S_noise, S_churn / num_t
)
Δt = ts[i+1] - t_i
d_1 = compute_step(xs, σ=t_i)
xs_1 = xs + Δt * d_1
# Heun's 2nd order method; don't apply on the last step
if (not heun) or (ts[i+1] == 0):
xs = xs_1
else:
d_2 = compute_step(xs_1, σ=ts[i+1])
xs = xs + Δt * (d_1 + d_2) / 2
yield xs
@staticmethod
def noise_backward_in_time(model, xs, t_i, S_noise, S_churn_i):
n = S_noise * torch.randn_like(xs)
γ_i = min(sqrt(2)-1, S_churn_i)
t_i_hat = t_i * (1 + γ_i)
t_i_hat = model.snap_t_to_nearest_tick(t_i_hat)[0]
xs = xs + n * sqrt(t_i_hat ** 2 - t_i ** 2)
return xs, t_i_hat
def test():
pass
if __name__ == "__main__":
test()