131 132 def forward(self, x, style, noise_mode="random", gain=1): 133 x = self.conv(x, style) 134 135 assert noise_mode in ["random", "const", "none"] 136 137 if self.use_noise: 138 if noise_mode == "random":
321 322 # Apply truncation. 323 if truncation_psi != 1: 324 with torch.autograd.profiler.record_function("truncate"): 325 assert self.w_avg_beta is not None 326 if self.num_ws is None or truncation_cutoff is None: 327 x = self.w_avg.lerp(x, truncation_psi) 328 else:
643 if min(self.input_resolution) <= self.window_size: 644 # if window size is larger than input resolution, we don't partition windows 645 self.shift_size = 0 646 self.window_size = min(self.input_resolution) 647 assert ( 648 0 <= self.shift_size < self.window_size 649 ), "shift_size must in 0-window_size" 650 651 if self.shift_size > 0: 652 down_ratio = 1
1462 demodulate=True, 1463 ): 1464 super().__init__() 1465 resolution_log2 = int(np.log2(img_resolution)) 1466 assert img_resolution == 2**resolution_log2 and img_resolution >= 4 1467 1468 self.num_layers = resolution_log2 * 2 - 3 * 2 1469 self.img_resolution = img_resolution
108 109 110 def _bias_act_ref(x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None): 111 """Slow reference implementation of `bias_act()` using standard TensorFlow ops.""" 112 assert isinstance(x, torch.Tensor) 113 assert clamp is None or clamp >= 0 114 spec = activation_funcs[act] 115 alpha = float(alpha if alpha is not None else spec.def_alpha)
109 110 def _bias_act_ref(x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None): 111 """Slow reference implementation of `bias_act()` using standard TensorFlow ops.""" 112 assert isinstance(x, torch.Tensor) 113 assert clamp is None or clamp >= 0 114 spec = activation_funcs[act] 115 alpha = float(alpha if alpha is not None else spec.def_alpha) 116 gain = float(gain if gain is not None else spec.def_gain)
117 clamp = float(clamp if clamp is not None else -1) 118 119 # Add bias. 120 if b is not None: 121 assert isinstance(b, torch.Tensor) and b.ndim == 1 122 assert 0 <= dim < x.ndim 123 assert b.shape[0] == x.shape[dim] 124 x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]).to(x.device)
118 119 # Add bias. 120 if b is not None: 121 assert isinstance(b, torch.Tensor) and b.ndim == 1 122 assert 0 <= dim < x.ndim 123 assert b.shape[0] == x.shape[dim] 124 x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]).to(x.device) 125
119 # Add bias. 120 if b is not None: 121 assert isinstance(b, torch.Tensor) and b.ndim == 1 122 assert 0 <= dim < x.ndim 123 assert b.shape[0] == x.shape[dim] 124 x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]).to(x.device) 125 126 # Evaluate activation function.
166 impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 167 Returns: 168 Tensor of the same shape and datatype as `x`. 169 """ 170 assert isinstance(x, torch.Tensor) 171 assert impl in ["ref", "cuda"] 172 return _bias_act_ref( 173 x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp
167 Returns: 168 Tensor of the same shape and datatype as `x`. 169 """ 170 assert isinstance(x, torch.Tensor) 171 assert impl in ["ref", "cuda"] 172 return _bias_act_ref( 173 x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp 174 )
203 # Validate. 204 if f is None: 205 f = 1 206 f = torch.as_tensor(f, dtype=torch.float32) 207 assert f.ndim in [0, 1, 2] 208 assert f.numel() > 0 209 if f.ndim == 0: 210 f = f[np.newaxis]
204 if f is None: 205 f = 1 206 f = torch.as_tensor(f, dtype=torch.float32) 207 assert f.ndim in [0, 1, 2] 208 assert f.numel() > 0 209 if f.ndim == 0: 210 f = f[np.newaxis] 211
213 if separable is None: 214 separable = f.ndim == 1 and f.numel() >= 8 215 if f.ndim == 1 and not separable: 216 f = f.ger(f) 217 assert f.ndim == (1 if separable else 2) 218 219 # Apply normalize, flip, gain, and device. 220 if normalize:
229 def _get_filter_size(f): 230 if f is None: 231 return 1, 1 232 233 assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 234 fw = f.shape[-1] 235 fh = f.shape[0] 236
235 fh = f.shape[0] 236 237 fw = int(fw) 238 fh = int(fh) 239 assert fw >= 1 and fh >= 1 240 return fw, fh 241 242
247 248 def _parse_scaling(scaling): 249 if isinstance(scaling, int): 250 scaling = [scaling, scaling] 251 assert isinstance(scaling, (list, tuple)) 252 assert all(isinstance(x, int) for x in scaling) 253 sx, sy = scaling 254 assert sx >= 1 and sy >= 1
248 def _parse_scaling(scaling): 249 if isinstance(scaling, int): 250 scaling = [scaling, scaling] 251 assert isinstance(scaling, (list, tuple)) 252 assert all(isinstance(x, int) for x in scaling) 253 sx, sy = scaling 254 assert sx >= 1 and sy >= 1 255 return sx, sy
250 scaling = [scaling, scaling] 251 assert isinstance(scaling, (list, tuple)) 252 assert all(isinstance(x, int) for x in scaling) 253 sx, sy = scaling 254 assert sx >= 1 and sy >= 1 255 return sx, sy 256 257
257 258 def _parse_padding(padding): 259 if isinstance(padding, int): 260 padding = [padding, padding] 261 assert isinstance(padding, (list, tuple)) 262 assert all(isinstance(x, int) for x in padding) 263 if len(padding) == 2: 264 padx, pady = padding
258 def _parse_padding(padding): 259 if isinstance(padding, int): 260 padding = [padding, padding] 261 assert isinstance(padding, (list, tuple)) 262 assert all(isinstance(x, int) for x in padding) 263 if len(padding) == 2: 264 padx, pady = padding 265 padding = [padx, padx, pady, pady]
281 282 def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): 283 """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.""" 284 # Validate arguments. 285 assert isinstance(x, torch.Tensor) and x.ndim == 4 286 if f is None: 287 f = torch.ones([1, 1], dtype=torch.float32, device=x.device) 288 assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
284 # Validate arguments. 285 assert isinstance(x, torch.Tensor) and x.ndim == 4 286 if f is None: 287 f = torch.ones([1, 1], dtype=torch.float32, device=x.device) 288 assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 289 assert f.dtype == torch.float32 and not f.requires_grad 290 batch_size, num_channels, in_height, in_width = x.shape 291 # upx, upy = _parse_scaling(up)
285 assert isinstance(x, torch.Tensor) and x.ndim == 4 286 if f is None: 287 f = torch.ones([1, 1], dtype=torch.float32, device=x.device) 288 assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 289 assert f.dtype == torch.float32 and not f.requires_grad 290 batch_size, num_channels, in_height, in_width = x.shape 291 # upx, upy = _parse_scaling(up) 292 # downx, downy = _parse_scaling(down)
518 Returns: 519 Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 520 """ 521 # Validate arguments. 522 assert isinstance(x, torch.Tensor) and (x.ndim == 4) 523 assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 524 assert f is None or ( 525 isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32
519 Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 520 """ 521 # Validate arguments. 522 assert isinstance(x, torch.Tensor) and (x.ndim == 4) 523 assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 524 assert f is None or ( 525 isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32 526 )
520 """ 521 # Validate arguments. 522 assert isinstance(x, torch.Tensor) and (x.ndim == 4) 523 assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 524 assert f is None or ( 525 isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32 526 ) 527 assert isinstance(up, int) and (up >= 1) 528 assert isinstance(down, int) and (down >= 1) 529 # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}"
523 assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 524 assert f is None or ( 525 isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32 526 ) 527 assert isinstance(up, int) and (up >= 1) 528 assert isinstance(down, int) and (down >= 1) 529 # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}" 530 out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
524 assert f is None or ( 525 isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32 526 ) 527 assert isinstance(up, int) and (up >= 1) 528 assert isinstance(down, int) and (down >= 1) 529 # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}" 530 out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 531 fw, fh = _get_filter_size(f)
181 182 def _input_block_patch(self, h: Tensor, transformer_options: dict): 183 if transformer_options["block"][1] == 0: 184 if self._inpaint_block is None or self._inpaint_block.shape != h.shape: 185 assert self._inpaint_head_feature is not None 186 batch = h.shape[0] // self._inpaint_head_feature.shape[0] 187 self._inpaint_block = self._inpaint_head_feature.to(h).repeat(batch, 1, 1, 1) 188 h = h + self._inpaint_block
236 237 def fill(self, image: Tensor, mask: Tensor, fill: str, falloff: int): 238 image = image.detach().clone() 239 alpha = mask_unsqueeze(mask_floor(mask)) 240 assert alpha.shape[0] == image.shape[0], "Image and mask batch size does not match" 241 242 falloff = make_odd(falloff) 243 if falloff > 0:
434 CATEGORY = "inpaint" 435 FUNCTION = "convert" 436 437 def convert(self, mask: Tensor, offset: float, threshold: float): 438 assert 0.0 <= offset < threshold <= 1.0, "Threshold must be higher than offset" 439 mask = (mask - offset) * (1 / (threshold - offset)) 440 mask = mask.clamp(0, 1) 441 return (mask,)