213 return emb, pool
214 else:
215 return emb
216
217 def finalize_clip_regions(clip_regions, mask_token, strict_mask, start_from_masked, token_normalization='none', weight_interpretation='comfy'):
218 clip = clip_regions["clip"]
219 tokenizer = clip.tokenizer
220 if hasattr(tokenizer, 'clip_g'):
221 tokenizer = tokenizer.clip_g
222 base_weighted_tokens = clip_regions["base_tokens"]
223
224 #calc base embedding
225 base_embedding_full, pool = encode_from_tokens(clip, base_weighted_tokens, token_normalization, weight_interpretation, True)
226
227 # Avoid numpy value error and passthrough base embeddings if no regions are set.
228 if len(clip_regions["regions"]) == 0:
229 return ([[base_embedding_full, {"pooled_output": pool}]], )
230
231 if mask_token == "":
232 mask_token = 266#clip.tokenizer.end_token
233 else:
234 mask_token = tokenizer.tokenizer(mask_token)['input_ids'][1:-1]
235 if len(mask_token) > 1:
236 warnings.warn("mask_token does not map to a single token, using the first token instead")
237 mask_token = mask_token[0]
238
239 #calc global target mask
240 global_target_mask = np.any(np.stack(clip_regions["targets"]), axis=0).astype(int)
241
242 #calc global region mask
243 global_region_mask = np.any(np.stack(clip_regions["regions"]), axis=0).astype(float)
244 regions_sum = np.sum(np.stack(clip_regions["regions"]), axis=0)
245 regions_normalized = np.divide(1, regions_sum, out=np.zeros_like(regions_sum), where=regions_sum!=0)
246
247 #mask base embeddings
248 base_embedding_masked = encode_from_tokens(clip, create_masked_prompt(base_weighted_tokens, global_target_mask, mask_token), token_normalization, weight_interpretation)
249 base_embedding_start = base_embedding_full * (1-start_from_masked) + base_embedding_masked * start_from_masked
250 base_embedding_outer = base_embedding_full * (1-strict_mask) + base_embedding_masked * strict_mask
251
252 region_embeddings = []
253 for region, target, weight in zip (clip_regions["regions"],clip_regions["targets"],clip_regions["weights"]):
254 region_masking = torch.tensor(regions_normalized * region * weight, dtype=base_embedding_full.dtype, device=base_embedding_full.device).unsqueeze(-1)
255
256 region_emb = encode_from_tokens(clip, create_masked_prompt(base_weighted_tokens, global_target_mask - target, mask_token), token_normalization, weight_interpretation)
257 region_emb -= base_embedding_start
258 region_emb *= region_masking
259
260 region_embeddings.append(region_emb)
261 region_embeddings = torch.stack(region_embeddings).sum(axis=0)
262
263 embeddings_final_mask = torch.tensor(global_region_mask, dtype=base_embedding_full.dtype, device=base_embedding_full.device).unsqueeze(-1)
264 embeddings_final = base_embedding_start * embeddings_final_mask + base_embedding_outer * (1 - embeddings_final_mask)
265 embeddings_final += region_embeddings
266 return ([[embeddings_final, {"pooled_output": pool}]], )
267
268
269 class CLIPRegionsToConditioning: