146 147 # for i, ref_img in enumerate(reference_faces): 148 # ref_img.save(f'./{i + 1}.png') 149 print(f'detected {len(reference_faces)} faces') 150 assert len( 151 reference_faces) > 0, 'No face detected in the reference images' 152 153 # process the ref_imgs 154 H = height
166 caching=None, 167 style_fidelity=0.5): 168 num_sample = t.shape[0] 169 num_diff_condition = context.shape[0] // num_sample 170 assert num_diff_condition == 2 or num_diff_condition == 3 171 t = t.repeat_interleave(num_diff_condition, dim=0) 172 # embeddings 173 self.share_cache['num_diff_condition'] = num_diff_condition
210 enable_encoder=False, 211 **kwargs): 212 """UNet of Stable Diffusion 1.x (1.1~1.5).""" 213 # sanity check 214 assert version in ('sd-v1-1_ema', 'sd-v1-1_nonema', 'sd-v1-2_ema', 215 'sd-v1-2_nonema', 'sd-v1-3_ema', 'sd-v1-3_nonema', 216 'sd-v1-4_ema', 'sd-v1-4_nonema', 'sd-v1-5_ema', 217 'sd-v1-5_nonema', 'sd-v1-5-inpainting_nonema') 218 219 # dedue dimension 220 in_dim = 4
66 67 # classifier-free guidance (arXiv:2207.12598) 68 # model_kwargs[0]: conditional kwargs 69 # model_kwargs[1]: non-conditional kwargs 70 assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 71 # model.share_cache["do_mimic"] = True 72 73 assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
69 # model_kwargs[1]: non-conditional kwargs 70 assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 71 # model.share_cache["do_mimic"] = True 72 73 assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 74 cat_xt = xt.repeat(3, 1, 1, 1) 75 conditional_embed = model_kwargs[0]['context'] 76 non_conditional_embed = model_kwargs[1]['context']
88 out = u_out + guide_scale * (y_out - u_out) + classifier * ( 89 y_out_with_ref - y_out) 90 91 if guide_rescale is not None: 92 assert guide_rescale >= 0 and guide_rescale <= 1 93 ratio = ( 94 y_out.flatten(1).std(dim=1) / 95 (out.flatten(1).std(dim=1) + 1e-12)).view((-1, ) + (1, ) *
100 # classifier-free guidance (arXiv:2207.12598) 101 # model_kwargs[0]: conditional kwargs 102 # model_kwargs[1]: non-conditional kwargs 103 # make it batch inference for both conditional and non-conditional 104 assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 105 cat_xt = xt.repeat(2, 1, 1, 1) 106 conditional_embed = model_kwargs[0]['context'] 107 non_conditional_embed = model_kwargs[1]['context']
120 # out = base_out + guide_scale * (y_out - base_out) 121 # out = base_out + guide_rescale * (y_out - base_out) 122 # rescale the output according to arXiv:2305.08891 123 if guide_rescale is not None: 124 assert guide_rescale >= 0 and guide_rescale <= 1 125 ratio = ( 126 y_out.flatten(1).std(dim=1) / 127 (out.flatten(1).std(dim=1) + 1e-12)).view((-1, ) + (1, ) *
141 142 # restrict the range of x0 143 if percentile is not None: 144 # NOTE: percentile should only be used when data is within range [-1, 1] 145 assert percentile > 0 and percentile <= 1 146 s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1) 147 s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1)) 148 x0 = torch.min(s, torch.max(-s, x0)) / s
68 def __init__(self, 69 length=77, 70 padding='zero', 71 bpe_path='tokenizers/clip/bpe_simple_vocab_16e6.txt.gz'): 72 assert padding in ('zero', 'eos') 73 self.length = length 74 self.padding = padding 75 self.bpe_path = bpe_path
32 num_heads, 33 causal=False, 34 attn_dropout=0.0, 35 proj_dropout=0.0): 36 assert dim % num_heads == 0 37 super().__init__() 38 self.dim = dim 39 self.num_heads = num_heads
99 activation='quick_gelu', 100 attn_dropout=0.0, 101 proj_dropout=0.0, 102 norm_eps=1e-5): 103 assert activation in ['quick_gelu', 'gelu', 'swi_glu'] 104 super().__init__() 105 self.dim = dim 106 self.mlp_ratio = mlp_ratio
140 num_heads, 141 activation='gelu', 142 proj_dropout=0.0, 143 norm_eps=1e-5): 144 assert dim % num_heads == 0 145 super().__init__() 146 self.dim = dim 147 self.mlp_ratio = mlp_ratio
206 norm_eps=1e-5): 207 if image_size % patch_size != 0: 208 print('[WARNING] image_size is not divisible by patch_size', 209 flush=True) 210 assert pool_type in ('token', 'token_fc', 'attn_pool') 211 out_dim = out_dim or dim 212 super().__init__() 213 self.image_size = image_size
298 attn_dropout=0.0, 299 proj_dropout=0.0, 300 embedding_dropout=0.0, 301 norm_eps=1e-5): 302 assert pool_type in ('argmax', 'last') 303 out_dim = out_dim or dim 304 super().__init__() 305 self.vocab_size = vocab_size
504 505 # load checkpoint 506 if pretrained and pretrained_name: 507 path = Path(__file__).parents[5] / "models" / "clip" / "openai-clip-vit-large-14.pth" 508 assert pretrained_name in str(path) 509 # load 510 model.load_state_dict(torch.load(path, 511 map_location=device,
118 119 class SSH(nn.Module): 120 121 def __init__(self, in_dim, out_dim): 122 assert out_dim % 4 == 0 123 leaky = 0.1 if out_dim <= 64 else 0.0 124 super().__init__() 125 self.conv3X3 = conv_bn_no_relu(in_dim, out_dim // 2, stride=1)
257 258 class RetinaFace(nn.Module): 259 260 def __init__(self, backbone='resnet50'): 261 assert backbone in CONFIGS 262 super().__init__() 263 self.cfg = CONFIGS[backbone] 264
305 """ 306 imgs: [B, C, H, W] within range [0, 1]. 307 """ 308 # preprocess 309 assert mode in ['RGB', 'BGR'] 310 if mode == 'RGB': 311 imgs = imgs.flip(1) 312 imgs = 255.0 * imgs - imgs.new_tensor(self.cfg['mean']).view(
91 dropout=0.0): 92 # consider head_dim first, then num_heads 93 num_heads = dim // head_dim if head_dim else num_heads 94 head_dim = dim // num_heads 95 assert num_heads * head_dim == dim 96 context_dim = context_dim or dim 97 super().__init__() 98 self.dim = dim
187 context=None, 188 mask=None, 189 caching=None, 190 style_fidelity=0.5): 191 assert caching in (None, 'write', 'read') 192 193 # self-attention 194 if self.disable_self_attn:
195 x = x + self.self_attn(self.norm1(x), context, mask) 196 else: 197 y = self.norm1(x) 198 if caching == 'read': 199 assert self.cache is not None 200 201 # read cache & self-attention 202 ctx = torch.cat([
258 context=None, 259 mask=None, 260 caching=None, 261 style_fidelity=0.5): 262 assert caching in (None, 'write', 'read') 263 b, c, h, w = x.size() 264 identity = x 265
29 30 class Resample(nn.Module): 31 32 def __init__(self, dim, mode): 33 assert mode in ['none', 'upsample', 'downsample'] 34 super().__init__() 35 self.dim = dim 36 self.mode = mode
24 25 class GaussianDiffusion(object): 26 27 def __init__(self, sigmas, prediction_type='eps'): 28 assert prediction_type in {'x0', 'eps', 'v'} 29 self.sigmas = sigmas # noise coefficients 30 self.alphas = torch.sqrt(1 - sigmas**2) # signal coefficients 31 self.num_timesteps = len(sigmas)
73 log_var = torch.log(var).clamp_(-20, 20) 74 75 # prediction 76 if guide_scale is None: 77 assert isinstance(model_kwargs, dict) 78 out = model(xt, t=t, **model_kwargs) 79 else: 80 # classifier-free guidance (arXiv:2207.12598)
79 else: 80 # classifier-free guidance (arXiv:2207.12598) 81 # model_kwargs[0]: conditional kwargs 82 # model_kwargs[1]: non-conditional kwargs 83 assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 84 y_out = model(xt, t=t, **model_kwargs[0]) 85 if guide_scale == 1.: 86 out = y_out
89 out = u_out + guide_scale * (y_out - u_out) 90 91 # rescale the output according to arXiv:2305.08891 92 if guide_rescale is not None: 93 assert guide_rescale >= 0 and guide_rescale <= 1 94 ratio = (y_out.flatten(1).std(dim=1) / 95 (out.flatten(1).std(dim=1) + 96 1e-12)).view((-1, ) + (1, ) * (y_out.ndim - 1))
109 110 # restrict the range of x0 111 if percentile is not None: 112 # NOTE: percentile should only be used when data is within range [-1, 1] 113 assert percentile > 0 and percentile <= 1 114 s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1) 115 s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1)) 116 x0 = torch.min(s, torch.max(-s, x0)) / s
142 callback=None, 143 seed=-1, 144 **kwargs): 145 # sanity check 146 assert isinstance(steps, (int, torch.LongTensor)) 147 assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) 148 assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) 149 assert discretization in (None, 'leading', 'linspace', 'trailing')
143 seed=-1, 144 **kwargs): 145 # sanity check 146 assert isinstance(steps, (int, torch.LongTensor)) 147 assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) 148 assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) 149 assert discretization in (None, 'leading', 'linspace', 'trailing') 150 assert discard_penultimate_step in (None, True, False)
144 **kwargs): 145 # sanity check 146 assert isinstance(steps, (int, torch.LongTensor)) 147 assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) 148 assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) 149 assert discretization in (None, 'leading', 'linspace', 'trailing') 150 assert discard_penultimate_step in (None, True, False) 151 assert return_intermediate in (None, 'x0', 'xt')
145 # sanity check 146 assert isinstance(steps, (int, torch.LongTensor)) 147 assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) 148 assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) 149 assert discretization in (None, 'leading', 'linspace', 'trailing') 150 assert discard_penultimate_step in (None, True, False) 151 assert return_intermediate in (None, 'x0', 'xt') 152 assert solver == 'ddim'
146 assert isinstance(steps, (int, torch.LongTensor)) 147 assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) 148 assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) 149 assert discretization in (None, 'leading', 'linspace', 'trailing') 150 assert discard_penultimate_step in (None, True, False) 151 assert return_intermediate in (None, 'x0', 'xt') 152 assert solver == 'ddim' 153 # function of diffusion solver
147 assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) 148 assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) 149 assert discretization in (None, 'leading', 'linspace', 'trailing') 150 assert discard_penultimate_step in (None, True, False) 151 assert return_intermediate in (None, 'x0', 'xt') 152 assert solver == 'ddim' 153 # function of diffusion solver 154 solver_fn = {'ddim': sample_ddim}[solver]
148 assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) 149 assert discretization in (None, 'leading', 'linspace', 'trailing') 150 assert discard_penultimate_step in (None, True, False) 151 assert return_intermediate in (None, 'x0', 'xt') 152 assert solver == 'ddim' 153 # function of diffusion solver 154 solver_fn = {'ddim': sample_ddim}[solver] 155
155 156 # options 157 schedule = 'karras' if 'karras' in solver else None 158 discretization = discretization or 'linspace' 159 seed = seed if seed >= 0 else random.randint(0, 2**31) 160 if isinstance(steps, torch.LongTensor): 161 discard_penultimate_step = False 162 if discard_penultimate_step is None:
1 import os 2 import subprocess 3 import urllib.request 4 import shutil 5 6 7 def run_command(command): 8 result = subprocess.run(command, shell=True)
4 import shutil 5 6 7 def run_command(command): 8 result = subprocess.run(command, shell=True) 9 if result.returncode != 0: 10 raise RuntimeError(f"Command failed with exit code {result.returncode}: {command}") 11
17 18 def download_file(url, output_path): 19 if not os.path.exists(output_path): 20 try: 21 with urllib.request.urlopen(url) as response, open(output_path, 'wb') as out_file: 22 shutil.copyfileobj(response, out_file) 23 except Exception as e: 24 print(f"Failed to download {output_path}. Error: {e}")