60 imported_module = importlib.import_module(".py.{}".format(name), __name__) 61 try: 62 NODE_CLASS_MAPPINGS = {**NODE_CLASS_MAPPINGS, **imported_module.NODE_CLASS_MAPPINGS} 63 NODE_DISPLAY_NAME_MAPPINGS = {**NODE_DISPLAY_NAME_MAPPINGS, **imported_module.NODE_DISPLAY_NAME_MAPPINGS} 64 except: 65 pass 66 67 __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
19 elif bb_name == 'resnet50': 20 bb_net = list(resnet50(pretrained=ResNet50_Weights.DEFAULT if pretrained else None).children()) 21 bb = nn.Sequential(OrderedDict({'conv1': nn.Sequential(*bb_net[0:3]), 'conv2': bb_net[4], 'conv3': bb_net[5], 'conv4': bb_net[6]})) 22 else: 23 bb = eval('{}({})'.format(bb_name, params_settings)) 24 if pretrained: 25 bb = load_weights(bb, bb_name) 26 return bb
50 51 class Attention(nn.Module): 52 def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 53 super().__init__() 54 assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 55 56 self.dim = dim 57 self.num_heads = num_heads
187 self.num_heads = num_heads 188 self.window_size = window_size 189 self.shift_size = shift_size 190 self.mlp_ratio = mlp_ratio 191 assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 192 193 self.norm1 = norm_layer(dim) 194 self.attn = WindowAttention(
212 mask_matrix: Attention mask for cyclic shift. 213 """ 214 B, L, C = x.shape 215 H, W = self.H, self.W 216 assert L == H * W, "input feature has wrong size" 217 218 shortcut = x 219 x = self.norm1(x)
283 x: Input feature, tensor size (B, H*W, C). 284 H, W: Spatial resolution of the input feature. 285 """ 286 B, L, C = x.shape 287 assert L == H * W, "input feature has wrong size" 288 289 x = x.view(B, H, W, C) 290
36 ) 37 38 if self.config.squeeze_block: 39 self.squeeze_module = nn.Sequential(*[ 40 eval(self.config.squeeze_block.split('_x')[0])(channels[0]+sum(self.config.cxt), channels[0]) 41 for _ in range(eval(self.config.squeeze_block.split('_x')[1])) 42 ]) 43
37 38 if self.config.squeeze_block: 39 self.squeeze_module = nn.Sequential(*[ 40 eval(self.config.squeeze_block.split('_x')[0])(channels[0]+sum(self.config.cxt), channels[0]) 41 for _ in range(eval(self.config.squeeze_block.split('_x')[1])) 42 ]) 43 44 self.decoder = Decoder(channels)
62 if self.config.refine: 63 if self.config.refine == 'itself': 64 self.stem_layer = StemLayer(in_channels=3+1, inter_channels=48, out_channels=3) 65 else: 66 self.refiner = eval('{}({})'.format(self.config.refine, 'in_channels=3+1')) 67 68 if self.config.freeze_bb: 69 # Freeze the backbone...
145 class Decoder(nn.Module): 146 def __init__(self, channels): 147 super(Decoder, self).__init__() 148 self.config = Config() 149 DecoderBlock = eval(self.config.dec_blk) 150 LateralBlock = eval(self.config.lat_blk) 151 152 if self.config.dec_ipt:
146 def __init__(self, channels): 147 super(Decoder, self).__init__() 148 self.config = Config() 149 DecoderBlock = eval(self.config.dec_blk) 150 LateralBlock = eval(self.config.lat_blk) 151 152 if self.config.dec_ipt: 153 self.split = self.config.dec_ipt_split
13 bias=False): 14 15 super(DeformableConv2d, self).__init__() 16 17 assert type(kernel_size) == tuple or type(kernel_size) == int 18 19 kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size) 20 self.stride = stride if type(stride) == tuple else (stride, stride)
29 30 class Attention(nn.Module): 31 def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 32 super().__init__() 33 assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 34 35 self.dim = dim 36 self.num_heads = num_heads
18 return image, label 19 20 21 def cv_random_flip(img, label): 22 if random.random() > 0.5: 23 img = img.transpose(Image.FLIP_LEFT_RIGHT) 24 label = label.transpose(Image.FLIP_LEFT_RIGHT) 25 return img, label
39 40 41 def random_rotate(image, label, angle=15): 42 mode = Image.BICUBIC 43 if random.random() > 0.8: 44 random_angle = np.random.randint(-angle, angle) 45 image = image.rotate(random_angle, mode) 46 label = label.rotate(random_angle, mode)
47 return image, label 48 49 50 def color_enhance(image): 51 bright_intensity = random.randint(5, 15) / 10.0 52 image = ImageEnhance.Brightness(image).enhance(bright_intensity) 53 contrast_intensity = random.randint(5, 15) / 10.0 54 image = ImageEnhance.Contrast(image).enhance(contrast_intensity)
49 50 def color_enhance(image): 51 bright_intensity = random.randint(5, 15) / 10.0 52 image = ImageEnhance.Brightness(image).enhance(bright_intensity) 53 contrast_intensity = random.randint(5, 15) / 10.0 54 image = ImageEnhance.Contrast(image).enhance(contrast_intensity) 55 color_intensity = random.randint(0, 20) / 10.0 56 image = ImageEnhance.Color(image).enhance(color_intensity)
51 bright_intensity = random.randint(5, 15) / 10.0 52 image = ImageEnhance.Brightness(image).enhance(bright_intensity) 53 contrast_intensity = random.randint(5, 15) / 10.0 54 image = ImageEnhance.Contrast(image).enhance(contrast_intensity) 55 color_intensity = random.randint(0, 20) / 10.0 56 image = ImageEnhance.Color(image).enhance(color_intensity) 57 sharp_intensity = random.randint(0, 30) / 10.0 58 image = ImageEnhance.Sharpness(image).enhance(sharp_intensity)
53 contrast_intensity = random.randint(5, 15) / 10.0 54 image = ImageEnhance.Contrast(image).enhance(contrast_intensity) 55 color_intensity = random.randint(0, 20) / 10.0 56 image = ImageEnhance.Color(image).enhance(color_intensity) 57 sharp_intensity = random.randint(0, 30) / 10.0 58 image = ImageEnhance.Sharpness(image).enhance(sharp_intensity) 59 return image 60
75 def random_pepper(img, N=0.0015): 76 img = np.array(img) 77 noiseNum = int(N * img.shape[0] * img.shape[1]) 78 for i in range(noiseNum): 79 randX = random.randint(0, img.shape[0] - 1) 80 randY = random.randint(0, img.shape[1] - 1) 81 if random.randint(0, 1) == 0: 82 img[randX, randY] = 0
76 img = np.array(img) 77 noiseNum = int(N * img.shape[0] * img.shape[1]) 78 for i in range(noiseNum): 79 randX = random.randint(0, img.shape[0] - 1) 80 randY = random.randint(0, img.shape[1] - 1) 81 if random.randint(0, 1) == 0: 82 img[randX, randY] = 0 83 else:
77 noiseNum = int(N * img.shape[0] * img.shape[1]) 78 for i in range(noiseNum): 79 randX = random.randint(0, img.shape[0] - 1) 80 randY = random.randint(0, img.shape[1] - 1) 81 if random.randint(0, 1) == 0: 82 img[randX, randY] = 0 83 else: 84 img[randX, randY] = 255
109 class Decoder(nn.Module): 110 def __init__(self, channels): 111 super(Decoder, self).__init__() 112 self.config = Config() 113 DecoderBlock = eval('BasicDecBlk') 114 LateralBlock = eval('BasicLatBlk') 115 116 self.decoder_block4 = DecoderBlock(channels[0], channels[1])
110 def __init__(self, channels): 111 super(Decoder, self).__init__() 112 self.config = Config() 113 DecoderBlock = eval('BasicDecBlk') 114 LateralBlock = eval('BasicLatBlk') 115 116 self.decoder_block4 = DecoderBlock(channels[0], channels[1]) 117 self.decoder_block3 = DecoderBlock(channels[1], channels[2])
57 model_file_path = "" 58 try: 59 model_file_path = os.path.join( 60 os.path.normpath(folder_paths.folder_names_and_paths[model_folder_name][0][0]), model_name) 61 except: 62 pass 63 if not os.path.exists(model_file_path): 64 model_file_path = os.path.join(folder_paths.models_dir, model_folder_name, model_name) 65 self.load(model_file_path, device=device)
79 log(f"{NODE_NAME} is skip, because No Input.", message_type='error') 80 return (None, None) 81 82 if random_output: 83 index = random.randint(1, len(output_list)) 84 output = output_list[index - 1] 85 86 ret_image = None
9 10 import os 11 import sys 12 sys.path.append(os.path.dirname(os.path.abspath(__file__))) 13 import pickle 14 import copy 15 import re 16 import json
86 pickle.dump(obj, f) 87 88 def load_pickle(file_name:str) -> object: 89 with open(file_name, 'rb') as f: 90 obj = pickle.load(f) 91 return obj 92 93 def load_light_leak_images() -> list:
1403 net = BriaRMBG() 1404 model_path = "" 1405 try: 1406 model_path = os.path.join(os.path.normpath(folder_paths.folder_names_and_paths['rmbg'][0][0]), "model.pth") 1407 except: 1408 pass 1409 if not os.path.exists(model_path): 1410 model_path = os.path.join(folder_paths.models_dir, "rmbg", "RMBG-1.4", "model.pth") 1411 if not os.path.exists(model_path):
1547 1548 model_file_path = "" 1549 try: 1550 model_file_path = os.path.join(os.path.normpath(folder_paths.folder_names_and_paths[model_folder_name][0][0]), model_name) 1551 except: 1552 pass 1553 if not os.path.exists(model_file_path): 1554 model_file_path = os.path.join(folder_paths.models_dir, model_folder_name, model_name) 1555
1847 return replace_case(old, new, text[:index] + new + text[index + len(old):]) 1848 1849 def random_numbers(total:int, random_range:int, seed:int=0, sum_of_numbers:int=0) -> list: 1850 random.seed(seed) 1851 numbers = [random.randint(-random_range//2, random_range//2) for _ in range(total - 1)] 1852 avg = sum(numbers) // total 1853 ret_list = [] 1854 for i in numbers:
1902 1903 return target_width, target_height 1904 1905 def generate_random_name(prefix:str, suffix:str, length:int) -> str: 1906 name = ''.join(random.choice("abcdefghijklmnopqrstupvxyz1234567890") for x in range(length)) 1907 return prefix + name + suffix 1908 1909 def check_image_file(file_name:str, interval:int) -> object:
15 torch._C._jit_override_can_fuse_on_cpu(False) 16 torch._C._jit_override_can_fuse_on_gpu(False) 17 torch._C._jit_set_texpr_fuser_enabled(False) 18 torch._C._jit_set_nvfuser_enabled(False) 19 except: 20 pass 21 22 23 import uvicorn
76 from rich.console import Console 77 78 console = Console() 79 rich_available = True 80 except Exception: 81 pass 82 83 def handle_exception(request: Request, e: Exception): 84 err = {
16 torch._C._jit_override_can_fuse_on_cpu(False) 17 torch._C._jit_override_can_fuse_on_gpu(False) 18 torch._C._jit_set_texpr_fuser_enabled(False) 19 torch._C._jit_set_nvfuser_enabled(False) 20 except: 21 pass 22 23 NUM_THREADS = str(4) 24
107 if cache_file.exists(): 108 try: 109 with open(cache_file, "r", encoding="utf-8") as f: 110 model_type_cache = json.load(f) 111 assert isinstance(model_type_cache, dict) 112 except: 113 pass 114
108 try: 109 with open(cache_file, "r", encoding="utf-8") as f: 110 model_type_cache = json.load(f) 111 assert isinstance(model_type_cache, dict) 112 except: 113 pass 114 115 res = [] 116 for it in stable_diffusion_dir.glob(f"*.*"):
139 if sdxl_cache_file.exists(): 140 try: 141 with open(sdxl_cache_file, "r", encoding="utf-8") as f: 142 sdxl_model_type_cache = json.load(f) 143 assert isinstance(sdxl_model_type_cache, dict) 144 except: 145 pass 146
140 try: 141 with open(sdxl_cache_file, "r", encoding="utf-8") as f: 142 sdxl_model_type_cache = json.load(f) 143 assert isinstance(sdxl_model_type_cache, dict) 144 except: 145 pass 146 147 for it in stable_diffusion_xl_dir.glob(f"*.*"): 148 if it.suffix not in [".safetensors", ".ckpt"]:
195 for it in cache_dir.glob("**/*/model_index.json"): 196 with open(it, "r", encoding="utf-8") as f: 197 try: 198 data = json.load(f) 199 except: 200 continue 201 202 _class_name = data["_class_name"] 203 name = folder_name_to_show_name(it.parent.parent.parent.name)
8 def generate_filename(directory: Path, original_filename, *options) -> str: 9 text = str(directory.absolute()) + original_filename 10 for v in options: 11 text += "%s" % v 12 md5_hash = hashlib.md5() 13 md5_hash.update(text.encode("utf-8")) 14 return md5_hash.hexdigest() + ".jpg" 15
17 from .const import DEFAULT_MODEL_DIR 18 19 20 def md5sum(filename): 21 md5 = hashlib.md5() 22 with open(filename, "rb") as f: 23 for chunk in iter(lambda: f.read(128 * md5.block_size), b""): 24 md5.update(chunk)
166 infos = image.info 167 168 try: 169 image = ImageOps.exif_transpose(image) 170 except: 171 pass 172 173 if gray: 174 image = image.convert("L")
229 out_height = ceil_modulo(height, mod) 230 out_width = ceil_modulo(width, mod) 231 232 if min_size is not None: 233 assert min_size % mod == 0 234 out_width = max(min_size, out_width) 235 out_height = max(min_size, out_height) 236
318 319 alpha_channel = None 320 try: 321 image = ImageOps.exif_transpose(image) 322 except: 323 pass 324 # exif_transpose will remove exif rotate info,we must call image.info after exif_transpose 325 infos = image.info 326
1 import subprocess 2 import sys 3 4 5 def install(package): 6 subprocess.check_call([sys.executable, "-m", "pip", "install", package]) 7 8
2 import sys 3 4 5 def install(package): 6 subprocess.check_call([sys.executable, "-m", "pip", "install", package]) 7 8 9 def install_plugins_package():
138 "", 139 ) 140 if isinstance(image, str): 141 image = cv2.imread(image)[..., ::-1] 142 assert image is not None, f"Can't read ori_image image from{image}!" 143 elif isinstance(image, torch.Tensor): 144 image = image.cpu().numpy() 145 else:
142 assert image is not None, f"Can't read ori_image image from{image}!" 143 elif isinstance(image, torch.Tensor): 144 image = image.cpu().numpy() 145 else: 146 assert isinstance( 147 image, np.ndarray 148 ), f"Unknown format of ori_image: {type(image)}" 149 edit_image = image.clip(1, 255) # for mask reason 150 edit_image = check_channels(edit_image) 151 # edit_image = resize_image(
156 if masked_image is None: 157 pos_imgs = np.zeros((w, h, 1)) 158 if isinstance(masked_image, str): 159 masked_image = cv2.imread(masked_image)[..., ::-1] 160 assert ( 161 masked_image is not None 162 ), f"Can't read draw_pos image from{masked_image}!" 163 pos_imgs = 255 - masked_image 164 elif isinstance(masked_image, torch.Tensor): 165 pos_imgs = masked_image.cpu().numpy()
163 pos_imgs = 255 - masked_image 164 elif isinstance(masked_image, torch.Tensor): 165 pos_imgs = masked_image.cpu().numpy() 166 else: 167 assert isinstance( 168 masked_image, np.ndarray 169 ), f"Unknown format of draw_pos: {type(masked_image)}" 170 pos_imgs = 255 - masked_image 171 pos_imgs = pos_imgs[..., 0:1] 172 pos_imgs = cv2.convertScaleAbs(pos_imgs)
92 use_linear_in_transformer=False, 93 ): 94 super().__init__() 95 if use_spatial_transformer: 96 assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' 97 98 if context_dim is not None: 99 assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
95 if use_spatial_transformer: 96 assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' 97 98 if context_dim is not None: 99 assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' 100 from omegaconf.listconfig import ListConfig 101 if type(context_dim) == ListConfig: 102 context_dim = list(context_dim)
104 if num_heads_upsample == -1: 105 num_heads_upsample = num_heads 106 107 if num_heads == -1: 108 assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' 109 110 if num_head_channels == -1: 111 assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
107 if num_heads == -1: 108 assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' 109 110 if num_head_channels == -1: 111 assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' 112 self.dims = dims 113 self.image_size = image_size 114 self.in_channels = in_channels
121 "as a list/tuple (per-level) with the same length as channel_mult") 122 self.num_res_blocks = num_res_blocks 123 if disable_self_attentions is not None: 124 # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not 125 assert len(disable_self_attentions) == len(channel_mult) 126 if num_attention_blocks is not None: 127 assert len(num_attention_blocks) == len(self.num_res_blocks) 128 assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
123 if disable_self_attentions is not None: 124 # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not 125 assert len(disable_self_attentions) == len(channel_mult) 126 if num_attention_blocks is not None: 127 assert len(num_attention_blocks) == len(self.num_res_blocks) 128 assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) 129 print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " 130 f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
124 # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not 125 assert len(disable_self_attentions) == len(channel_mult) 126 if num_attention_blocks is not None: 127 assert len(num_attention_blocks) == len(self.num_res_blocks) 128 assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) 129 print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " 130 f"This option has LESS priority than attention_resolutions {attention_resolutions}, " 131 f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
410 positions = batch[self.position_key] 411 n_lines = batch['n_lines'] 412 language = batch['language'] 413 texts = batch['texts'] 414 assert len(glyphs) == len(positions) 415 for i in range(len(glyphs)): 416 if bs is not None: 417 glyphs[i] = glyphs[i][:bs]
439 info['inv_mask'] = inv_mask 440 return x, dict(c_crossattn=[c], c_concat=[control], text_info=info) 441 442 def apply_model(self, x_noisy, t, cond, *args, **kwargs): 443 assert isinstance(cond, dict) 444 diffusion_model = self.model.diffusion_model 445 _cond = torch.cat(cond['c_crossattn'], 1) 446 _hint = torch.cat(cond['c_concat'], 1)
481 c = c.mode() 482 else: 483 c = self.cond_stage_model(c) 484 else: 485 assert hasattr(self.cond_stage_model, self.cond_stage_forward) 486 c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) 487 return c 488
35 num_ddpm_timesteps=self.ddpm_num_timesteps, 36 verbose=verbose, 37 ) 38 alphas_cumprod = self.model.alphas_cumprod 39 assert ( 40 alphas_cumprod.shape[0] == self.ddpm_num_timesteps 41 ), "alphas have to be defined for each timestep" 42 to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device) 43 44 self.register_buffer("betas", to_torch(self.model.betas))
226 index = total_steps - i - 1 227 ts = torch.full((b,), step, device=device, dtype=torch.long) 228 229 if mask is not None: 230 assert x0 is not None 231 img_orig = self.model.q_sample( 232 x0, ts 233 ) # TODO: deterministic forward pass?
233 ) # TODO: deterministic forward pass? 234 img = img_orig * mask + (1.0 - mask) * img 235 236 if ucg_schedule is not None: 237 assert len(ucg_schedule) == len(time_range) 238 unconditional_guidance_scale = ucg_schedule[i] 239 240 outs = self.p_sample_ddim(
298 else: 299 e_t = model_output 300 301 if score_corrector is not None: 302 assert self.model.parameterization == "eps", "not implemented" 303 e_t = score_corrector.modify_score( 304 self.model, e_t, x, t, c, **corrector_kwargs 305 )
366 else self.ddim_timesteps 367 ) 368 num_reference_steps = timesteps.shape[0] 369 370 assert t_enc <= num_reference_steps 371 num_steps = t_enc 372 373 if use_original_steps:
386 ) 387 if unconditional_guidance_scale == 1.0: 388 noise_pred = self.model.apply_model(x_next, t, c) 389 else: 390 assert unconditional_conditioning is not None 391 e_t_uncond, noise_pred = torch.chunk( 392 self.model.apply_model( 393 torch.cat((x_next, x_next)),
11 def get_clip_token_for_string(tokenizer, string): 12 batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, 13 return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 14 tokens = batch_encoding["input_ids"] 15 assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string" 16 return tokens[0, 1] 17 18
17 18 19 def get_bert_token_for_string(tokenizer, string): 20 token = tokenizer(string) 21 assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" 22 token = token[0, 1] 23 return token 24
88 if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder 89 get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) 90 token_dim = 768 91 if hasattr(embedder, 'vit'): 92 assert emb_type == 'vit' 93 self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor) 94 self.get_recog_emb = None 95 else: # using LDM's BERT encoder
124 125 # img: CHW 126 def resize_norm_img(self, img, max_wh_ratio): 127 imgC, imgH, imgW = self.rec_image_shape 128 assert imgC == img.shape[0] 129 imgW = int((imgH * max_wh_ratio)) 130 131 h, w = img.shape[1:]
149 150 # img_list: list of tensors with shape chw 0-255 151 def pred_imglist(self, img_list, show_debug=False, is_ori=False): 152 img_num = len(img_list) 153 assert img_num > 0 154 # Calculate the aspect ratio of all text bars 155 width_list = [] 156 for img in img_list:
27 self.image_key = image_key 28 self.encoder = Encoder(**ddconfig) 29 self.decoder = Decoder(**ddconfig) 30 self.loss = instantiate_from_config(lossconfig) 31 assert ddconfig["double_z"] 32 self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 33 self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 34 self.embed_dim = embed_dim
32 self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 33 self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 34 self.embed_dim = embed_dim 35 if colorize_nlabels is not None: 36 assert type(colorize_nlabels)==int 37 self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 38 if monitor is not None: 39 self.monitor = monitor
40 41 self.use_ema = ema_decay is not None 42 if self.use_ema: 43 self.ema_decay = ema_decay 44 assert 0. < ema_decay < 1. 45 self.model_ema = LitEma(self, decay=ema_decay) 46 print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 47
170 if not only_inputs: 171 xrec, posterior = self(x) 172 if x.shape[1] > 3: 173 # colorize with random projection 174 assert xrec.shape[1] > 3 175 x = self.to_rgb(x) 176 xrec = self.to_rgb(xrec) 177 log["samples"] = self.decode(torch.randn_like(posterior.sample()))
180 with self.ema_scope(): 181 xrec_ema, posterior_ema = self(x) 182 if x.shape[1] > 3: 183 # colorize with random projection 184 assert xrec_ema.shape[1] > 3 185 xrec_ema = self.to_rgb(xrec_ema) 186 log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) 187 log["reconstructions_ema"] = xrec_ema
188 log["inputs"] = x 189 return log 190 191 def to_rgb(self, x): 192 assert self.image_key == "segmentation" 193 if not hasattr(self, "colorize"): 194 self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 195 x = F.conv2d(x, weight=self.colorize)
23 def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 24 self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 25 num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 26 alphas_cumprod = self.model.alphas_cumprod 27 assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 28 to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 29 30 self.register_buffer('betas', to_torch(self.model.betas))
152 index = total_steps - i - 1 153 ts = torch.full((b,), step, device=device, dtype=torch.long) 154 155 if mask is not None: 156 assert x0 is not None 157 img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 158 img = img_orig * mask + (1. - mask) * img 159
157 img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 158 img = img_orig * mask + (1. - mask) * img 159 160 if ucg_schedule is not None: 161 assert len(ucg_schedule) == len(time_range) 162 unconditional_guidance_scale = ucg_schedule[i] 163 164 outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
193 else: 194 x_in = torch.cat([x] * 2) 195 t_in = torch.cat([t] * 2) 196 if isinstance(c, dict): 197 assert isinstance(unconditional_conditioning, dict) 198 c_in = dict() 199 for k in c: 200 if isinstance(c[k], list):
220 unconditional_conditioning[k], 221 c[k]]) 222 elif isinstance(c, list): 223 c_in = list() 224 assert isinstance(unconditional_conditioning, list) 225 for i in range(len(c)): 226 c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) 227 else:
234 else: 235 e_t = model_output 236 237 if score_corrector is not None: 238 assert self.model.parameterization == "eps", 'not implemented' 239 e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 240 241 alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
272 def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, 273 unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): 274 num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] 275 276 assert t_enc <= num_reference_steps 277 num_steps = t_enc 278 279 if use_original_steps:
290 t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) 291 if unconditional_guidance_scale == 1.: 292 noise_pred = self.model.apply_model(x_next, t, c) 293 else: 294 assert unconditional_conditioning is not None 295 e_t_uncond, noise_pred = torch.chunk( 296 self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), 297 torch.cat((unconditional_conditioning, c))), 2)
107 reset_ema=False, 108 reset_num_ema_updates=False, 109 ): 110 super().__init__() 111 assert parameterization in [ 112 "eps", 113 "x0", 114 "v", 115 ], 'currently only supporting "eps" and "x0" and "v"' 116 self.parameterization = parameterization 117 print( 118 f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode"
142 if monitor is not None: 143 self.monitor = monitor 144 self.make_it_fit = make_it_fit 145 if reset_ema: 146 assert exists(ckpt_path) 147 if ckpt_path is not None: 148 self.init_from_ckpt( 149 ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet
148 self.init_from_ckpt( 149 ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet 150 ) 151 if reset_ema: 152 assert self.use_ema 153 print( 154 f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint." 155 )
157 if reset_num_ema_updates: 158 print( 159 " +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ " 160 ) 161 assert self.use_ema 162 self.model_ema.reset_num_updates() 163 164 self.register_schedule(
210 (timesteps,) = betas.shape 211 self.num_timesteps = int(timesteps) 212 self.linear_start = linear_start 213 self.linear_end = linear_end 214 assert ( 215 alphas_cumprod.shape[0] == self.num_timesteps 216 ), "alphas have to be defined for each timestep" 217 218 to_torch = partial(torch.tensor, dtype=torch.float32) 219
284 else: 285 raise NotImplementedError("mu not supported") 286 lvlb_weights[0] = lvlb_weights[1] 287 self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) 288 assert not torch.isnan(self.lvlb_weights).all() 289 290 @contextmanager 291 def ema_scope(self, context=None):
330 if not name in sd: 331 continue 332 old_shape = sd[name].shape 333 new_shape = param.shape 334 assert len(old_shape) == len(new_shape) 335 if len(new_shape) > 2: 336 # we only modify first two axes 337 assert new_shape[2:] == old_shape[2:]
333 new_shape = param.shape 334 assert len(old_shape) == len(new_shape) 335 if len(new_shape) > 2: 336 # we only modify first two axes 337 assert new_shape[2:] == old_shape[2:] 338 # assumes first axis corresponds to output dim 339 if not new_shape == old_shape: 340 new_param = param.clone()
689 ): 690 self.force_null_conditioning = force_null_conditioning 691 self.num_timesteps_cond = default(num_timesteps_cond, 1) 692 self.scale_by_std = scale_by_std 693 assert self.num_timesteps_cond <= kwargs["timesteps"] 694 # for backwards compatibility after implementation of DiffusionWrapper 695 if conditioning_key is None: 696 conditioning_key = "concat" if concat_mode else "crossattn"
725 if ckpt_path is not None: 726 self.init_from_ckpt(ckpt_path, ignore_keys) 727 self.restarted_from_ckpt = True 728 if reset_ema: 729 assert self.use_ema 730 print( 731 f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint." 732 )
734 if reset_num_ema_updates: 735 print( 736 " +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ " 737 ) 738 assert self.use_ema 739 self.model_ema.reset_num_updates() 740 741 def make_cond_schedule(
760 and self.global_step == 0 761 and batch_idx == 0 762 and not self.restarted_from_ckpt 763 ): 764 assert ( 765 self.scale_factor == 1.0 766 ), "rather not use custom rescaling and std-rescaling simultaneously" 767 # set rescale weight to 1./std of encodings 768 print("### USING STD-RESCALING ###") 769 x = super().get_input(batch, self.first_stage_key)
814 self.cond_stage_model.train = disabled_train 815 for param in self.cond_stage_model.parameters(): 816 param.requires_grad = False 817 else: 818 assert config != "__is_first_stage__" 819 assert config != "__is_unconditional__" 820 model = instantiate_from_config(config) 821 self.cond_stage_model = model
815 for param in self.cond_stage_model.parameters(): 816 param.requires_grad = False 817 else: 818 assert config != "__is_first_stage__" 819 assert config != "__is_unconditional__" 820 model = instantiate_from_config(config) 821 self.cond_stage_model = model 822
858 c = c.mode() 859 else: 860 c = self.cond_stage_model(c) 861 else: 862 assert hasattr(self.cond_stage_model, self.cond_stage_forward) 863 c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) 864 return c 865
1099 0, self.num_timesteps, (x.shape[0],), device=self.device 1100 ).long() 1101 # t = torch.randint(500, 501, (x.shape[0],), device=self.device).long() 1102 if self.model.conditioning_key is not None: 1103 assert c is not None 1104 if self.cond_stage_trainable: 1105 c = self.get_learned_conditioning(c) 1106 if self.shorten_cond_schedule: # TODO: drop this option
1164 t_in = t 1165 model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) 1166 1167 if score_corrector is not None: 1168 assert self.parameterization == "eps" 1169 model_out = score_corrector.modify_score( 1170 self, model_out, x, t, c, **corrector_kwargs 1171 )
1312 1313 for i in iterator: 1314 ts = torch.full((b,), i, device=self.device, dtype=torch.long) 1315 if self.shorten_cond_schedule: 1316 assert self.model.conditioning_key != "hybrid" 1317 tc = self.cond_ids[ts].to(cond.device) 1318 cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) 1319
1329 score_corrector=score_corrector, 1330 corrector_kwargs=corrector_kwargs, 1331 ) 1332 if mask is not None: 1333 assert x0 is not None 1334 img_orig = self.q_sample(x0, ts) 1335 img = img_orig * mask + (1.0 - mask) * img 1336
1380 else reversed(range(0, timesteps)) 1381 ) 1382 1383 if mask is not None: 1384 assert x0 is not None 1385 assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match 1386 1387 for i in iterator:
1381 ) 1382 1383 if mask is not None: 1384 assert x0 is not None 1385 assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match 1386 1387 for i in iterator: 1388 ts = torch.full((b,), i, device=device, dtype=torch.long)
1386 1387 for i in iterator: 1388 ts = torch.full((b,), i, device=device, dtype=torch.long) 1389 if self.shorten_cond_schedule: 1390 assert self.model.conditioning_key != "hybrid" 1391 tc = self.cond_ids[ts].to(cond.device) 1392 cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) 1393
1701 print("Diffusion model optimizing logvar") 1702 params.append(self.logvar) 1703 opt = torch.optim.AdamW(params, lr=lr) 1704 if self.use_scheduler: 1705 assert "target" in self.scheduler_config 1706 scheduler = instantiate_from_config(self.scheduler_config) 1707 1708 print("Setting up LambdaLR scheduler...")
1733 "sequential_crossattn", False 1734 ) 1735 self.diffusion_model = instantiate_from_config(diff_model_config) 1736 self.conditioning_key = conditioning_key 1737 assert self.conditioning_key in [ 1738 None, 1739 "concat", 1740 "crossattn", 1741 "hybrid", 1742 "adm", 1743 "hybrid-adm", 1744 "crossattn-adm", 1745 ] 1746 1747 def forward( 1748 self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None
1762 xc = torch.cat([x] + c_concat, dim=1) 1763 cc = torch.cat(c_crossattn, 1) 1764 out = self.diffusion_model(xc, t, context=cc) 1765 elif self.conditioning_key == "hybrid-adm": 1766 assert c_adm is not None 1767 xc = torch.cat([x] + c_concat, dim=1) 1768 cc = torch.cat(c_crossattn, 1) 1769 out = self.diffusion_model(xc, t, context=cc, y=c_adm)
1767 xc = torch.cat([x] + c_concat, dim=1) 1768 cc = torch.cat(c_crossattn, 1) 1769 out = self.diffusion_model(xc, t, context=cc, y=c_adm) 1770 elif self.conditioning_key == "crossattn-adm": 1771 assert c_adm is not None 1772 cc = torch.cat(c_crossattn, 1) 1773 out = self.diffusion_model(x, t, context=cc, y=c_adm) 1774 elif self.conditioning_key == "adm":
1790 **kwargs, 1791 ): 1792 super().__init__(*args, **kwargs) 1793 # assumes that neither the cond_stage nor the low_scale_model contain trainable params 1794 assert not self.cond_stage_trainable 1795 self.instantiate_low_stage(low_scale_config) 1796 self.low_scale_key = low_scale_key 1797 self.noise_level_key = noise_level_key
1930 # maybe guide away from empty text label and highest noise level and maximally degraded zx? 1931 uc = dict() 1932 for k in c: 1933 if k == "c_crossattn": 1934 assert isinstance(c[k], list) and len(c[k]) == 1 1935 uc[k] = [uc_tmp] 1936 elif k == "c_adm": # todo: only run with text-based guidance? 1937 assert isinstance(c[k], torch.Tensor)
1933 if k == "c_crossattn": 1934 assert isinstance(c[k], list) and len(c[k]) == 1 1935 uc[k] = [uc_tmp] 1936 elif k == "c_adm": # todo: only run with text-based guidance? 1937 assert isinstance(c[k], torch.Tensor) 1938 # uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level 1939 uc[k] = c[k] 1940 elif isinstance(c[k], list):
2000 self.keep_dims = keep_finetune_dims 2001 self.c_concat_log_start = c_concat_log_start 2002 self.c_concat_log_end = c_concat_log_end 2003 if exists(self.finetune_keys): 2004 assert exists(ckpt_path), "can only finetune from a given checkpoint" 2005 if exists(ckpt_path): 2006 self.init_from_ckpt(ckpt_path, ignore_keys) 2007
2024 print( 2025 f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only" 2026 ) 2027 new_entry = torch.zeros_like(param) # zero init 2028 assert exists(new_entry), "did not find matching parameter to modify" 2029 new_entry[:, : self.keep_dims, ...] = sd[k] 2030 sd[k] = new_entry 2031
2176 **kwargs, 2177 ): 2178 super().__init__(concat_keys, *args, **kwargs) 2179 self.masked_image_key = masked_image_key 2180 assert self.masked_image_key in concat_keys 2181 2182 @torch.no_grad() 2183 def get_input(
2183 def get_input( 2184 self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False 2185 ): 2186 # note: restricted to non-trainable encoders currently 2187 assert ( 2188 not self.cond_stage_trainable 2189 ), "trainable cond stages not yet supported for inpainting" 2190 z, c, x, xrec, xc = super().get_input( 2191 batch, 2192 self.first_stage_key,
2195 return_original_cond=True, 2196 bs=bs, 2197 ) 2198 2199 assert exists(self.concat_keys) 2200 c_cat = list() 2201 for ck in self.concat_keys: 2202 cc = (
2244 def get_input( 2245 self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False 2246 ): 2247 # note: restricted to non-trainable encoders currently 2248 assert ( 2249 not self.cond_stage_trainable 2250 ), "trainable cond stages not yet supported for depth2img" 2251 z, c, x, xrec, xc = super().get_input( 2252 batch, 2253 self.first_stage_key,
2256 return_original_cond=True, 2257 bs=bs, 2258 ) 2259 2260 assert exists(self.concat_keys) 2261 assert len(self.concat_keys) == 1 2262 c_cat = list() 2263 for ck in self.concat_keys:
2257 bs=bs, 2258 ) 2259 2260 assert exists(self.concat_keys) 2261 assert len(self.concat_keys) == 1 2262 c_cat = list() 2263 for ck in self.concat_keys: 2264 cc = batch[ck]
2313 self.reshuffle_patch_size = reshuffle_patch_size 2314 self.low_scale_model = None 2315 if low_scale_config is not None: 2316 print("Initializing a low-scale model") 2317 assert exists(low_scale_key) 2318 self.instantiate_low_stage(low_scale_config) 2319 self.low_scale_key = low_scale_key 2320
2329 def get_input( 2330 self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False 2331 ): 2332 # note: restricted to non-trainable encoders currently 2333 assert ( 2334 not self.cond_stage_trainable 2335 ), "trainable cond stages not yet supported for upscaling-ft" 2336 z, c, x, xrec, xc = super().get_input( 2337 batch, 2338 self.first_stage_key,
2341 return_original_cond=True, 2342 bs=bs, 2343 ) 2344 2345 assert exists(self.concat_keys) 2346 assert len(self.concat_keys) == 1 2347 # optionally make spatial noise_level here 2348 c_cat = list()
2342 bs=bs, 2343 ) 2344 2345 assert exists(self.concat_keys) 2346 assert len(self.concat_keys) == 1 2347 # optionally make spatial noise_level here 2348 c_cat = list() 2349 noise_level = None
2350 for ck in self.concat_keys: 2351 cc = batch[ck] 2352 cc = rearrange(cc, "b h w c -> b c h w") 2353 if exists(self.reshuffle_patch_size): 2354 assert isinstance(self.reshuffle_patch_size, int) 2355 cc = rearrange( 2356 cc, 2357 "b c (p1 h) (p2 w) -> b (p1 p2 c) h w",
79 if schedule == 'discrete': 80 if betas is not None: 81 log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) 82 else: 83 assert alphas_cumprod is not None 84 log_alphas = 0.5 * torch.log(alphas_cumprod) 85 self.total_N = len(log_alphas) 86 self.T = 1.
294 t_continuous = t_continuous.expand((x.shape[0])) 295 if guidance_type == "uncond": 296 return noise_pred_fn(x, t_continuous) 297 elif guidance_type == "classifier": 298 assert classifier_fn is not None 299 t_input = get_model_input_time(t_continuous) 300 cond_grad = cond_grad_fn(x, t_input) 301 sigma_t = noise_schedule.marginal_std(t_continuous)
310 c_in = torch.cat([unconditional_condition, condition]) 311 noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) 312 return noise_uncond + guidance_scale * (noise - noise_uncond) 313 314 assert model_type in ["noise", "x_start", "v"] 315 assert guidance_type in ["uncond", "classifier", "classifier-free"] 316 return model_fn 317
311 noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) 312 return noise_uncond + guidance_scale * (noise - noise_uncond) 313 314 assert model_type in ["noise", "x_start", "v"] 315 assert guidance_type in ["uncond", "classifier", "classifier-free"] 316 return model_fn 317 318
1041 with torch.no_grad(): 1042 x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, 1043 solver_type=solver_type) 1044 elif method == 'multistep': 1045 assert steps >= order 1046 timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) 1047 assert timesteps.shape[0] - 1 == steps 1048 with torch.no_grad():
1043 solver_type=solver_type) 1044 elif method == 'multistep': 1045 assert steps >= order 1046 timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) 1047 assert timesteps.shape[0] - 1 == steps 1048 with torch.no_grad(): 1049 vec_t = timesteps[0].expand((x.shape[0])) 1050 model_prev_list = [self.model_fn(x, vec_t)]
27 raise ValueError('ddim_eta must be 0 for PLMS') 28 self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 29 num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 30 alphas_cumprod = self.model.alphas_cumprod 31 assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 32 to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 33 34 self.register_buffer('betas', to_torch(self.model.betas))
148 ts = torch.full((b,), step, device=device, dtype=torch.long) 149 ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 150 151 if mask is not None: 152 assert x0 is not None 153 img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 154 img = img_orig * mask + (1. - mask) * img 155
191 e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 192 e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 193 194 if score_corrector is not None: 195 assert self.model.parameterization == "eps" 196 e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 197 198 return e_t
13 Build sinusoidal embeddings. 14 This matches the implementation in tensor2tensor, but differs slightly 15 from the description in Section 3.5 of "Attention Is All You Need". 16 """ 17 assert len(timesteps.shape) == 1 18 19 half_dim = embedding_dim // 2 20 emb = math.log(10000) / (half_dim - 1)
228 return x + h_ 229 230 231 def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): 232 assert attn_type in [ 233 "vanilla", 234 "vanilla-xformers", 235 "memory-efficient-cross-attn", 236 "linear", 237 "none", 238 ], f"attn_type {attn_type} unknown" 239 assert attn_kwargs is None 240 if hasattr(torch.nn.functional, "scaled_dot_product_attention"): 241 # print(f"Using torch.nn.functional.scaled_dot_product_attention")
235 "memory-efficient-cross-attn", 236 "linear", 237 "none", 238 ], f"attn_type {attn_type} unknown" 239 assert attn_kwargs is None 240 if hasattr(torch.nn.functional, "scaled_dot_product_attention"): 241 # print(f"Using torch.nn.functional.scaled_dot_product_attention") 242 return AttnBlock2_0(in_channels)
371 # assume aligned context, cat along channel axis 372 x = torch.cat((x, context), dim=1) 373 if self.use_timestep: 374 # timestep embedding 375 assert t is not None 376 temb = get_timestep_embedding(t, self.ch) 377 temb = self.temb.dense[0](temb) 378 temb = nonlinearity(temb)
917 918 class Upsampler(nn.Module): 919 def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): 920 super().__init__() 921 assert out_size >= in_size 922 num_blocks = int(np.log2(out_size // in_size)) + 1 923 factor_up = 1.0 + (out_size % in_size) 924 print(
956 print( 957 f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" 958 ) 959 raise NotImplementedError() 960 assert in_channels is not None 961 # no asymmetric padding in torch conv, must do it ourselves 962 self.conv = torch.nn.Conv2d( 963 in_channels, in_channels, kernel_size=4, stride=2, padding=1
105 if use_conv: 106 self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) 107 108 def forward(self, x): 109 assert x.shape[1] == self.channels 110 if self.dims == 3: 111 x = F.interpolate( 112 x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
150 self.op = conv_nd( 151 dims, self.channels, self.out_channels, 3, stride=stride, padding=padding 152 ) 153 else: 154 assert self.channels == self.out_channels 155 self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 156 157 def forward(self, x):
154 assert self.channels == self.out_channels 155 self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 156 157 def forward(self, x): 158 assert x.shape[1] == self.channels 159 return self.op(x) 160 161
293 self.channels = channels 294 if num_head_channels == -1: 295 self.num_heads = num_heads 296 else: 297 assert ( 298 channels % num_head_channels == 0 299 ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 300 self.num_heads = channels // num_head_channels 301 self.use_checkpoint = use_checkpoint 302 self.norm = normalization(channels)
358 :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. 359 :return: an [N x (H * C) x T] tensor after attention. 360 """ 361 bs, width, length = qkv.shape 362 assert width % (3 * self.n_heads) == 0 363 ch = width // (3 * self.n_heads) 364 q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) 365 scale = 1 / math.sqrt(math.sqrt(ch))
390 :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 391 :return: an [N x (H * C) x T] tensor after attention. 392 """ 393 bs, width, length = qkv.shape 394 assert width % (3 * self.n_heads) == 0 395 ch = width // (3 * self.n_heads) 396 q, k, v = qkv.chunk(3, dim=1) 397 scale = 1 / math.sqrt(math.sqrt(ch))
471 use_linear_in_transformer=False, 472 ): 473 super().__init__() 474 if use_spatial_transformer: 475 assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' 476 477 if context_dim is not None: 478 assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
474 if use_spatial_transformer: 475 assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' 476 477 if context_dim is not None: 478 assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' 479 from omegaconf.listconfig import ListConfig 480 if type(context_dim) == ListConfig: 481 context_dim = list(context_dim)
483 if num_heads_upsample == -1: 484 num_heads_upsample = num_heads 485 486 if num_heads == -1: 487 assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' 488 489 if num_head_channels == -1: 490 assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
486 if num_heads == -1: 487 assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' 488 489 if num_head_channels == -1: 490 assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' 491 492 self.image_size = image_size 493 self.in_channels = in_channels
501 "as a list/tuple (per-level) with the same length as channel_mult") 502 self.num_res_blocks = num_res_blocks 503 if disable_self_attentions is not None: 504 # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not 505 assert len(disable_self_attentions) == len(channel_mult) 506 if num_attention_blocks is not None: 507 assert len(num_attention_blocks) == len(self.num_res_blocks) 508 assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
503 if disable_self_attentions is not None: 504 # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not 505 assert len(disable_self_attentions) == len(channel_mult) 506 if num_attention_blocks is not None: 507 assert len(num_attention_blocks) == len(self.num_res_blocks) 508 assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) 509 print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " 510 f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
504 # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not 505 assert len(disable_self_attentions) == len(channel_mult) 506 if num_attention_blocks is not None: 507 assert len(num_attention_blocks) == len(self.num_res_blocks) 508 assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) 509 print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " 510 f"This option has LESS priority than attention_resolutions {attention_resolutions}, " 511 f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
759 :param context: conditioning plugged in via crossattn 760 :param y: an [N] Tensor of labels, if class-conditional. 761 :return: an [N x C x ...] Tensor of outputs. 762 """ 763 assert (y is not None) == ( 764 self.num_classes is not None 765 ), "must specify y if and only if the model is class-conditional" 766 hs = [] 767 t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) 768 emb = self.time_embed(t_emb)
767 t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) 768 emb = self.time_embed(t_emb) 769 770 if self.num_classes is not None: 771 assert y.shape[0] == x.shape[0] 772 emb = emb + self.label_emb(y) 773 774 h = x.type(self.dtype)
25 timesteps, = betas.shape 26 self.num_timesteps = int(timesteps) 27 self.linear_start = linear_start 28 self.linear_end = linear_end 29 assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 30 31 to_torch = partial(torch.tensor, dtype=torch.float32) 32
72 def forward(self, x, noise_level=None): 73 if noise_level is None: 74 noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 75 else: 76 assert isinstance(noise_level, torch.Tensor) 77 z = self.q_sample(x, noise_level) 78 return z, noise_level 79
73 for obj in (mean1, logvar1, mean2, logvar2): 74 if isinstance(obj, torch.Tensor): 75 tensor = obj 76 break 77 assert tensor is not None, "at least one argument must be a Tensor" 78 79 # Force variances to be Tensors. Broadcasting helps convert scalars to 80 # Tensors, but it does not work for torch.exp().
44 sname = self.m_name2s_name[key] 45 shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 else: 48 assert not key in self.m_name2s_name 49 50 def copy_to(self, model): 51 m_param = dict(model.named_parameters())
53 for key in m_param: 54 if m_param[key].requires_grad: 55 m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 else: 57 assert not key in self.m_name2s_name 58 59 def store(self, parameters): 60 """
142 layer="last", 143 layer_idx=None, 144 ): # clip-vit-base-patch32 145 super().__init__() 146 assert layer in self.LAYERS 147 self.tokenizer = CLIPTokenizer.from_pretrained(version) 148 self.transformer = CLIPTextModel.from_pretrained(version) 149 self.device = device
152 self.freeze() 153 self.layer = layer 154 self.layer_idx = layer_idx 155 if layer == "hidden": 156 assert layer_idx is not None 157 assert 0 <= abs(layer_idx) <= 12 158 159 def freeze(self):
153 self.layer = layer 154 self.layer_idx = layer_idx 155 if layer == "hidden": 156 assert layer_idx is not None 157 assert 0 <= abs(layer_idx) <= 12 158 159 def freeze(self): 160 self.transformer = self.transformer.eval()
54 'reshape': Im2Seq, 55 'rnn': EncoderWithRNN, 56 'svtr': EncoderWithSVTR 57 } 58 assert encoder_type in support_encoder_dict, '{} must in {}'.format( 59 encoder_type, support_encoder_dict.keys()) 60 61 self.encoder = support_encoder_dict[encoder_type]( 62 self.encoder_reshape.out_channels,**kwargs)
11 12 class RecModel(nn.Module): 13 def __init__(self, config): 14 super().__init__() 15 assert 'in_channels' in config, 'in_channels must in model config' 16 backbone_type = config.backbone.pop('type') 17 assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}' 18 self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)
13 def __init__(self, config): 14 super().__init__() 15 assert 'in_channels' in config, 'in_channels must in model config' 16 backbone_type = config.backbone.pop('type') 17 assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}' 18 self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone) 19 20 neck_type = config.neck.pop('type')
17 assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}' 18 self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone) 19 20 neck_type = config.neck.pop('type') 21 assert neck_type in neck_dict, f'neck.type must in {neck_dict}' 22 self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck) 23 24 head_type = config.head.pop('type')
21 assert neck_type in neck_dict, f'neck.type must in {neck_dict}' 22 self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck) 23 24 head_type = config.head.pop('type') 25 assert head_type in head_dict, f'head.type must in {head_dict}' 26 self.head = head_dict[head_type](self.neck.out_channels, **config.head) 27 28 self.name = f'RecModel_{backbone_type}_{neck_type}_{head_type}'
213 epsilon=1e-6, 214 prenorm=True): 215 super().__init__() 216 if isinstance(norm_layer, str): 217 self.norm1 = eval(norm_layer)(dim, eps=epsilon) 218 else: 219 self.norm1 = norm_layer(dim) 220 if mixer == 'Global' or mixer == 'Local':
236 raise TypeError("The mixer must be one of [Global, Local, Conv]") 237 238 self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() 239 if isinstance(norm_layer, str): 240 self.norm2 = eval(norm_layer)(dim, eps=epsilon) 241 else: 242 self.norm2 = norm_layer(dim) 243 mlp_hidden_dim = int(dim * mlp_ratio)
320 bias_attr=False)) 321 322 def forward(self, x): 323 B, C, H, W = x.shape 324 assert H == self.img_size[0] and W == self.img_size[1], \ 325 f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 326 x = self.proj(x).flatten(2).permute(0, 2, 1) 327 return x 328
351 stride=stride, 352 padding=1, 353 # weight_attr=ParamAttr(initializer=KaimingNormal()) 354 ) 355 self.norm = eval(sub_norm)(out_channels) 356 if act is not None: 357 self.act = act() 358 else:
425 426 # self.add_parameter("pos_embed", self.pos_embed) 427 428 self.pos_drop = nn.Dropout(p=drop_rate) 429 Block_unit = eval(block_unit) 430 431 dpr = np.linspace(0, drop_path_rate, sum(depth)) 432 self.blocks1 = nn.ModuleList(
440 mlp_ratio=mlp_ratio, 441 qkv_bias=qkv_bias, 442 qk_scale=qk_scale, 443 drop=drop_rate, 444 act_layer=eval(act), 445 attn_drop=attn_drop_rate, 446 drop_path=dpr[0:depth[0]][i], 447 norm_layer=norm_layer,
470 mlp_ratio=mlp_ratio, 471 qkv_bias=qkv_bias, 472 qk_scale=qk_scale, 473 drop=drop_rate, 474 act_layer=eval(act), 475 attn_drop=attn_drop_rate, 476 drop_path=dpr[depth[0]:depth[0] + depth[1]][i], 477 norm_layer=norm_layer,
498 mlp_ratio=mlp_ratio, 499 qkv_bias=qkv_bias, 500 qk_scale=qk_scale, 501 drop=drop_rate, 502 act_layer=eval(act), 503 attn_drop=attn_drop_rate, 504 drop_path=dpr[depth[0] + depth[1]:][i], 505 norm_layer=norm_layer,
518 bias=False) 519 self.hardswish = nn.Hardswish() 520 self.dropout = nn.Dropout(p=last_drop) 521 if not prenorm: 522 self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon) 523 self.use_lenhead = use_lenhead 524 if use_lenhead: 525 self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
316 t = max(cropper_t, image_t) 317 r = min(cropper_r, image_r) 318 b = min(cropper_b, image_b) 319 320 assert ( 321 0 <= l < r and 0 <= t < b 322 ), f"cropper and image not overlap, {l},{t},{r},{b}" 323 324 cropped_image = image[t:b, l:r, :] 325 padding_l = max(0, image_l - cropper_l)
27 num_ddpm_timesteps=self.ddpm_num_timesteps, 28 verbose=verbose, 29 ) 30 alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000]) 31 assert ( 32 alphas_cumprod.shape[0] == self.ddpm_num_timesteps 33 ), "alphas have to be defined for each timestep" 34 to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 35 36 self.register_buffer("betas", to_torch(self.model.betas))
36 ) 37 38 39 def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"): 40 assert isinstance(x, torch.Tensor) 41 return _upfirdn2d_ref( 42 x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain 43 )
45 46 def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): 47 """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.""" 48 # Validate arguments. 49 assert isinstance(x, torch.Tensor) and x.ndim == 4 50 if f is None: 51 f = torch.ones([1, 1], dtype=torch.float32, device=x.device) 52 assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
48 # Validate arguments. 49 assert isinstance(x, torch.Tensor) and x.ndim == 4 50 if f is None: 51 f = torch.ones([1, 1], dtype=torch.float32, device=x.device) 52 assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 53 assert f.dtype == torch.float32 and not f.requires_grad 54 batch_size, num_channels, in_height, in_width = x.shape 55 upx, upy = _parse_scaling(up)
49 assert isinstance(x, torch.Tensor) and x.ndim == 4 50 if f is None: 51 f = torch.ones([1, 1], dtype=torch.float32, device=x.device) 52 assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 53 assert f.dtype == torch.float32 and not f.requires_grad 54 batch_size, num_channels, in_height, in_width = x.shape 55 upx, upy = _parse_scaling(up) 56 downx, downy = _parse_scaling(down)
104 mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. 105 activation="lrelu", # Activation function: 'relu', 'lrelu', etc. 106 conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. 107 ): 108 assert architecture in ["orig", "skip", "resnet"] 109 super().__init__() 110 self.in_channels = in_channels 111 self.cmap_dim = cmap_dim
154 # Conditioning. 155 if self.cmap_dim > 0: 156 x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) 157 158 assert x.dtype == dtype 159 return x, const_e 160 161
180 use_fp16=False, # Use FP16 for this block? 181 fp16_channels_last=False, # Use channels-last memory format with FP16? 182 freeze_layers=0, # Freeze-D: Number of layers to freeze. 183 ): 184 assert in_channels in [0, tmp_channels] 185 assert architecture in ["orig", "skip", "resnet"] 186 super().__init__() 187 self.in_channels = in_channels
181 fp16_channels_last=False, # Use channels-last memory format with FP16? 182 freeze_layers=0, # Freeze-D: Number of layers to freeze. 183 ): 184 assert in_channels in [0, tmp_channels] 185 assert architecture in ["orig", "skip", "resnet"] 186 super().__init__() 187 self.in_channels = in_channels 188 self.resolution = resolution
284 x = self.conv0(x) 285 feat = x.clone() 286 x = self.conv1(x) 287 288 assert x.dtype == dtype 289 return x, img, feat 290 291
420 421 422 def _unbroadcast(x, shape): 423 extra_dims = x.ndim - len(shape) 424 assert extra_dims >= 0 425 dim = [ 426 i 427 for i in range(x.ndim)
430 if len(dim): 431 x = x.sum(dim=dim, keepdim=True) 432 if extra_dims: 433 x = x.reshape(-1, *x.shape[extra_dims + 1 :]) 434 assert x.shape == shape 435 return x 436 437
557 self.noise_strength = torch.nn.Parameter(torch.zeros([])) 558 self.bias = torch.nn.Parameter(torch.zeros([out_channels])) 559 560 def forward(self, x, w, noise_mode="none", fused_modconv=True, gain=1): 561 assert noise_mode in ["random", "const", "none"] 562 in_resolution = self.resolution // self.up 563 styles = self.affine(w) 564
695 if self.architecture == "skip": 696 img = self.torgb(x, mod_vector) 697 img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format) 698 699 assert x.dtype == dtype 700 return x, img 701 702
912 **spectral_kwargs, 913 ): 914 super(FFC, self).__init__() 915 916 assert stride == 1 or stride == 2, "Stride should be 1 or 2." 917 self.stride = stride 918 919 in_cg = int(in_channels * ratio_gin)
1122 1123 1124 class ConcatTupleLayer(nn.Module): 1125 def forward(self, x): 1126 assert isinstance(x, tuple) 1127 x_l, x_g = x 1128 assert torch.is_tensor(x_l) or torch.is_tensor(x_g) 1129 if not torch.is_tensor(x_g):
1124 class ConcatTupleLayer(nn.Module): 1125 def forward(self, x): 1126 assert isinstance(x, tuple) 1127 x_l, x_g = x 1128 assert torch.is_tensor(x_l) or torch.is_tensor(x_g) 1129 if not torch.is_tensor(x_g): 1130 return x_l 1131 return torch.cat(x, dim=1)
1221 use_fp16=False, # Use FP16 for this block? 1222 fp16_channels_last=False, # Use channels-last memory format with FP16? 1223 **layer_kwargs, # Arguments for SynthesisLayer. 1224 ): 1225 assert architecture in ["orig", "skip", "resnet"] 1226 super().__init__() 1227 self.in_channels = in_channels 1228 self.w_dim = w_dim
1375 y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) 1376 img = img.add_(y) if img is not None else y 1377 1378 x = x.to(dtype=dtype) 1379 assert x.dtype == dtype 1380 assert img is None or img.dtype == torch.float32 1381 return x, img 1382
1376 img = img.add_(y) if img is not None else y 1377 1378 x = x.to(dtype=dtype) 1379 assert x.dtype == dtype 1380 assert img is None or img.dtype == torch.float32 1381 return x, img 1382 1383
1392 channel_max=512, # Maximum number of channels in any layer. 1393 num_fp16_res=0, # Use FP16 for the N highest resolutions. 1394 **block_kwargs, # Arguments for SynthesisBlock. 1395 ): 1396 assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 1397 super().__init__() 1398 self.w_dim = w_dim 1399 self.img_resolution = img_resolution
1548 1549 # Apply truncation. 1550 if truncation_psi != 1: 1551 with torch.autograd.profiler.record_function("truncate"): 1552 assert self.w_avg_beta is not None 1553 if self.num_ws is None or truncation_cutoff is None: 1554 x = self.w_avg.lerp(x, truncation_psi) 1555 else:
104 105 def expand_image( 106 cv2_img, top: int, right: int, bottom: int, left: int, softness: float, space: float 107 ): 108 assert cv2_img.shape[2] == 3 109 origin_h, origin_w = cv2_img.shape[:2] 110 new_width = cv2_img.shape[1] + left + right 111 new_height = cv2_img.shape[0] + top + bottom
105 (timesteps,) = betas.shape 106 self.num_timesteps = int(timesteps) 107 self.linear_start = linear_start 108 self.linear_end = linear_end 109 assert ( 110 alphas_cumprod.shape[0] == self.num_timesteps 111 ), "alphas have to be defined for each timestep" 112 113 to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device) 114
170 raise NotImplementedError("mu not supported") 171 # TODO how to choose this term 172 lvlb_weights[0] = lvlb_weights[1] 173 self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) 174 assert not torch.isnan(self.lvlb_weights).all() 175 176 177 class LatentDiffusion(DDPM):
139 140 def forward(self, x, style, noise_mode="random", gain=1): 141 x = self.conv(x, style) 142 143 assert noise_mode in ["random", "const", "none"] 144 145 if self.use_noise: 146 if noise_mode == "random":
429 x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) 430 431 # Apply truncation. 432 if truncation_psi != 1: 433 assert self.w_avg_beta is not None 434 if self.num_ws is None or truncation_cutoff is None: 435 x = self.w_avg.lerp(x, truncation_psi) 436 else:
511 self.img_resolution = img_resolution 512 self.img_channels = img_channels 513 514 resolution_log2 = int(np.log2(img_resolution)) 515 assert img_resolution == 2**resolution_log2 and img_resolution >= 4 516 self.resolution_log2 = resolution_log2 517 518 def nf(stage):
823 if min(self.input_resolution) <= self.window_size: 824 # if window size is larger than input resolution, we don't partition windows 825 self.shift_size = 0 826 self.window_size = min(self.input_resolution) 827 assert ( 828 0 <= self.shift_size < self.window_size 829 ), "shift_size must in 0-window_size" 830 831 if self.shift_size > 0: 832 down_ratio = 1
1644 demodulate=True, 1645 ): 1646 super().__init__() 1647 resolution_log2 = int(np.log2(img_resolution)) 1648 assert img_resolution == 2**resolution_log2 and img_resolution >= 4 1649 1650 self.num_layers = resolution_log2 * 2 - 3 * 2 1651 self.img_resolution = img_resolution
1785 self.img_resolution = img_resolution 1786 self.img_channels = img_channels 1787 1788 resolution_log2 = int(np.log2(img_resolution)) 1789 assert img_resolution == 2**resolution_log2 and img_resolution >= 4 1790 self.resolution_log2 = resolution_log2 1791 1792 if cmap_dim == None:
20 raise ValueError('ddim_eta must be 0 for PLMS') 21 self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 22 num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose) 23 alphas_cumprod = self.model.alphas_cumprod 24 assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 25 to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 26 27 self.register_buffer('betas', to_torch(self.model.betas))
137 ts = torch.full((b,), step, device=device, dtype=torch.long) 138 ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 139 140 if mask is not None: 141 assert x0 is not None 142 img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 143 img = img_orig * mask + (1. - mask) * img 144
174 e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 175 e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 176 177 if score_corrector is not None: 178 assert self.model.parameterization == "eps" 179 e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 180 181 return e_t
94 ) 95 96 # Batch single image 97 if image.ndim == 3: 98 assert ( 99 image.shape[0] == 3 100 ), "Image outside a batch should be of shape (3, H, W)" 101 image = image.unsqueeze(0) 102 103 # Batch and add channel dim for single mask
113 # Batched masks no channel dim 114 else: 115 mask = mask.unsqueeze(1) 116 117 assert ( 118 image.ndim == 4 and mask.ndim == 4 119 ), "Image and Mask must have 4 dimensions" 120 assert ( 121 image.shape[-2:] == mask.shape[-2:] 122 ), "Image and Mask must have the same spatial dimensions"
116 117 assert ( 118 image.ndim == 4 and mask.ndim == 4 119 ), "Image and Mask must have 4 dimensions" 120 assert ( 121 image.shape[-2:] == mask.shape[-2:] 122 ), "Image and Mask must have the same spatial dimensions" 123 assert ( 124 image.shape[0] == mask.shape[0] 125 ), "Image and Mask must have the same batch size"
119 ), "Image and Mask must have 4 dimensions" 120 assert ( 121 image.shape[-2:] == mask.shape[-2:] 122 ), "Image and Mask must have the same spatial dimensions" 123 assert ( 124 image.shape[0] == mask.shape[0] 125 ), "Image and Mask must have the same batch size" 126 127 # Check image is in [-1, 1] 128 if image.min() < -1 or image.max() > 1:
146 raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") 147 148 # Batch single image 149 if image.ndim == 3: 150 assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" 151 image = image.unsqueeze(0) 152 153 # Batch and add channel dim for single mask
163 # Batched masks no channel dim 164 else: 165 mask = mask.unsqueeze(1) 166 167 assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" 168 assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" 169 assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" 170
164 else: 165 mask = mask.unsqueeze(1) 166 167 assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" 168 assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" 169 assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" 170 171 # Check image is in [-1, 1]
165 mask = mask.unsqueeze(1) 166 167 assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" 168 assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" 169 assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" 170 171 # Check image is in [-1, 1] 172 if image.min() < -1 or image.max() > 1:
738 739 for image_ in image: 740 self.check_image(image_, prompt, prompt_embeds) 741 else: 742 assert False 743 744 # Check `controlnet_conditioning_scale` 745 if (
764 "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" 765 " the same length as the number of controlnets" 766 ) 767 else: 768 assert False 769 770 if len(control_guidance_start) != len(control_guidance_end): 771 raise ValueError(
1593 control_images.append(control_image_) 1594 1595 control_image = control_images 1596 else: 1597 assert False 1598 1599 # 4. Preprocess mask and image - resizes image and mask w.r.t height and width 1600 mask, masked_image, init_image = prepare_mask_and_masked_image(
68 Args: 69 tokens (Union[str, List[str]]): The tokens to be added. 70 """ 71 num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs) 72 assert num_added_tokens != 0, ( 73 f"The tokenizer already contains the token {tokens}. Please pass " 74 "a different `placeholder_token` that is not already in the " 75 "tokenizer." 76 ) 77 78 def get_token_info(self, token: str) -> dict: 79 """Get the information of a token, including its start and end index in
278 Args: 279 embeddings (List[dict]): A list of embedding to be check. 280 """ 281 names = [emb["name"] for emb in embeddings] 282 assert len(names) == len(set(names)), ( 283 "Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'" 284 ) 285 286 def check_ids_overlap(self, embeddings): 287 """Check whether overlap exist in token ids of 'external_embeddings'.
293 ids_range.sort() # sort by 'start' 294 # check if 'end' has overlapping 295 for idx in range(len(ids_range) - 1): 296 name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1] 297 assert ids_range[idx][1] <= ids_range[idx + 1][0], ( 298 f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'." 299 ) 300 301 def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]): 302 """Add external embeddings to this layer.
420 # check if the next embedding need to replace is valid 421 actually_ids_to_replace = [ 422 int(i) for i in input_ids[e_idx : e_idx + end - start] 423 ] 424 assert actually_ids_to_replace == target_ids_to_replace, ( 425 f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. " 426 f"Expect '{target_ids_to_replace}' for embedding " 427 f"'{name}' but found '{actually_ids_to_replace}'." 428 ) 429 430 new_embedding.append(ext_emb) 431
452 will be used. Defaults to None. 453 454 input_ids: shape like [bz, LENGTH] or [LENGTH]. 455 """ 456 assert input_ids.ndim in [1, 2] 457 if input_ids.ndim == 1: 458 input_ids = input_ids.unsqueeze(0) 459
493 494 # TODO: support add tokens as dict, then we can load pretrained tokens. 495 """ 496 if initialize_tokens is not None: 497 assert len(initialize_tokens) == len( 498 placeholder_tokens 499 ), "placeholder_token should be the same length as initialize_token" 500 for ii in range(len(placeholder_tokens)): 501 tokenizer.add_placeholder_token( 502 placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token
508 embedding_layer 509 ) 510 embedding_layer = text_encoder.text_model.embeddings.token_embedding 511 512 assert embedding_layer is not None, ( 513 "Do not support get embedding layer for current text encoder. " 514 "Please check your configuration." 515 ) 516 initialize_embedding = [] 517 if initialize_tokens is not None: 518 for ii in range(len(placeholder_tokens)):
165 166 167 def _bias_act_ref(x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None): 168 """Slow reference implementation of `bias_act()` using standard TensorFlow ops.""" 169 assert isinstance(x, torch.Tensor) 170 assert clamp is None or clamp >= 0 171 spec = activation_funcs[act] 172 alpha = float(alpha if alpha is not None else spec.def_alpha)
166 167 def _bias_act_ref(x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None): 168 """Slow reference implementation of `bias_act()` using standard TensorFlow ops.""" 169 assert isinstance(x, torch.Tensor) 170 assert clamp is None or clamp >= 0 171 spec = activation_funcs[act] 172 alpha = float(alpha if alpha is not None else spec.def_alpha) 173 gain = float(gain if gain is not None else spec.def_gain)
174 clamp = float(clamp if clamp is not None else -1) 175 176 # Add bias. 177 if b is not None: 178 assert isinstance(b, torch.Tensor) and b.ndim == 1 179 assert 0 <= dim < x.ndim 180 assert b.shape[0] == x.shape[dim] 181 x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
175 176 # Add bias. 177 if b is not None: 178 assert isinstance(b, torch.Tensor) and b.ndim == 1 179 assert 0 <= dim < x.ndim 180 assert b.shape[0] == x.shape[dim] 181 x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 182
176 # Add bias. 177 if b is not None: 178 assert isinstance(b, torch.Tensor) and b.ndim == 1 179 assert 0 <= dim < x.ndim 180 assert b.shape[0] == x.shape[dim] 181 x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 182 183 # Evaluate activation function.
226 227 Returns: 228 Tensor of the same shape and datatype as `x`. 229 """ 230 assert isinstance(x, torch.Tensor) 231 assert impl in ["ref", "cuda"] 232 return _bias_act_ref( 233 x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp
227 Returns: 228 Tensor of the same shape and datatype as `x`. 229 """ 230 assert isinstance(x, torch.Tensor) 231 assert impl in ["ref", "cuda"] 232 return _bias_act_ref( 233 x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp 234 )
237 def _get_filter_size(f): 238 if f is None: 239 return 1, 1 240 241 assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 242 fw = f.shape[-1] 243 fh = f.shape[0] 244
243 fh = f.shape[0] 244 245 fw = int(fw) 246 fh = int(fh) 247 assert fw >= 1 and fh >= 1 248 return fw, fh 249 250
255 256 def _parse_scaling(scaling): 257 if isinstance(scaling, int): 258 scaling = [scaling, scaling] 259 assert isinstance(scaling, (list, tuple)) 260 assert all(isinstance(x, int) for x in scaling) 261 sx, sy = scaling 262 assert sx >= 1 and sy >= 1
256 def _parse_scaling(scaling): 257 if isinstance(scaling, int): 258 scaling = [scaling, scaling] 259 assert isinstance(scaling, (list, tuple)) 260 assert all(isinstance(x, int) for x in scaling) 261 sx, sy = scaling 262 assert sx >= 1 and sy >= 1 263 return sx, sy
258 scaling = [scaling, scaling] 259 assert isinstance(scaling, (list, tuple)) 260 assert all(isinstance(x, int) for x in scaling) 261 sx, sy = scaling 262 assert sx >= 1 and sy >= 1 263 return sx, sy 264 265
265 266 def _parse_padding(padding): 267 if isinstance(padding, int): 268 padding = [padding, padding] 269 assert isinstance(padding, (list, tuple)) 270 assert all(isinstance(x, int) for x in padding) 271 if len(padding) == 2: 272 padx, pady = padding
266 def _parse_padding(padding): 267 if isinstance(padding, int): 268 padding = [padding, padding] 269 assert isinstance(padding, (list, tuple)) 270 assert all(isinstance(x, int) for x in padding) 271 if len(padding) == 2: 272 padx, pady = padding 273 padding = [padx, padx, pady, pady]
306 # Validate. 307 if f is None: 308 f = 1 309 f = torch.as_tensor(f, dtype=torch.float32) 310 assert f.ndim in [0, 1, 2] 311 assert f.numel() > 0 312 if f.ndim == 0: 313 f = f[np.newaxis]
307 if f is None: 308 f = 1 309 f = torch.as_tensor(f, dtype=torch.float32) 310 assert f.ndim in [0, 1, 2] 311 assert f.numel() > 0 312 if f.ndim == 0: 313 f = f[np.newaxis] 314
316 if separable is None: 317 separable = f.ndim == 1 and f.numel() >= 8 318 if f.ndim == 1 and not separable: 319 f = f.ger(f) 320 assert f.ndim == (1 if separable else 2) 321 322 # Apply normalize, flip, gain, and device. 323 if normalize:
465 466 def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): 467 """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.""" 468 # Validate arguments. 469 assert isinstance(x, torch.Tensor) and x.ndim == 4 470 if f is None: 471 f = torch.ones([1, 1], dtype=torch.float32, device=x.device) 472 assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
468 # Validate arguments. 469 assert isinstance(x, torch.Tensor) and x.ndim == 4 470 if f is None: 471 f = torch.ones([1, 1], dtype=torch.float32, device=x.device) 472 assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 473 assert not f.requires_grad 474 batch_size, num_channels, in_height, in_width = x.shape 475 # upx, upy = _parse_scaling(up)
469 assert isinstance(x, torch.Tensor) and x.ndim == 4 470 if f is None: 471 f = torch.ones([1, 1], dtype=torch.float32, device=x.device) 472 assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 473 assert not f.requires_grad 474 batch_size, num_channels, in_height, in_width = x.shape 475 # upx, upy = _parse_scaling(up) 476 # downx, downy = _parse_scaling(down)
740 Returns: 741 Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 742 """ 743 # Validate arguments. 744 assert isinstance(x, torch.Tensor) and (x.ndim == 4) 745 assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 746 assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2]) 747 assert isinstance(up, int) and (up >= 1)
741 Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 742 """ 743 # Validate arguments. 744 assert isinstance(x, torch.Tensor) and (x.ndim == 4) 745 assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 746 assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2]) 747 assert isinstance(up, int) and (up >= 1) 748 assert isinstance(down, int) and (down >= 1)
742 """ 743 # Validate arguments. 744 assert isinstance(x, torch.Tensor) and (x.ndim == 4) 745 assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 746 assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2]) 747 assert isinstance(up, int) and (up >= 1) 748 assert isinstance(down, int) and (down >= 1) 749 # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}"
743 # Validate arguments. 744 assert isinstance(x, torch.Tensor) and (x.ndim == 4) 745 assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 746 assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2]) 747 assert isinstance(up, int) and (up >= 1) 748 assert isinstance(down, int) and (down >= 1) 749 # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}" 750 out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
744 assert isinstance(x, torch.Tensor) and (x.ndim == 4) 745 assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 746 assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2]) 747 assert isinstance(up, int) and (up >= 1) 748 assert isinstance(down, int) and (down >= 1) 749 # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}" 750 out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 751 fw, fh = _get_filter_size(f)
464 465 for ii in range(b): 466 keep = int((i + 1) / iterations * torch.sum(mask[ii, ...])) 467 468 assert torch.sum(mask[ii][indices[ii, :keep]]) == keep, "Error!!!" 469 mask[ii][indices[ii, :keep]] = 0 470 471 mask = mask.reshape(b, 1, h, w)
80 self._init_session(new_model_name) 81 self.model_name = new_model_name 82 83 def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: 84 img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest() 85 return self.forward(rgb_np_img, req.clicks, img_md5) 86 87 @torch.inference_mode()
213 self.proj = nn.Linear(dim, dim) 214 215 self.use_rel_pos = use_rel_pos 216 if self.use_rel_pos: 217 assert ( 218 input_size is not None 219 ), "Input size must be provided if using relative positional encoding." 220 # initialize relative positional embeddings 221 self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 222 self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
220 self.proj = nn.Linear(dim, dim) 221 222 self.use_rel_pos = use_rel_pos 223 if self.use_rel_pos: 224 assert ( 225 input_size is not None 226 ), "Input size must be provided if using relative positional encoding." 227 # initialize relative positional embeddings 228 self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 229 self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
363 resolution=(14, 14), 364 ): 365 super().__init__() 366 # (h, w) 367 assert isinstance(resolution, tuple) and len(resolution) == 2 368 self.num_heads = num_heads 369 self.scale = key_dim**-0.5 370 self.key_dim = key_dim
465 super().__init__() 466 self.dim = dim 467 self.input_resolution = input_resolution 468 self.num_heads = num_heads 469 assert window_size > 0, "window_size must be greater than 0" 470 self.window_size = window_size 471 self.mlp_ratio = mlp_ratio 472
471 self.mlp_ratio = mlp_ratio 472 473 self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 474 475 assert dim % num_heads == 0, "dim must be divisible by num_heads" 476 head_dim = dim // num_heads 477 478 window_resolution = (window_size, window_size)
496 497 def forward(self, x): 498 H, W = self.input_resolution 499 B, L, C = x.shape 500 assert L == H * W, "input feature has wrong size" 501 res_x = x 502 if H == self.window_size and W == self.window_size: 503 x = self.attn(x)
772 block.apply(lambda x: _set_lr_scale(x, lr_scales[i])) 773 i += 1 774 if layer.downsample is not None: 775 layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1])) 776 assert i == depth 777 for m in [self.norm_head, self.head]: 778 m.apply(lambda x: _set_lr_scale(x, lr_scales[-1])) 779
781 p.param_name = k 782 783 def _check_lr_scale(m): 784 for p in m.parameters(): 785 assert hasattr(p, "lr_scale"), p.param_name 786 787 self.apply(_check_lr_scale) 788
197 super().__init__() 198 self.embedding_dim = embedding_dim 199 self.internal_dim = embedding_dim // downsample_rate 200 self.num_heads = num_heads 201 assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 203 self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
44 image (np.ndarray): The image for calculating masks. Expects an 45 image in HWC uint8 format, with pixel values in [0, 255]. 46 image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 """ 48 assert image_format in [ 49 "RGB", 50 "BGR", 51 ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 if image_format != self.model.image_format: 53 image = image[..., ::-1] 54
77 1x3xHxW, which has been transformed with ResizeLongestSide. 78 original_image_size (tuple(int, int)): The size of the image 79 before transformation, in (H, W) format. 80 """ 81 assert ( 82 len(transformed_image.shape) == 4 83 and transformed_image.shape[1] == 3 84 and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 85 ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 86 self.reset_image() 87 88 self.original_size = original_image_size
139 140 # Transform input prompts 141 coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 142 if point_coords is not None: 143 assert ( 144 point_labels is not None 145 ), "point_labels must be supplied if point_coords is supplied." 146 point_coords = self.transform.apply_coords(point_coords, self.original_size) 147 coords_torch = torch.as_tensor( 148 point_coords, dtype=torch.float, device=self.device
265 if not self.is_image_set: 266 raise RuntimeError( 267 "An image must be set with .set_image(...) to generate an embedding." 268 ) 269 assert ( 270 self.features is not None 271 ), "Features must exist if an image has been set." 272 return self.features 273 274 @property
44 image (np.ndarray): The image for calculating masks. Expects an 45 image in HWC uint8 format, with pixel values in [0, 255]. 46 image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 """ 48 assert image_format in [ 49 "RGB", 50 "BGR", 51 ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 # import pdb;pdb.set_trace() 53 if image_format != self.model.image_format: 54 image = image[..., ::-1]
79 1x3xHxW, which has been transformed with ResizeLongestSide. 80 original_image_size (tuple(int, int)): The size of the image 81 before transformation, in (H, W) format. 82 """ 83 assert ( 84 len(transformed_image.shape) == 4 85 and transformed_image.shape[1] == 3 86 and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 87 ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 88 self.reset_image() 89 90 self.original_size = original_image_size
142 143 # Transform input prompts 144 coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 145 if point_coords is not None: 146 assert ( 147 point_labels is not None 148 ), "point_labels must be supplied if point_coords is supplied." 149 point_coords = self.transform.apply_coords(point_coords, self.original_size) 150 coords_torch = torch.as_tensor( 151 point_coords, dtype=torch.float, device=self.device
272 if not self.is_image_set: 273 raise RuntimeError( 274 "An image must be set with .set_image(...) to generate an embedding." 275 ) 276 assert ( 277 self.features is not None 278 ), "Features must exist if an image has been set." 279 return self.features 280 281 @property
383 @field_validator("sd_seed") 384 @classmethod 385 def sd_seed_validator(cls, v: int) -> int: 386 if v == -1: 387 return random.randint(1, 99999999) 388 return v 389 390 @field_validator("controlnet_conditioning_scale")
7 8 def test_load_png_image(): 9 with open(png_img_p, "rb") as f: 10 np_img, alpha_channel = load_img(f.read()) 11 assert np_img.shape == (256, 256, 3) 12 assert alpha_channel.shape == (256, 256) 13 14
8 def test_load_png_image(): 9 with open(png_img_p, "rb") as f: 10 np_img, alpha_channel = load_img(f.read()) 11 assert np_img.shape == (256, 256, 3) 12 assert alpha_channel.shape == (256, 256) 13 14 15 def test_load_jpg_image():
14 15 def test_load_jpg_image(): 16 with open(jpg_img_p, "rb") as f: 17 np_img, alpha_channel = load_img(f.read()) 18 assert np_img.shape == (394, 448, 3) 19 assert alpha_channel is None
15 def test_load_jpg_image(): 16 with open(jpg_img_p, "rb") as f: 17 np_img, alpha_channel = load_img(f.read()) 18 assert np_img.shape == (394, 448, 3) 19 assert alpha_channel is None
41 enable_controlnet=False, 42 ) 43 ) 44 45 assert "Disable controlnet" in caplog.text 46 47 48 def test_switch_controlnet_method(caplog):
66 controlnet_method=new_method, 67 ) 68 ) 69 70 assert f"Switch Controlnet method from {old_method} to {new_method}" in caplog.text
47 48 res_mask = gen_frontend_mask(bgr_np_img) 49 _save(res_mask, "test_remove_bg_frontend_mask.png") 50 51 assert len(bgr_np_img.shape) == 2 52 _save(bgr_np_img, "test_remove_bg_mask.jpeg") 53 54
56 model = AnimeSeg() 57 img = cv2.imread(str(current_dir / "anime_test.png")) 58 img_base64 = encode_pil_to_base64(Image.fromarray(img), 100, {}) 59 res = model.gen_image(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64)) 60 assert len(res.shape) == 3 61 assert res.shape[-1] == 4 62 _save(res, "test_anime_seg.png") 63
57 img = cv2.imread(str(current_dir / "anime_test.png")) 58 img_base64 = encode_pil_to_base64(Image.fromarray(img), 100, {}) 59 res = model.gen_image(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64)) 60 assert len(res.shape) == 3 61 assert res.shape[-1] == 4 62 _save(res, "test_anime_seg.png") 63 64 res = model.gen_mask(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64))
61 assert res.shape[-1] == 4 62 _save(res, "test_anime_seg.png") 63 64 res = model.gen_mask(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64)) 65 assert len(res.shape) == 2 66 _save(res, "test_anime_seg_mask.png") 67 68
25 26 27 def assert_keys(keys: List[str], infos, res_infos): 28 for k in keys: 29 assert k in infos 30 assert k in res_infos 31 assert infos[k] == res_infos[k] 32
26 27 def assert_keys(keys: List[str], infos, res_infos): 28 for k in keys: 29 assert k in infos 30 assert k in res_infos 31 assert infos[k] == res_infos[k] 32 33
27 def assert_keys(keys: List[str], infos, res_infos): 28 for k in keys: 29 assert k in infos 30 assert k in res_infos 31 assert infos[k] == res_infos[k] 32 33 34 def run_test(file_path, keys):
37 str(save_dir / gt_name), 38 res, 39 [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], 40 ) 41 assert ok, save_dir / gt_name 42 43 """ 44 Note that JPEG is lossy compression, so even if it is the highest quality 100,
44 ret_images = [] 45 light_leak_images = load_light_leak_images() 46 if light == 'random': 47 random.seed(time.time()) 48 light_index = random.randint(0,31) 49 else: 50 light_index = int(light) - 1 51
187 def __call__(self, img: PIL.Image.Image, target: dict): 188 init_boxes = len(target["boxes"]) 189 max_patience = 10 190 for i in range(max_patience): 191 w = random.randint(self.min_size, min(img.width, self.max_size)) 192 h = random.randint(self.min_size, min(img.height, self.max_size)) 193 region = T.RandomCrop.get_params(img, [h, w]) 194 result_img, result_target = crop(img, target, region)
188 init_boxes = len(target["boxes"]) 189 max_patience = 10 190 for i in range(max_patience): 191 w = random.randint(self.min_size, min(img.width, self.max_size)) 192 h = random.randint(self.min_size, min(img.height, self.max_size)) 193 region = T.RandomCrop.get_params(img, [h, w]) 194 result_img, result_target = crop(img, target, region) 195 if (
217 def __init__(self, p=0.5): 218 self.p = p 219 220 def __call__(self, img, target): 221 if random.random() < self.p: 222 return hflip(img, target) 223 return img, target 224
224 225 226 class RandomResize(object): 227 def __init__(self, sizes, max_size=None): 228 assert isinstance(sizes, (list, tuple)) 229 self.sizes = sizes 230 self.max_size = max_size 231
229 self.sizes = sizes 230 self.max_size = max_size 231 232 def __call__(self, img, target=None): 233 size = random.choice(self.sizes) 234 return resize(img, target, size, self.max_size) 235 236
238 def __init__(self, max_pad): 239 self.max_pad = max_pad 240 241 def __call__(self, img, target): 242 pad_x = random.randint(0, self.max_pad) 243 pad_y = random.randint(0, self.max_pad) 244 return pad(img, target, (pad_x, pad_y)) 245
239 self.max_pad = max_pad 240 241 def __call__(self, img, target): 242 pad_x = random.randint(0, self.max_pad) 243 pad_y = random.randint(0, self.max_pad) 244 return pad(img, target, (pad_x, pad_y)) 245 246
255 self.transforms2 = transforms2 256 self.p = p 257 258 def __call__(self, img, target): 259 if random.random() < self.p: 260 return self.transforms1(img, target) 261 return self.transforms2(img, target) 262
108 xs = self.body(tensor_list.tensors) 109 out: Dict[str, NestedTensor] = {} 110 for name, x in xs.items(): 111 m = tensor_list.mask 112 assert m is not None 113 mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 114 out[name] = NestedTensor(x, mask) 115 # import ipdb; ipdb.set_trace()
135 ) 136 else: 137 raise NotImplementedError("Why you can get here with name {}".format(name)) 138 # num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 139 assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available." 140 assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]] 141 num_channels_all = [256, 512, 1024, 2048] 142 num_channels = num_channels_all[4 - len(return_interm_indices) :]
136 else: 137 raise NotImplementedError("Why you can get here with name {}".format(name)) 138 # num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 139 assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available." 140 assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]] 141 num_channels_all = [256, 512, 1024, 2048] 142 num_channels = num_channels_all[4 - len(return_interm_indices) :] 143 super().__init__(backbone, train_backbone, num_channels, return_interm_indices)
174 train_backbone = True 175 if not train_backbone: 176 raise ValueError("Please set lr_backbone > 0") 177 return_interm_indices = args.return_interm_indices 178 assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]] 179 args.backbone_freeze_keywords 180 use_checkpoint = getattr(args, "use_checkpoint", False) 181
207 bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :] 208 else: 209 raise NotImplementedError("Unknown backbone {}".format(args.backbone)) 210 211 assert len(bb_num_channels) == len( 212 return_interm_indices 213 ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}" 214 215 model = Joiner(backbone, position_embedding) 216 model.num_channels = bb_num_channels
213 ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}" 214 215 model = Joiner(backbone, position_embedding) 216 model.num_channels = bb_num_channels 217 assert isinstance( 218 bb_num_channels, List 219 ), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels)) 220 # import ipdb; ipdb.set_trace() 221 return model
46 47 def forward(self, tensor_list: NestedTensor): 48 x = tensor_list.tensors 49 mask = tensor_list.mask 50 assert mask is not None 51 not_mask = ~mask 52 y_embed = not_mask.cumsum(1, dtype=torch.float32) 53 x_embed = not_mask.cumsum(2, dtype=torch.float32)
97 98 def forward(self, tensor_list: NestedTensor): 99 x = tensor_list.tensors 100 mask = tensor_list.mask 101 assert mask is not None 102 not_mask = ~mask 103 y_embed = not_mask.cumsum(1, dtype=torch.float32) 104 x_embed = not_mask.cumsum(2, dtype=torch.float32)
211 self.num_heads = num_heads 212 self.window_size = window_size 213 self.shift_size = shift_size 214 self.mlp_ratio = mlp_ratio 215 assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 216 217 self.norm1 = norm_layer(dim) 218 self.attn = WindowAttention(
243 mask_matrix: Attention mask for cyclic shift. 244 """ 245 B, L, C = x.shape 246 H, W = self.H, self.W 247 assert L == H * W, "input feature has wrong size" 248 249 shortcut = x 250 x = self.norm1(x)
317 x: Input feature, tensor size (B, H*W, C). 318 H, W: Spatial resolution of the input feature. 319 """ 320 B, L, C = x.shape 321 assert L == H * W, "input feature has wrong size" 322 323 x = x.view(B, H, W, C) 324
746 # collect for nesttensors 747 outs_dict = {} 748 for idx, out_i in enumerate(outs): 749 m = tensor_list.mask 750 assert m is not None 751 mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0] 752 outs_dict[idx] = NestedTensor(out_i, mask) 753
759 self._freeze_stages() 760 761 762 def build_swin_transformer(modelname, pretrain_img_size, **kw): 763 assert modelname in [ 764 "swin_T_224_1k", 765 "swin_B_224_22k", 766 "swin_B_384_22k", 767 "swin_L_224_22k", 768 "swin_L_384_22k", 769 ] 770 771 model_para_dict = { 772 "swin_T_224_1k": dict(
105 self.head_dim = embed_dim // num_heads 106 self.v_dim = v_dim 107 self.l_dim = l_dim 108 109 assert ( 110 self.head_dim * self.num_heads == self.embed_dim 111 ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." 112 self.scale = self.head_dim ** (-0.5) 113 self.dropout = dropout 114
83 self.sub_sentence_present = sub_sentence_present 84 85 # setting query dim 86 self.query_dim = query_dim 87 assert query_dim == 4 88 89 # for dn training 90 self.num_patterns = num_patterns
129 ) 130 in_channels = hidden_dim 131 self.input_proj = nn.ModuleList(input_proj_list) 132 else: 133 assert two_stage_type == "no", "two_stage_type should be no if num_feature_levels=1 !!!" 134 self.input_proj = nn.ModuleList( 135 [ 136 nn.Sequential(
144 self.aux_loss = aux_loss 145 self.box_pred_damping = box_pred_damping = None 146 147 self.iter_update = iter_update 148 assert iter_update, "Why not iter_update?" 149 150 # prepare pred layers 151 self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
169 self.transformer.decoder.class_embed = self.class_embed 170 171 # two stage 172 self.two_stage_type = two_stage_type 173 assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format( 174 two_stage_type 175 ) 176 if two_stage_type != "no": 177 if two_stage_bbox_embed_share: 178 assert dec_pred_bbox_embed_share
174 two_stage_type 175 ) 176 if two_stage_type != "no": 177 if two_stage_bbox_embed_share: 178 assert dec_pred_bbox_embed_share 179 self.transformer.enc_out_bbox_embed = _bbox_embed 180 else: 181 self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
180 else: 181 self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed) 182 183 if two_stage_class_embed_share: 184 assert dec_pred_bbox_embed_share 185 self.transformer.enc_out_class_embed = _class_embed 186 else: 187 self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed)
283 for l, feat in enumerate(features): 284 src, mask = feat.decompose() 285 srcs.append(self.input_proj[l](src)) 286 masks.append(mask) 287 assert mask is not None 288 if self.num_feature_levels > len(srcs): 289 _len_srcs = len(srcs) 290 for l in range(_len_srcs, self.num_feature_levels):
224 225 bs, num_query, _ = query.shape 226 bs, num_value, _ = value.shape 227 228 assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value 229 230 value = self.value_proj(value) 231 if key_padding_mask is not None:
77 self.num_encoder_layers = num_encoder_layers 78 self.num_unicoder_layers = num_unicoder_layers 79 self.num_decoder_layers = num_decoder_layers 80 self.num_queries = num_queries 81 assert query_dim == 4 82 83 # choose encoder layer type 84 encoder_layer = DeformableTransformerEncoderLayer(
107 else: 108 feature_fusion_layer = None 109 110 encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 111 assert encoder_norm is None 112 self.encoder = TransformerEncoder( 113 encoder_layer, 114 num_encoder_layers,
158 else: 159 self.level_embed = None 160 161 self.learnable_tgt_init = learnable_tgt_init 162 assert learnable_tgt_init, "why not learnable_tgt_init" 163 self.embed_init_tgt = embed_init_tgt 164 if (two_stage_type != "no" and embed_init_tgt) or (two_stage_type == "no"): 165 self.tgt_embed = nn.Embedding(self.num_queries, d_model)
168 self.tgt_embed = None 169 170 # for two stage 171 self.two_stage_type = two_stage_type 172 assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format( 173 two_stage_type 174 ) 175 if two_stage_type == "standard": 176 # anchor selection at the output of encoder 177 self.enc_output = nn.Linear(d_model, d_model)
613 self.layers = [] 614 self.num_layers = num_layers 615 self.norm = norm 616 self.return_intermediate = return_intermediate 617 assert return_intermediate, "support return_intermediate only" 618 self.query_dim = query_dim 619 assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim) 620 self.num_feature_levels = num_feature_levels
615 self.norm = norm 616 self.return_intermediate = return_intermediate 617 assert return_intermediate, "support return_intermediate only" 618 self.query_dim = query_dim 619 assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim) 620 self.num_feature_levels = num_feature_levels 621 622 self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
669 reference_points[:, :, None] 670 * torch.cat([valid_ratios, valid_ratios], -1)[None, :] 671 ) # nq, bs, nlevel, 4 672 else: 673 assert reference_points.shape[-1] == 2 674 reference_points_input = reference_points[:, :, None] * valid_ratios[None, :] 675 query_sine_embed = gen_sineembed_for_position( 676 reference_points_input[:, :, 0, :]
845 self.norm3 = nn.LayerNorm(d_model) 846 847 self.key_aware_proj = None 848 self.use_text_feat_guide = use_text_feat_guide 849 assert not use_text_feat_guide 850 self.use_text_cross_attention = use_text_cross_attention 851 852 def rm_self_attn_modules(self):
889 Input: 890 - tgt/tgt_query_pos: nq, bs, d_model 891 - 892 """ 893 assert cross_attn_mask is None 894 895 # self attention 896 if self.self_attn is not None:
252 } 253 Returns: 254 _type_: _description_ 255 """ 256 assert isinstance(text_dict, dict) 257 258 y = text_dict["encoded_text"] 259 text_token_mask = text_dict["text_token_mask"]
11 def build_model(args): 12 # we use register to maintain models from catdet6 on. 13 from .registry import MODULE_BUILD_FUNCS 14 15 assert args.modelname in MODULE_BUILD_FUNCS._module_dict 16 build_func = MODULE_BUILD_FUNCS.get(args.modelname) 17 model = build_func(args) 18 return model
46 and M = len(boxes2) 47 """ 48 # degenerate boxes gives inf / nan results 49 # so do an early check 50 assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 51 assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 52 # except: 53 # import ipdb; ipdb.set_trace()
47 """ 48 # degenerate boxes gives inf / nan results 49 # so do an early check 50 assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 51 assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 52 # except: 53 # import ipdb; ipdb.set_trace() 54 iou, union = box_iou(boxes1, boxes2)
89 - giou: N, 4 90 """ 91 # degenerate boxes gives inf / nan results 92 # so do an early check 93 assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 94 assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 95 assert boxes1.shape == boxes2.shape 96 iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4
90 """ 91 # degenerate boxes gives inf / nan results 92 # so do an early check 93 assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 94 assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 95 assert boxes1.shape == boxes2.shape 96 iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4 97
91 # degenerate boxes gives inf / nan results 92 # so do an early check 93 assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 94 assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 95 assert boxes1.shape == boxes2.shape 96 iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4 97 98 lt = torch.min(boxes1[:, :2], boxes2[:, :2])
9 import functools 10 import io 11 import json 12 import os 13 import pickle 14 import subprocess 15 import time 16 from collections import OrderedDict, defaultdict, deque
10 import io 11 import json 12 import os 13 import pickle 14 import subprocess 15 import time 16 from collections import OrderedDict, defaultdict, deque 17 from typing import List, Optional
142 print("gathering on cpu") 143 dist.all_gather(size_list, local_size, group=cpu_group) 144 size_list = [int(size.item()) for size in size_list] 145 max_size = max(size_list) 146 assert isinstance(local_size.item(), int) 147 local_size = int(local_size.item()) 148 149 # receiving Tensor from all ranks
211 212 data_list = [] 213 for size, tensor in zip(size_list, tensor_list): 214 buffer = tensor.cpu().numpy().tobytes()[:size] 215 data_list.append(pickle.loads(buffer)) 216 217 return data_list 218
252 def update(self, **kwargs): 253 for k, v in kwargs.items(): 254 if isinstance(v, torch.Tensor): 255 v = v.item() 256 assert isinstance(v, (float, int)) 257 self.meters[k].update(v) 258 259 def __getattr__(self, attr):
362 def get_sha(): 363 cwd = os.path.dirname(os.path.abspath(__file__)) 364 365 def _run(command): 366 return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() 367 368 sha = "N/A" 369 diff = "clean"
369 diff = "clean" 370 branch = "N/A" 371 try: 372 sha = _run(["git", "rev-parse", "HEAD"]) 373 subprocess.check_output(["git", "diff"], cwd=cwd) 374 diff = _run(["git", "diff-index", "HEAD"]) 375 diff = "has uncommited changes" if diff else "clean" 376 branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
369 diff = "clean" 370 branch = "N/A" 371 try: 372 sha = _run(["git", "rev-parse", "HEAD"]) 373 subprocess.check_output(["git", "diff"], cwd=cwd) 374 diff = _run(["git", "diff-index", "HEAD"]) 375 diff = "has uncommited changes" if diff else "clean" 376 branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
373 subprocess.check_output(["git", "diff"], cwd=cwd) 374 diff = _run(["git", "diff-index", "HEAD"]) 375 diff = "has uncommited changes" if diff else "clean" 376 branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) 377 except Exception: 378 pass 379 message = f"sha: {sha}, status: {diff}, branch: {branch}" 380 return message 381
426 # type: (Device) -> NestedTensor # noqa 427 cast_tensor = self.tensors.to(device) 428 mask = self.mask 429 if mask is not None: 430 assert mask is not None 431 cast_mask = mask.to(device) 432 else: 433 cast_mask = None
433 cast_mask = None 434 return NestedTensor(cast_tensor, cast_mask) 435 436 def to_img_list_single(self, tensor, mask): 437 assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(tensor.dim()) 438 maxH = (~mask).sum(0).max() 439 maxW = (~mask).sum(1).max() 440 img = tensor[:, :maxH, :maxW]
2 # Modified from mmcv 3 # ========================================================== 4 5 import json 6 import pickle 7 from abc import ABCMeta, abstractmethod 8 from pathlib import Path 9
54 55 56 class PickleHandler(BaseFileHandler): 57 def load_from_fileobj(self, file, **kwargs): 58 return pickle.load(file, **kwargs) 59 60 def load_from_path(self, filepath, **kwargs): 61 return super(PickleHandler, self).load_from_path(filepath, mode="rb", **kwargs)
74 75 class YamlHandler(BaseFileHandler): 76 def load_from_fileobj(self, file, **kwargs): 77 kwargs.setdefault("Loader", Loader) 78 return yaml.load(file, **kwargs) 79 80 def dump_to_fileobj(self, obj, file, **kwargs): 81 kwargs.setdefault("Dumper", Dumper)
39 img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 40 ) -> torch.FloatTensor: 41 # img: tensor(3,H,W) or tensor(B,3,H,W) 42 # return: same as img 43 assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim() 44 if img.dim() == 3: 45 assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % ( 46 img.size(0),
41 # img: tensor(3,H,W) or tensor(B,3,H,W) 42 # return: same as img 43 assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim() 44 if img.dim() == 3: 45 assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % ( 46 img.size(0), 47 str(img.size()), 48 ) 49 img_perm = img.permute(1, 2, 0) 50 mean = torch.Tensor(mean) 51 std = torch.Tensor(std)
51 std = torch.Tensor(std) 52 img_res = img_perm * std + mean 53 return img_res.permute(2, 0, 1) 54 else: # img.dim() == 4 55 assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % ( 56 img.size(1), 57 str(img.size()), 58 ) 59 img_perm = img.permute(0, 2, 3, 1) 60 mean = torch.Tensor(mean) 61 std = torch.Tensor(std)
283 """ 284 Input: 285 pred, gt: Tensor() 286 """ 287 assert pred.shape == gt.shape 288 self.tp += torch.logical_and(pred == 1, gt == 1).sum().item() 289 self.fp += torch.logical_and(pred == 1, gt == 0).sum().item() 290 self.tn += torch.logical_and(pred == 0, gt == 0).sum().item()
322 raise NotImplementedError("Unknown type {}".format(type(args))) 323 324 325 def stat_tensors(tensor): 326 assert tensor.dim() == 1 327 tensor_sm = tensor.softmax(0) 328 entropy = (tensor_sm * torch.log(tensor_sm + 1e-9)).sum() 329
506 self.best_res = init_res 507 self.best_ep = -1 508 509 self.better = better 510 assert better in ["large", "small"] 511 512 def isbetter(self, new_res, old_res): 513 if self.better == "large":
598 599 def get_phrases_from_posmap( 600 posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer 601 ): 602 assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor" 603 if posmap.dim() == 1: 604 non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist() 605 token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]
6 class MaskPreview(SaveImage): 7 def __init__(self): 8 self.output_dir = folder_paths.get_temp_directory() 9 self.type = "temp" 10 self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz1234567890") for x in range(5)) 11 self.compress_level = 4 12 13 @classmethod
75 For large resolutions, 'binary_mask' may consume large amounts of 76 memory. 77 """ 78 79 assert (points_per_side is None) != ( 80 point_grids is None 81 ), "Exactly one of points_per_side or point_grid must be provided." 82 if points_per_side is not None: 83 self.point_grids = build_all_layer_point_grids( 84 points_per_side,
89 self.point_grids = point_grids 90 else: 91 raise ValueError("Can't have both points_per_side and point_grid be None.") 92 93 assert output_mode in [ 94 "binary_mask", 95 "uncompressed_rle", 96 "coco_rle", 97 ], f"Unknown output_mode {output_mode}." 98 if output_mode == "coco_rle": 99 from pycocotools import mask as mask_utils # type: ignore # noqa: F401 100
215 resolution=(14, 14), 216 ): 217 super().__init__() 218 # (h, w) 219 assert isinstance(resolution, tuple) and len(resolution) == 2 220 self.num_heads = num_heads 221 self.scale = key_dim ** -0.5 222 self.key_dim = key_dim
306 super().__init__() 307 self.dim = dim 308 self.input_resolution = input_resolution 309 self.num_heads = num_heads 310 assert window_size > 0, 'window_size must be greater than 0' 311 self.window_size = window_size 312 self.mlp_ratio = mlp_ratio 313
313 314 self.drop_path = DropPath( 315 drop_path) if drop_path > 0. else nn.Identity() 316 317 assert dim % num_heads == 0, 'dim must be divisible by num_heads' 318 head_dim = dim // num_heads 319 320 window_resolution = (window_size, window_size)
332 333 def forward(self, x): 334 H, W = self.input_resolution 335 B, L, C = x.shape 336 assert L == H * W, "input feature has wrong size" 337 res_x = x 338 if H == self.window_size and W == self.window_size: 339 x = self.attn(x)
568 i += 1 569 if layer.downsample is not None: 570 layer.downsample.apply( 571 lambda x: _set_lr_scale(x, lr_scales[i - 1])) 572 assert i == depth 573 for m in [self.norm_head, self.head]: 574 m.apply(lambda x: _set_lr_scale(x, lr_scales[-1])) 575
577 p.param_name = k 578 579 def _check_lr_scale(m): 580 for p in m.parameters(): 581 assert hasattr(p, 'lr_scale'), p.param_name 582 583 self.apply(_check_lr_scale) 584
39 1x3xHxW, which has been transformed with ResizeLongestSide. 40 original_image_size (tuple(int, int)): The size of the image 41 before transformation, in (H, W) format. 42 """ 43 assert ( 44 len(transformed_image.shape) == 4 45 and transformed_image.shape[1] == 3 46 and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 47 ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 48 self.reset_image() 49 50 self.original_size = original_image_size
13 cloth = tensor2pil(tensor_image) 14 model_folder_path = os.path.join(folder_paths.models_dir, "segformer_b2_clothes") 15 try: 16 model_folder_path = os.path.normpath(folder_paths.folder_names_and_paths['segformer_b2_clothes'][0][0]) 17 except: 18 pass 19 20 processor = SegformerImageProcessor.from_pretrained(model_folder_path) 21 model = AutoModelForSemanticSegmentation.from_pretrained(model_folder_path)
94 font_size = char_size + line_random[j] 95 font_size = int(font_size * scale / 100) 96 if font_size < 4: 97 font_size = 4 98 axis_x = _x + offset // 3 if random.random() > 0.5 else _x - offset // 3 99 axis_y = _y + offset // 3 if random.random() > 0.5 else _y - offset // 3 100 char_dict = {'char':lines[i][j], 101 'axis':(axis_x, axis_y),
95 font_size = int(font_size * scale / 100) 96 if font_size < 4: 97 font_size = 4 98 axis_x = _x + offset // 3 if random.random() > 0.5 else _x - offset // 3 99 axis_y = _y + offset // 3 if random.random() > 0.5 else _y - offset // 3 100 char_dict = {'char':lines[i][j], 101 'axis':(axis_x, axis_y), 102 'size':font_size}