28 29 cfg: Config 30 31 def configure(self) -> None: 32 assert self.cfg.feature_reduction in ["concat", "mean"] 33 self.chunk_size = 0 34 35 def set_chunk_size(self, chunk_size: int):
32 assert self.cfg.feature_reduction in ["concat", "mean"] 33 self.chunk_size = 0 34 35 def set_chunk_size(self, chunk_size: int): 36 assert ( 37 chunk_size >= 0 38 ), "chunk_size must be a non-negative integer (0 for no chunking)." 39 self.chunk_size = chunk_size 40 41 def query_triplane(
33 ) 34 35 def detokenize(self, tokens: torch.Tensor) -> torch.Tensor: 36 batch_size, Ct, Nt = tokens.shape 37 assert Nt == self.cfg.plane_size**2 * 3 38 assert Ct == self.cfg.num_channels 39 return rearrange( 40 tokens,
34 35 def detokenize(self, tokens: torch.Tensor) -> torch.Tensor: 36 batch_size, Ct, Nt = tokens.shape 37 assert Nt == self.cfg.plane_size**2 * 3 38 assert Ct == self.cfg.num_channels 39 return rearrange( 40 tokens, 41 "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
425 426 Returns: 427 `torch.Tensor`: The normalized encoder hidden states. 428 """ 429 assert ( 430 self.norm_cross is not None 431 ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states" 432 433 if isinstance(self.norm_cross, nn.LayerNorm): 434 encoder_hidden_states = self.norm_cross(encoder_hidden_states)
441 encoder_hidden_states = encoder_hidden_states.transpose(1, 2) 442 encoder_hidden_states = self.norm_cross(encoder_hidden_states) 443 encoder_hidden_states = encoder_hidden_states.transpose(1, 2) 444 else: 445 assert False 446 447 return encoder_hidden_states 448
90 ): 91 super().__init__() 92 self.only_cross_attention = only_cross_attention 93 94 assert norm_type == "layer_norm" 95 96 # Define 3 blocks. Each block has its own normalization layer. 97 # 1. Self-Attn
158 for arg in list(args) + list(kwargs.values()): 159 if isinstance(arg, torch.Tensor): 160 B = arg.shape[0] 161 break 162 assert ( 163 B is not None 164 ), "No tensor found in args or kwargs, cannot determine batch size." 165 out = defaultdict(list) 166 out_type = None 167 # max(1, B) to support B == 0
226 inp_scale = (0, 1) 227 if tgt_scale is None: 228 tgt_scale = (0, 1) 229 if isinstance(tgt_scale, torch.FloatTensor): 230 assert dat.shape[-1] == tgt_scale.shape[-1] 231 dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) 232 dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] 233 return dat
278 fx, fy = focal, focal 279 cx, cy = W / 2, H / 2 280 else: 281 fx, fy = focal 282 assert principal is not None 283 cx, cy = principal 284 285 i, j = torch.meshgrid(
302 keepdim=False, 303 normalize=False, 304 ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 305 # Rotate ray directions from camera coordinate to the world coordinate 306 assert directions.shape[-1] == 3 307 308 if directions.ndim == 2: # (N_rays, 3) 309 if c2w.ndim == 2: # (4, 4)
307 308 if directions.ndim == 2: # (N_rays, 3) 309 if c2w.ndim == 2: # (4, 4) 310 c2w = c2w[None, :, :] 311 assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4) 312 rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3) 313 rays_o = c2w[:, :3, 3].expand(rays_d.shape) 314 elif directions.ndim == 3: # (H, W, 3)
311 assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4) 312 rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3) 313 rays_o = c2w[:, :3, 3].expand(rays_d.shape) 314 elif directions.ndim == 3: # (H, W, 3) 315 assert c2w.ndim in [2, 3] 316 if c2w.ndim == 2: # (4, 4) 317 rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum( 318 -1
323 -1 324 ) # (B, H, W, 3) 325 rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape) 326 elif directions.ndim == 4: # (B, H, W, 3) 327 assert c2w.ndim == 3 # (B, 4, 4) 328 rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum( 329 -1 330 ) # (B, H, W, 3)
418 image: PIL.Image.Image, 419 ratio: float, 420 ) -> PIL.Image.Image: 421 image = np.array(image) 422 assert image.shape[-1] == 4 423 alpha = np.where(image[..., 3] > 0) 424 y1, y2, x1, x2 = ( 425 alpha[0].min(),