146
147 # for i, ref_img in enumerate(reference_faces):
148 # ref_img.save(f'./{i + 1}.png')
149 print(f'detected {len(reference_faces)} faces')
150 assert len(
151 reference_faces) > 0, 'No face detected in the reference images'
152
153 # process the ref_imgs
154 H = height
166 caching=None, 167 style_fidelity=0.5): 168 num_sample = t.shape[0] 169 num_diff_condition = context.shape[0] // num_sample 170 assert num_diff_condition == 2 or num_diff_condition == 3 171 t = t.repeat_interleave(num_diff_condition, dim=0) 172 # embeddings 173 self.share_cache['num_diff_condition'] = num_diff_condition
210 enable_encoder=False,
211 **kwargs):
212 """UNet of Stable Diffusion 1.x (1.1~1.5)."""
213 # sanity check
214 assert version in ('sd-v1-1_ema', 'sd-v1-1_nonema', 'sd-v1-2_ema',
215 'sd-v1-2_nonema', 'sd-v1-3_ema', 'sd-v1-3_nonema',
216 'sd-v1-4_ema', 'sd-v1-4_nonema', 'sd-v1-5_ema',
217 'sd-v1-5_nonema', 'sd-v1-5-inpainting_nonema')
218
219 # dedue dimension
220 in_dim = 4
66 67 # classifier-free guidance (arXiv:2207.12598) 68 # model_kwargs[0]: conditional kwargs 69 # model_kwargs[1]: non-conditional kwargs 70 assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 71 # model.share_cache["do_mimic"] = True 72 73 assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
69 # model_kwargs[1]: non-conditional kwargs 70 assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 71 # model.share_cache["do_mimic"] = True 72 73 assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 74 cat_xt = xt.repeat(3, 1, 1, 1) 75 conditional_embed = model_kwargs[0]['context'] 76 non_conditional_embed = model_kwargs[1]['context']
88 out = u_out + guide_scale * (y_out - u_out) + classifier * ( 89 y_out_with_ref - y_out) 90 91 if guide_rescale is not None: 92 assert guide_rescale >= 0 and guide_rescale <= 1 93 ratio = ( 94 y_out.flatten(1).std(dim=1) / 95 (out.flatten(1).std(dim=1) + 1e-12)).view((-1, ) + (1, ) *
100 # classifier-free guidance (arXiv:2207.12598) 101 # model_kwargs[0]: conditional kwargs 102 # model_kwargs[1]: non-conditional kwargs 103 # make it batch inference for both conditional and non-conditional 104 assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 105 cat_xt = xt.repeat(2, 1, 1, 1) 106 conditional_embed = model_kwargs[0]['context'] 107 non_conditional_embed = model_kwargs[1]['context']
120 # out = base_out + guide_scale * (y_out - base_out) 121 # out = base_out + guide_rescale * (y_out - base_out) 122 # rescale the output according to arXiv:2305.08891 123 if guide_rescale is not None: 124 assert guide_rescale >= 0 and guide_rescale <= 1 125 ratio = ( 126 y_out.flatten(1).std(dim=1) / 127 (out.flatten(1).std(dim=1) + 1e-12)).view((-1, ) + (1, ) *
141 142 # restrict the range of x0 143 if percentile is not None: 144 # NOTE: percentile should only be used when data is within range [-1, 1] 145 assert percentile > 0 and percentile <= 1 146 s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1) 147 s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1)) 148 x0 = torch.min(s, torch.max(-s, x0)) / s
68 def __init__(self,
69 length=77,
70 padding='zero',
71 bpe_path='tokenizers/clip/bpe_simple_vocab_16e6.txt.gz'):
72 assert padding in ('zero', 'eos')
73 self.length = length
74 self.padding = padding
75 self.bpe_path = bpe_path
32 num_heads, 33 causal=False, 34 attn_dropout=0.0, 35 proj_dropout=0.0): 36 assert dim % num_heads == 0 37 super().__init__() 38 self.dim = dim 39 self.num_heads = num_heads
99 activation='quick_gelu', 100 attn_dropout=0.0, 101 proj_dropout=0.0, 102 norm_eps=1e-5): 103 assert activation in ['quick_gelu', 'gelu', 'swi_glu'] 104 super().__init__() 105 self.dim = dim 106 self.mlp_ratio = mlp_ratio
140 num_heads, 141 activation='gelu', 142 proj_dropout=0.0, 143 norm_eps=1e-5): 144 assert dim % num_heads == 0 145 super().__init__() 146 self.dim = dim 147 self.mlp_ratio = mlp_ratio
206 norm_eps=1e-5):
207 if image_size % patch_size != 0:
208 print('[WARNING] image_size is not divisible by patch_size',
209 flush=True)
210 assert pool_type in ('token', 'token_fc', 'attn_pool')
211 out_dim = out_dim or dim
212 super().__init__()
213 self.image_size = image_size
298 attn_dropout=0.0,
299 proj_dropout=0.0,
300 embedding_dropout=0.0,
301 norm_eps=1e-5):
302 assert pool_type in ('argmax', 'last')
303 out_dim = out_dim or dim
304 super().__init__()
305 self.vocab_size = vocab_size
504 505 # load checkpoint 506 if pretrained and pretrained_name: 507 path = Path(__file__).parents[5] / "models" / "clip" / "openai-clip-vit-large-14.pth" 508 assert pretrained_name in str(path) 509 # load 510 model.load_state_dict(torch.load(path, 511 map_location=device,
118 119 class SSH(nn.Module): 120 121 def __init__(self, in_dim, out_dim): 122 assert out_dim % 4 == 0 123 leaky = 0.1 if out_dim <= 64 else 0.0 124 super().__init__() 125 self.conv3X3 = conv_bn_no_relu(in_dim, out_dim // 2, stride=1)
257 258 class RetinaFace(nn.Module): 259 260 def __init__(self, backbone='resnet50'): 261 assert backbone in CONFIGS 262 super().__init__() 263 self.cfg = CONFIGS[backbone] 264
305 """ 306 imgs: [B, C, H, W] within range [0, 1]. 307 """ 308 # preprocess 309 assert mode in ['RGB', 'BGR'] 310 if mode == 'RGB': 311 imgs = imgs.flip(1) 312 imgs = 255.0 * imgs - imgs.new_tensor(self.cfg['mean']).view(
91 dropout=0.0): 92 # consider head_dim first, then num_heads 93 num_heads = dim // head_dim if head_dim else num_heads 94 head_dim = dim // num_heads 95 assert num_heads * head_dim == dim 96 context_dim = context_dim or dim 97 super().__init__() 98 self.dim = dim
187 context=None, 188 mask=None, 189 caching=None, 190 style_fidelity=0.5): 191 assert caching in (None, 'write', 'read') 192 193 # self-attention 194 if self.disable_self_attn:
195 x = x + self.self_attn(self.norm1(x), context, mask) 196 else: 197 y = self.norm1(x) 198 if caching == 'read': 199 assert self.cache is not None 200 201 # read cache & self-attention 202 ctx = torch.cat([
258 context=None, 259 mask=None, 260 caching=None, 261 style_fidelity=0.5): 262 assert caching in (None, 'write', 'read') 263 b, c, h, w = x.size() 264 identity = x 265
29 30 class Resample(nn.Module): 31 32 def __init__(self, dim, mode): 33 assert mode in ['none', 'upsample', 'downsample'] 34 super().__init__() 35 self.dim = dim 36 self.mode = mode
24
25 class GaussianDiffusion(object):
26
27 def __init__(self, sigmas, prediction_type='eps'):
28 assert prediction_type in {'x0', 'eps', 'v'}
29 self.sigmas = sigmas # noise coefficients
30 self.alphas = torch.sqrt(1 - sigmas**2) # signal coefficients
31 self.num_timesteps = len(sigmas)
73 log_var = torch.log(var).clamp_(-20, 20) 74 75 # prediction 76 if guide_scale is None: 77 assert isinstance(model_kwargs, dict) 78 out = model(xt, t=t, **model_kwargs) 79 else: 80 # classifier-free guidance (arXiv:2207.12598)
79 else: 80 # classifier-free guidance (arXiv:2207.12598) 81 # model_kwargs[0]: conditional kwargs 82 # model_kwargs[1]: non-conditional kwargs 83 assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 84 y_out = model(xt, t=t, **model_kwargs[0]) 85 if guide_scale == 1.: 86 out = y_out
89 out = u_out + guide_scale * (y_out - u_out) 90 91 # rescale the output according to arXiv:2305.08891 92 if guide_rescale is not None: 93 assert guide_rescale >= 0 and guide_rescale <= 1 94 ratio = (y_out.flatten(1).std(dim=1) / 95 (out.flatten(1).std(dim=1) + 96 1e-12)).view((-1, ) + (1, ) * (y_out.ndim - 1))
109 110 # restrict the range of x0 111 if percentile is not None: 112 # NOTE: percentile should only be used when data is within range [-1, 1] 113 assert percentile > 0 and percentile <= 1 114 s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1) 115 s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1)) 116 x0 = torch.min(s, torch.max(-s, x0)) / s
142 callback=None, 143 seed=-1, 144 **kwargs): 145 # sanity check 146 assert isinstance(steps, (int, torch.LongTensor)) 147 assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) 148 assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) 149 assert discretization in (None, 'leading', 'linspace', 'trailing')
143 seed=-1, 144 **kwargs): 145 # sanity check 146 assert isinstance(steps, (int, torch.LongTensor)) 147 assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) 148 assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) 149 assert discretization in (None, 'leading', 'linspace', 'trailing') 150 assert discard_penultimate_step in (None, True, False)
144 **kwargs): 145 # sanity check 146 assert isinstance(steps, (int, torch.LongTensor)) 147 assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) 148 assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) 149 assert discretization in (None, 'leading', 'linspace', 'trailing') 150 assert discard_penultimate_step in (None, True, False) 151 assert return_intermediate in (None, 'x0', 'xt')
145 # sanity check 146 assert isinstance(steps, (int, torch.LongTensor)) 147 assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) 148 assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) 149 assert discretization in (None, 'leading', 'linspace', 'trailing') 150 assert discard_penultimate_step in (None, True, False) 151 assert return_intermediate in (None, 'x0', 'xt') 152 assert solver == 'ddim'
146 assert isinstance(steps, (int, torch.LongTensor)) 147 assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) 148 assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) 149 assert discretization in (None, 'leading', 'linspace', 'trailing') 150 assert discard_penultimate_step in (None, True, False) 151 assert return_intermediate in (None, 'x0', 'xt') 152 assert solver == 'ddim' 153 # function of diffusion solver
147 assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
148 assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
149 assert discretization in (None, 'leading', 'linspace', 'trailing')
150 assert discard_penultimate_step in (None, True, False)
151 assert return_intermediate in (None, 'x0', 'xt')
152 assert solver == 'ddim'
153 # function of diffusion solver
154 solver_fn = {'ddim': sample_ddim}[solver]
148 assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
149 assert discretization in (None, 'leading', 'linspace', 'trailing')
150 assert discard_penultimate_step in (None, True, False)
151 assert return_intermediate in (None, 'x0', 'xt')
152 assert solver == 'ddim'
153 # function of diffusion solver
154 solver_fn = {'ddim': sample_ddim}[solver]
155
155 156 # options 157 schedule = 'karras' if 'karras' in solver else None 158 discretization = discretization or 'linspace' 159 seed = seed if seed >= 0 else random.randint(0, 2**31) 160 if isinstance(steps, torch.LongTensor): 161 discard_penultimate_step = False 162 if discard_penultimate_step is None:
1 import os 2 import subprocess 3 import urllib.request 4 import shutil 5 6 7 def run_command(command): 8 result = subprocess.run(command, shell=True)
4 import shutil
5
6
7 def run_command(command):
8 result = subprocess.run(command, shell=True)
9 if result.returncode != 0:
10 raise RuntimeError(f"Command failed with exit code {result.returncode}: {command}")
11
17
18 def download_file(url, output_path):
19 if not os.path.exists(output_path):
20 try:
21 with urllib.request.urlopen(url) as response, open(output_path, 'wb') as out_file:
22 shutil.copyfileobj(response, out_file)
23 except Exception as e:
24 print(f"Failed to download {output_path}. Error: {e}")