108 if self.patch_type == self.ATTN1: 109 to_return.control.patch_attn1 = to_return 110 elif self.patch_type == self.ATTN2: 111 to_return.control.patch_attn2 = to_return 112 except Exception: 113 pass 114 return to_return 115 116
469 def clean_module_mem(self): 470 for attn_module in self.attn_modules: 471 try: 472 attn_module.injection_holder.clean() 473 except Exception: 474 pass 475 for gn_module in self.gn_modules: 476 try: 477 gn_module.injection_holder.clean()
474 pass 475 for gn_module in self.gn_modules: 476 try: 477 gn_module.injection_holder.clean() 478 except Exception: 479 pass 480 481 def cleanup(self): 482 self.clean_module_mem()
637 for i in range(cn_idx, len(ref_controlnets)): 638 if ref_controlnets[i].order == order: 639 cn_idx = i 640 break 641 assert order == ref_controlnets[cn_idx].order 642 if ref_controlnets[cn_idx].any_attn_strength_to_apply(): 643 effective_strength = ref_controlnets[cn_idx].get_effective_attn_mask_or_float(x=n, channels=n.shape[2], is_mid=self.injection_holder.is_middle) 644 real_bank[idx] = real_bank[idx] * effective_strength + context_attn1 * (1-effective_strength)
675 for i in range(cn_idx, len(ref_controlnets)): 676 if ref_controlnets[i].order == order: 677 cn_idx = i 678 break 679 assert order == ref_controlnets[cn_idx].order 680 if ref_controlnets[cn_idx].any_attn_strength_to_apply(): 681 effective_strength = ref_controlnets[cn_idx].get_effective_attn_mask_or_float(x=n, channels=n.shape[2], is_mid=self.injection_holder.is_middle) 682 real_bank[idx] = real_bank[idx] * effective_strength + context_attn1 * (1-effective_strength)
790 for i in range(cn_idx, len(ref_controlnets)): 791 if ref_controlnets[i].order == order: 792 cn_idx = i 793 break 794 assert order == ref_controlnets[cn_idx].order 795 style_fidelity = bank_styles.style_cfgs[idx] 796 var_acc = bank_styles.var_bank[idx] 797 mean_acc = bank_styles.mean_bank[idx]
75 outs = [] 76 77 hs = [] 78 if self.num_classes is not None: 79 assert y.shape[0] == x.shape[0] 80 emb = emb + self.label_emb(y) 81 82 h = x
107 patched_model = super().patch_model(device_to, patch_weights) 108 try: 109 if self.model.motion_wrapper is not None: 110 self.model.motion_wrapper.to(device=device_to) 111 except Exception: 112 pass 113 return patched_model 114 115 def unpatch_model(self, device_to=None, unpatch_weights=True):
115 def unpatch_model(self, device_to=None, unpatch_weights=True): 116 try: 117 if self.model.motion_wrapper is not None: 118 self.model.motion_wrapper.to(device=device_to) 119 except Exception: 120 pass 121 if unpatch_weights: 122 return super().unpatch_model(device_to) 123 else:
882 *args, 883 **kwargs, 884 ): 885 super().__init__(*args, **kwargs) 886 assert attention_mode == "Temporal" 887 888 self.attention_mode = attention_mode 889 self.is_cross_attention = kwargs["context_dim"] is not None
61 operations=comfy.ops.disable_weight_init, 62 **kwargs, 63 ): 64 super().__init__() 65 assert use_spatial_transformer == True, "use_spatial_transformer has to be true" 66 if use_spatial_transformer: 67 assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' 68
63 ): 64 super().__init__() 65 assert use_spatial_transformer == True, "use_spatial_transformer has to be true" 66 if use_spatial_transformer: 67 assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' 68 69 if context_dim is not None: 70 assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
66 if use_spatial_transformer: 67 assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' 68 69 if context_dim is not None: 70 assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' 71 # from omegaconf.listconfig import ListConfig 72 # if type(context_dim) == ListConfig: 73 # context_dim = list(context_dim)
75 if num_heads_upsample == -1: 76 num_heads_upsample = num_heads 77 78 if num_heads == -1: 79 assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' 80 81 if num_head_channels == -1: 82 assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
78 if num_heads == -1: 79 assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' 80 81 if num_head_channels == -1: 82 assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' 83 84 self.dims = dims 85 self.image_size = image_size
95 self.num_res_blocks = num_res_blocks 96 97 if disable_self_attentions is not None: 98 # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not 99 assert len(disable_self_attentions) == len(channel_mult) 100 if num_attention_blocks is not None: 101 assert len(num_attention_blocks) == len(self.num_res_blocks) 102 assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
97 if disable_self_attentions is not None: 98 # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not 99 assert len(disable_self_attentions) == len(channel_mult) 100 if num_attention_blocks is not None: 101 assert len(num_attention_blocks) == len(self.num_res_blocks) 102 assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) 103 104 transformer_depth = transformer_depth[:]
98 # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not 99 assert len(disable_self_attentions) == len(channel_mult) 100 if num_attention_blocks is not None: 101 assert len(num_attention_blocks) == len(self.num_res_blocks) 102 assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) 103 104 transformer_depth = transformer_depth[:] 105
127 elif self.num_classes == "continuous": 128 print("setting up linear c_adm embedding layer") 129 self.label_emb = nn.Linear(1, time_embed_dim) 130 elif self.num_classes == "sequential": 131 assert adm_in_channels is not None 132 self.label_emb = nn.Sequential( 133 nn.Sequential( 134 operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
314 outs = [] 315 316 hs = [] 317 if self.num_classes is not None: 318 assert y.shape[0] == x.shape[0] 319 emb = emb + self.label_emb(y) 320 321 h = x
351 should be shared between the original and its copy. 352 memo is the dictionary passed into __deepcopy__. Ignore this argument if 353 not calling from within __deepcopy__. 354 ''' 355 assert isinstance(shared_attribute_names, (list, tuple)) 356 357 shared_attributes = {k: getattr(obj, k) for k in shared_attribute_names} 358