146 # For single mask, expand it to match size of image batch size. 147 if mask.shape[0] == 1: 148 mask = mask.repeat(out.shape[0], 1, 1) 149 150 assert mask.ndim == 3, f"Mask should have shape [B, H, W]. {mask.shape}" 151 assert out.ndim == 4, f"Image should have shsape [B, C, H, W]. {out.shape}" 152 assert out.shape[-2:] == mask.shape[-2:], f"{out.shape[-2:]} != {mask.shape[-2:]}" 153 assert out.shape[0] == mask.shape[0], f"{out.shape[0]} != {mask.shape[0]}"
147 if mask.shape[0] == 1: 148 mask = mask.repeat(out.shape[0], 1, 1) 149 150 assert mask.ndim == 3, f"Mask should have shape [B, H, W]. {mask.shape}" 151 assert out.ndim == 4, f"Image should have shsape [B, C, H, W]. {out.shape}" 152 assert out.shape[-2:] == mask.shape[-2:], f"{out.shape[-2:]} != {mask.shape[-2:]}" 153 assert out.shape[0] == mask.shape[0], f"{out.shape[0]} != {mask.shape[0]}" 154 # Apply each mask in the batch to its corresponding image's alpha channel
148 mask = mask.repeat(out.shape[0], 1, 1) 149 150 assert mask.ndim == 3, f"Mask should have shape [B, H, W]. {mask.shape}" 151 assert out.ndim == 4, f"Image should have shsape [B, C, H, W]. {out.shape}" 152 assert out.shape[-2:] == mask.shape[-2:], f"{out.shape[-2:]} != {mask.shape[-2:]}" 153 assert out.shape[0] == mask.shape[0], f"{out.shape[0]} != {mask.shape[0]}" 154 # Apply each mask in the batch to its corresponding image's alpha channel 155 for i in range(out.shape[0]):
149 150 assert mask.ndim == 3, f"Mask should have shape [B, H, W]. {mask.shape}" 151 assert out.ndim == 4, f"Image should have shsape [B, C, H, W]. {out.shape}" 152 assert out.shape[-2:] == mask.shape[-2:], f"{out.shape[-2:]} != {mask.shape[-2:]}" 153 assert out.shape[0] == mask.shape[0], f"{out.shape[0]} != {mask.shape[0]}" 154 # Apply each mask in the batch to its corresponding image's alpha channel 155 for i in range(out.shape[0]): 156 out[i, 3, :, :] = mask[i]
54 current = self 55 while current is not None: 56 result.append(current) 57 current = current.previous 58 assert len(result) > 1, "At least 2 regions are required." 59 60 result = list(reversed(result)) 61 if result[0].mask is None: # BackgroundRegion
139 num_conds = len(region_list) 140 141 mask = torch.stack([r.mask for r in region_list], dim=0) 142 mask_sum = mask.sum(dim=0, keepdim=True) 143 assert mask_sum.sum() > 0, "There are areas that are zero in all masks." 144 self.mask = mask / mask_sum 145 146 self.conds = [r.conditioning[0][0] for r in region_list]
146 self.conds = [r.conditioning[0][0] for r in region_list] 147 num_tokens = [cond.shape[1] for cond in self.conds] 148 149 def attn2_patch(q: Tensor, k: Tensor, v: Tensor, extra_options: dict): 150 assert k.mean() == v.mean(), "k and v must be the same." 151 device, dtype = q.device, q.dtype 152 153 if self.conds[0].device != device:
33 self.init(image, min_tile_size, padding, blending) 34 return (self,) 35 36 def init(self, image: Tensor, min_tile_size: int, padding: int, blending: int): 37 assert all([x % 8 == 0 for x in image.shape[-3:-1]]), "Image size must be divisible by 8" 38 assert min_tile_size % 8 == 0, "Tile size must be divisible by 8" 39 assert blending < padding, "Blending must be smaller than padding" 40
34 return (self,) 35 36 def init(self, image: Tensor, min_tile_size: int, padding: int, blending: int): 37 assert all([x % 8 == 0 for x in image.shape[-3:-1]]), "Image size must be divisible by 8" 38 assert min_tile_size % 8 == 0, "Tile size must be divisible by 8" 39 assert blending < padding, "Blending must be smaller than padding" 40 41 self.image_size = np.array(image.shape[-3:-1])
35 36 def init(self, image: Tensor, min_tile_size: int, padding: int, blending: int): 37 assert all([x % 8 == 0 for x in image.shape[-3:-1]]), "Image size must be divisible by 8" 38 assert min_tile_size % 8 == 0, "Tile size must be divisible by 8" 39 assert blending < padding, "Blending must be smaller than padding" 40 41 self.image_size = np.array(image.shape[-3:-1]) 42 self.padding = padding
167 RETURN_TYPES = ("IMAGE",) 168 FUNCTION = "merge" 169 170 def merge(self, image: Tensor, layout: TileLayout, index: int, tile: Tensor): 171 assert index < layout.total_count, f"Index {index} out of range" 172 if index == 0: 173 image = image.clone() 174 layout.merge(image, index, tile)