42 import cupy 43 print("CuPy is already installed.") 44 except: 45 print("Uninstall cupy if existed...") 46 os.system(f'"{sys.executable}" {s_param} -m pip uninstall -y cupy-wheel cupy-cuda102 cupy-cuda110 cupy-cuda111 cupy-cuda11x cupy-cuda12x') 47 print("Installing cupy...") 48 cuda_ver = get_cuda_ver_from_dir(cuda_home) 49 cupy_package = f"cupy-cuda{cuda_ver}" if cuda_ver is not None else "cupy-wheel"
46 os.system(f'"{sys.executable}" {s_param} -m pip uninstall -y cupy-wheel cupy-cuda102 cupy-cuda110 cupy-cuda111 cupy-cuda11x cupy-cuda12x') 47 print("Installing cupy...") 48 cuda_ver = get_cuda_ver_from_dir(cuda_home) 49 cupy_package = f"cupy-cuda{cuda_ver}" if cuda_ver is not None else "cupy-wheel" 50 os.system(f'"{sys.executable}" {s_param} -m pip install {cupy_package}') 51 52 with open(Path(__file__).parent / "requirements-no-cupy.txt", 'r') as f: 53 for package in f.readlines():
52 with open(Path(__file__).parent / "requirements-no-cupy.txt", 'r') as f: 53 for package in f.readlines(): 54 package = package.strip() 55 print(f"Installing {package}...") 56 os.system(f'"{sys.executable}" {s_param} -m pip install {package}') 57 58 print("Checking cupy...") 59 install_cupy()
33 os.makedirs(f"test_result/video{i}", exist_ok=True) 34 for j, frame in enumerate(frames): 35 frame.save(f"test_result/video{i}/{j}.jpg") 36 frames[0].save(f"test_result/video{i}.gif", save_all=True, append_images=frames[1:], optimize=True, duration=1/3, loop=0) 37 os.startfile(f"test_result{os.path.sep}video{i}.gif") 38 #torchvision.io.video.write_video("test.mp4", einops.rearrange(result, "n c h w -> n h w c").cpu(), fps=1)
125 convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 126 Returns: 127 np.ndarray: Flow visualization image of shape [H,W,3] 128 """ 129 assert flow_uv.ndim == 3, 'input flow must have three dimensions' 130 assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 131 if clip_flow is not None: 132 flow_uv = np.clip(flow_uv, 0, clip_flow)
126 Returns: 127 np.ndarray: Flow visualization image of shape [H,W,3] 128 """ 129 assert flow_uv.ndim == 3, 'input flow must have three dimensions' 130 assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 131 if clip_flow is not None: 132 flow_uv = np.clip(flow_uv, 0, clip_flow) 133 u = flow_uv[:,:,0]
1099 def __call__(self, coords0, coords1): 1100 r = self.radius 1101 coords0 = coords0.permute(0, 2, 3, 1) 1102 coords1 = coords1.permute(0, 2, 3, 1) 1103 assert coords0.shape == coords1.shape, f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]" 1104 batch, h1, w1, _ = coords0.shape 1105 1106 out_pyramid = []
1028 if metric is not None and metric.dtype != torch.float32: 1029 metric = metric.type(torch.float32) 1030 1031 # move to gpu if necessary 1032 assert img.device == flow.device 1033 if metric is not None: 1034 assert img.device == metric.device 1035 was_cpu = img.device.type == "cpu"
1030 1031 # move to gpu if necessary 1032 assert img.device == flow.device 1033 if metric is not None: 1034 assert img.device == metric.device 1035 was_cpu = img.device.type == "cpu" 1036 if was_cpu: 1037 img = img.to("cuda")
1202 f = _deepflow 1203 a = a.convert("L").cv2() 1204 b = b.convert("L").cv2() 1205 else: 1206 assert 0 1207 ans = f(b, a) 1208 if back: 1209 ans = np.concatenate(
1222 # package 1223 url = f"http://localhost:8109/get-flow" 1224 if mode == "shm": 1225 t = time.time() 1226 fn_a = img_a.save(mkfile(f"/dev/shm/_flownet2/{t}/img_a.png")) 1227 fn_b = img_b.save(mkfile(f"/dev/shm/_flownet2/{t}/img_b.png")) 1228 elif mode == "net": 1229 assert False, "not impl"
1223 url = f"http://localhost:8109/get-flow" 1224 if mode == "shm": 1225 t = time.time() 1226 fn_a = img_a.save(mkfile(f"/dev/shm/_flownet2/{t}/img_a.png")) 1227 fn_b = img_b.save(mkfile(f"/dev/shm/_flownet2/{t}/img_b.png")) 1228 elif mode == "net": 1229 assert False, "not impl" 1230 q = u2d.img2uri(img.pil("RGB"))
1225 t = time.time() 1226 fn_a = img_a.save(mkfile(f"/dev/shm/_flownet2/{t}/img_a.png")) 1227 fn_b = img_b.save(mkfile(f"/dev/shm/_flownet2/{t}/img_b.png")) 1228 elif mode == "net": 1229 assert False, "not impl" 1230 q = u2d.img2uri(img.pil("RGB")) 1231 q.decode() 1232 resp = requests.get(
1228 elif mode == "net": 1229 assert False, "not impl" 1230 q = u2d.img2uri(img.pil("RGB")) 1231 q.decode() 1232 resp = requests.get( 1233 url, 1234 params={ 1235 "img_a": fn_a, 1236 "img_b": fn_b, 1237 "mode": mode, 1238 "back": back, 1239 # 'vis': vis, 1240 }, 1241 ) 1242 1243 # return 1244 ans = {"response": resp}
1250 } 1251 # if vis: 1252 # ans['output']['vis'] = I(j['fn_vis']) 1253 if mode == "shm": 1254 shutil.rmtree(f"/dev/shm/_flownet2/{t}") 1255 return ans 1256 1257
1411 1412 class GridnetTotalDropout(nn.Module): 1413 def __init__(self, p): 1414 super().__init__() 1415 assert 0 <= p < 1 1416 self.p = p 1417 self.weight = 1 / (1 - p) 1418 return
1470 x.view(bs * k, ch, h, w), 1471 is_flow=is_flow, 1472 ).view(bs, k, ch, *self.size) 1473 else: 1474 assert 0 1475 1476 1477 ###################### CANNY ######################
1543 bs, ch, h, w = img.shape 1544 if ch in [3, 4]: 1545 img = kornia.color.rgb_to_grayscale(img[:, :3]) 1546 else: 1547 assert ch == 1 1548 1549 # calculate dog 1550 kern0 = max(2 * int(sigma * kernel_factor) + 1, 3)
1582 return cd 1583 1584 1585 def batch_chamfer_distance_t(gt, pred, block=1024, return_more=False): 1586 assert gt.device == pred.device and gt.shape == pred.shape 1587 bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1] 1588 dpred = batch_edt(pred, block=block) 1589 cd = (gt * dpred).float().mean((-2, -1)) / np.sqrt(h**2 + w**2)
1587 bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1] 1588 dpred = batch_edt(pred, block=block) 1589 cd = (gt * dpred).float().mean((-2, -1)) / np.sqrt(h**2 + w**2) 1590 if len(cd.shape) == 2: 1591 assert cd.shape[1] == 1 1592 cd = cd.squeeze(1) 1593 return cd 1594
1593 return cd 1594 1595 1596 def batch_chamfer_distance_p(gt, pred, block=1024, return_more=False): 1597 assert gt.device == pred.device and gt.shape == pred.shape 1598 bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1] 1599 dgt = batch_edt(gt, block=block) 1600 cd = (pred * dgt).float().mean((-2, -1)) / np.sqrt(h**2 + w**2)
1598 bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1] 1599 dgt = batch_edt(gt, block=block) 1600 cd = (pred * dgt).float().mean((-2, -1)) / np.sqrt(h**2 + w**2) 1601 if len(cd.shape) == 2: 1602 assert cd.shape[1] == 1 1603 cd = cd.squeeze(1) 1604 return cd 1605
1606 1607 # normalized by diameter 1608 # always between [0,1] 1609 def batch_hausdorff_distance(gt, pred, block=1024, return_more=False): 1610 assert gt.device == pred.device and gt.shape == pred.shape 1611 bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1] 1612 dgt = batch_edt(gt, block=block) 1613 dpred = batch_edt(pred, block=block)
1617 (dpred * gt).amax(dim=(-2, -1)), 1618 ] 1619 ).amax(dim=0).float() / np.sqrt(h**2 + w**2) 1620 if len(hd.shape) == 2: 1621 assert hd.shape[1] == 1 1622 hd = hd.squeeze(1) 1623 return hd 1624
1662 if isinstance(x, torch.Tensor): 1663 return x.to(device) 1664 if isinstance(x, np.ndarray): 1665 return torch.tensor(x).to(device) 1666 assert 0, "data not understood" 1667 1668 1669 ################ PARSING ################
1737 with open(fn, mode) as handle: 1738 return handle.write(text) 1739 1740 1741 import pickle 1742 1743 1744 def dump(obj, fn, mode="wb"):
1748 1749 1750 def load(fn, mode="rb"): 1751 with open(fn, mode) as handle: 1752 return pickle.load(handle) 1753 1754 1755 import json
1777 def yread(fn, mode="r"): 1778 with open(fn, mode) as handle: 1779 return yaml.safe_load(handle) 1780 1781 except: 1782 pass 1783 1784 try: 1785 import pyunpack
1782 pass 1783 1784 try: 1785 import pyunpack 1786 except: 1787 pass 1788 1789 try: 1790 import mysql
1788 1789 try: 1790 import mysql 1791 import mysql.connector 1792 except: 1793 pass 1794 1795 1796 ################ MISC ################
1854 # calculate table size 1855 t = copy.deepcopy(self.t) 1856 totalrows = len(t) 1857 totalcols = [len(r) for r in t] 1858 assert min(totalcols) == max(totalcols) 1859 totalcols = totalcols[0] 1860 1861 # string-ify
1863 for j in range(totalcols): 1864 x, s = t[i][j] 1865 sp = s[11] 1866 if sp: 1867 x = eval(f'f"{{{x}{sp}}}"') 1868 Table._put((str(x), s), t, (i, j), empty) 1869 1870 # expand delimiters
780 kernel_size=size, 781 padding='same') 782 if activation is None: 783 return _conv 784 assert activation == 'relu' 785 return nn.Sequential( 786 _conv, 787 nn.LeakyReLU(.2)
105 self.test_branch_idx = test_branch_idx 106 self.norm = norm 107 self.activation = activation 108 109 assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 110 111 self.weight = nn.Parameter( 112 torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
123 def forward(self, inputs): 124 num_branch = ( 125 self.num_branch if self.training or self.test_branch_idx == -1 else 1 126 ) 127 assert len(inputs) == num_branch 128 129 if self.training or self.test_branch_idx == -1: 130 outputs = [
314 315 316 def single_head_full_attention(q, k, v): 317 # q, k, v: [B, L, C] 318 assert q.dim() == k.dim() == v.dim() == 3 319 320 scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** 0.5) # [B, L, L] 321 attn = torch.softmax(scores, dim=2) # [B, L, L]
376 attn_mask=None, 377 ): 378 # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py 379 # q, k, v: [B, L, C] 380 assert q.dim() == k.dim() == v.dim() == 3 381 382 assert h is not None and w is not None 383 assert q.size(1) == h * w
378 # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py 379 # q, k, v: [B, L, C] 380 assert q.dim() == k.dim() == v.dim() == 3 381 382 assert h is not None and w is not None 383 assert q.size(1) == h * w 384 385 b, _, c = q.size()
379 # q, k, v: [B, L, C] 380 assert q.dim() == k.dim() == v.dim() == 3 381 382 assert h is not None and w is not None 383 assert q.size(1) == h * w 384 385 b, _, c = q.size() 386
395 396 scale_factor = c**0.5 397 398 if with_shift: 399 assert attn_mask is not None # compute once 400 shift_size_h = window_size_h // 2 401 shift_size_w = window_size_w // 2 402
633 attn_num_splits=None, 634 **kwargs, 635 ): 636 b, c, h, w = feature0.shape 637 assert self.d_model == c 638 639 feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C] 640 feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
748 feature0, 749 flow, 750 local_window_radius=1, 751 ): 752 assert flow.size(1) == 2 753 assert local_window_radius > 0 754 755 b, c, h, w = feature0.size()
749 flow, 750 local_window_radius=1, 751 ): 752 assert flow.size(1) == 2 753 assert local_window_radius > 0 754 755 b, c, h, w = feature0.size() 756
933 return grid 934 935 936 def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): 937 assert device is not None 938 939 x, y = torch.meshgrid( 940 [
984 985 986 def flow_warp(feature, flow, mask=False, padding_mode="zeros"): 987 b, c, h, w = feature.size() 988 assert flow.size(1) == 2 989 990 grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] 991
994 995 def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5): 996 # fwd_flow, bwd_flow: [B, 2, H, W] 997 # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) 998 assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 999 assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 1000 flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] 1001
995 def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5): 996 # fwd_flow, bwd_flow: [B, 2, H, W] 997 # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) 998 assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 999 assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 1000 flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] 1001 1002 warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
1063 channel_last=False, 1064 ): 1065 if channel_last: # [B, H, W, C] 1066 b, h, w, c = feature.size() 1067 assert h % num_splits == 0 and w % num_splits == 0 1068 1069 b_new = b * num_splits * num_splits 1070 h_new = h // num_splits
1076 .reshape(b_new, h_new, w_new, c) 1077 ) # [B*K*K, H/K, W/K, C] 1078 else: # [B, C, H, W] 1079 b, c, h, w = feature.size() 1080 assert h % num_splits == 0 and w % num_splits == 0 1081 1082 b_new = b * num_splits * num_splits 1083 h_new = h // num_splits
1278 ) # list of features 1279 1280 flow = None 1281 1282 assert ( 1283 len(attn_splits_list) 1284 == len(corr_radius_list) 1285 == len(prop_radius_list) 1286 == self.num_scales 1287 ) 1288 1289 for scale_idx in range(self.num_scales): 1290 feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
104 self.test_branch_idx = test_branch_idx 105 self.norm = norm 106 self.activation = activation 107 108 assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 109 110 self.weight = nn.Parameter( 111 torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
122 def forward(self, inputs): 123 num_branch = ( 124 self.num_branch if self.training or self.test_branch_idx == -1 else 1 125 ) 126 assert len(inputs) == num_branch 127 128 if self.training or self.test_branch_idx == -1: 129 outputs = [
313 314 315 def single_head_full_attention(q, k, v): 316 # q, k, v: [B, L, C] 317 assert q.dim() == k.dim() == v.dim() == 3 318 319 scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** 0.5) # [B, L, L] 320 attn = torch.softmax(scores, dim=2) # [B, L, L]
375 attn_mask=None, 376 ): 377 # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py 378 # q, k, v: [B, L, C] 379 assert q.dim() == k.dim() == v.dim() == 3 380 381 assert h is not None and w is not None 382 assert q.size(1) == h * w
377 # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py 378 # q, k, v: [B, L, C] 379 assert q.dim() == k.dim() == v.dim() == 3 380 381 assert h is not None and w is not None 382 assert q.size(1) == h * w 383 384 b, _, c = q.size()
378 # q, k, v: [B, L, C] 379 assert q.dim() == k.dim() == v.dim() == 3 380 381 assert h is not None and w is not None 382 assert q.size(1) == h * w 383 384 b, _, c = q.size() 385
394 395 scale_factor = c**0.5 396 397 if with_shift: 398 assert attn_mask is not None # compute once 399 shift_size_h = window_size_h // 2 400 shift_size_w = window_size_w // 2 401
632 attn_num_splits=None, 633 **kwargs, 634 ): 635 b, c, h, w = feature0.shape 636 assert self.d_model == c 637 638 feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C] 639 feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
747 feature0, 748 flow, 749 local_window_radius=1, 750 ): 751 assert flow.size(1) == 2 752 assert local_window_radius > 0 753 754 b, c, h, w = feature0.size()
748 flow, 749 local_window_radius=1, 750 ): 751 assert flow.size(1) == 2 752 assert local_window_radius > 0 753 754 b, c, h, w = feature0.size() 755
932 return grid 933 934 935 def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): 936 assert device is not None 937 938 x, y = torch.meshgrid( 939 [
983 984 985 def flow_warp(feature, flow, mask=False, padding_mode="zeros"): 986 b, c, h, w = feature.size() 987 assert flow.size(1) == 2 988 989 grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] 990
993 994 def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5): 995 # fwd_flow, bwd_flow: [B, 2, H, W] 996 # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) 997 assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 998 assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 999 flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] 1000
994 def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5): 995 # fwd_flow, bwd_flow: [B, 2, H, W] 996 # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) 997 assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 998 assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 999 flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] 1000 1001 warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
1062 channel_last=False, 1063 ): 1064 if channel_last: # [B, H, W, C] 1065 b, h, w, c = feature.size() 1066 assert h % num_splits == 0 and w % num_splits == 0 1067 1068 b_new = b * num_splits * num_splits 1069 h_new = h // num_splits
1075 .reshape(b_new, h_new, w_new, c) 1076 ) # [B*K*K, H/K, W/K, C] 1077 else: # [B, C, H, W] 1078 b, c, h, w = feature.size() 1079 assert h % num_splits == 0 and w % num_splits == 0 1080 1081 b_new = b * num_splits * num_splits 1082 h_new = h // num_splits
1277 ) # list of features 1278 1279 flow = None 1280 1281 assert ( 1282 len(attn_splits_list) 1283 == len(corr_radius_list) 1284 == len(prop_radius_list) 1285 == self.num_scales 1286 ) 1287 1288 for scale_idx in range(self.num_scales): 1289 feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
613 self.maskconvx8 = nn.Conv2d(c, 8 * 8 * 9, 1, 1, 0) 614 self.maskconvx4 = nn.Conv2d(c, 4 * 4 * 9, 1, 1, 0) 615 616 self.level = level 617 assert self.level in [4, 8, 16], "Bitch" 618 619 def mask_conv(self, x): 620 if self.level == 4:
266 ) # https://github.com/pytorch/pytorch/issues/62854 267 268 # end 269 270 assert False # to make torchscript happy 271 272 # end 273
296 ] 297 fltStride *= 1.0 298 299 elif True: 300 assert False 301 302 # end 303 # end
384 385 # end 386 # end 387 388 assert len(intChans) == 1 389 390 # end 391
7 8 config_path = Path(Path(__file__).parent.parent.parent.resolve(), "config.yaml") 9 10 if os.path.exists(config_path): 11 config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) 12 ops_backend = config["ops_backend"] 13 else: 14 ops_backend = "taichi"
12 ops_backend = config["ops_backend"] 13 else: 14 ops_backend = "taichi" 15 16 assert ops_backend in ["taichi", "cupy"] 17 18 if ops_backend == "taichi": 19 from .taichi_ops import softsplat, ModuleSoftsplat, FunctionSoftsplat, softsplat_func, costvol_func, sepconv_func, init, batch_edt, FunctionAdaCoF, ModuleCorrelation, FunctionCorrelation, _FunctionCorrelation
270 intFilterSize = int(math.sqrt(weight.size(1))) 271 intOutputHeight = weight.size(2) 272 intOutputWidth = weight.size(3) 273 274 assert ( 275 intInputHeight - ((intFilterSize - 1) * dilation + 1) == intOutputHeight - 1 276 ) 277 assert ( 278 intInputWidth - ((intFilterSize - 1) * dilation + 1) == intOutputWidth - 1 279 )
273 274 assert ( 275 intInputHeight - ((intFilterSize - 1) * dilation + 1) == intOutputHeight - 1 276 ) 277 assert ( 278 intInputWidth - ((intFilterSize - 1) * dilation + 1) == intOutputWidth - 1 279 ) 280 281 assert input.is_contiguous() == True 282 assert weight.is_contiguous() == True
277 assert ( 278 intInputWidth - ((intFilterSize - 1) * dilation + 1) == intOutputWidth - 1 279 ) 280 281 assert input.is_contiguous() == True 282 assert weight.is_contiguous() == True 283 assert offset_i.is_contiguous() == True 284 assert offset_j.is_contiguous() == True
278 intInputWidth - ((intFilterSize - 1) * dilation + 1) == intOutputWidth - 1 279 ) 280 281 assert input.is_contiguous() == True 282 assert weight.is_contiguous() == True 283 assert offset_i.is_contiguous() == True 284 assert offset_j.is_contiguous() == True 285
279 ) 280 281 assert input.is_contiguous() == True 282 assert weight.is_contiguous() == True 283 assert offset_i.is_contiguous() == True 284 assert offset_j.is_contiguous() == True 285 286 output = input.new_zeros(
280 281 assert input.is_contiguous() == True 282 assert weight.is_contiguous() == True 283 assert offset_i.is_contiguous() == True 284 assert offset_j.is_contiguous() == True 285 286 output = input.new_zeros( 287 intSample, intInputDepth, intOutputHeight, intOutputWidth
343 intFilterSize = int(math.sqrt(weight.size(1))) 344 intOutputHeight = weight.size(2) 345 intOutputWidth = weight.size(3) 346 347 assert ( 348 intInputHeight - ((intFilterSize - 1) * dilation + 1) == intOutputHeight - 1 349 ) 350 assert ( 351 intInputWidth - ((intFilterSize - 1) * dilation + 1) == intOutputWidth - 1 352 )
346 347 assert ( 348 intInputHeight - ((intFilterSize - 1) * dilation + 1) == intOutputHeight - 1 349 ) 350 assert ( 351 intInputWidth - ((intFilterSize - 1) * dilation + 1) == intOutputWidth - 1 352 ) 353 354 assert gradOutput.is_contiguous() == True 355
350 assert ( 351 intInputWidth - ((intFilterSize - 1) * dilation + 1) == intOutputWidth - 1 352 ) 353 354 assert gradOutput.is_contiguous() == True 355 356 gradInput = ( 357 input.new_zeros(intSample, intInputDepth, intInputHeight, intInputWidth)
48 _batch_edt = cuda_launch(*_batch_edt_kernel) 49 50 # bookkeeppingg 51 if len(img.shape) == 4: 52 assert img.shape[1] == 1 53 img = img.squeeze(1) 54 expand = True 55 else:
241 242 self.save_for_backward(first, second, rbot0, rbot1) 243 244 first = first.contiguous() 245 assert first.is_cuda == True 246 second = second.contiguous() 247 assert second.is_cuda == True 248
243 244 first = first.contiguous() 245 assert first.is_cuda == True 246 second = second.contiguous() 247 assert second.is_cuda == True 248 249 output = first.new_zeros([first.shape[0], 81, first.shape[2], first.shape[3]]) 250
298 def backward(self, gradOutput): 299 first, second, rbot0, rbot1 = self.saved_tensors 300 301 gradOutput = gradOutput.contiguous() 302 assert gradOutput.is_cuda == True 303 304 gradFirst = ( 305 first.new_zeros(
189 def backward(self, tenOutgrad): 190 tenOne, tenTwo = self.saved_tensors 191 192 tenOutgrad = tenOutgrad.contiguous() 193 assert tenOutgrad.is_cuda == True 194 195 tenOnegrad = ( 196 tenOne.new_zeros(
189 ], 190 ) 191 192 elif tenIn.is_cuda != True: 193 assert False 194 195 # end 196
205 def backward(self, tenOutgrad): 206 tenIn, tenVer, tenHor = self.saved_tensors 207 208 tenOutgrad = tenOutgrad.contiguous() 209 assert tenOutgrad.is_cuda == True 210 211 tenIngrad = ( 212 tenIn.new_empty(
223 ), 224 ) 225 226 elif tenIn.is_cuda != True: 227 assert False 228 229 # end 230
239 def backward(self, tenOutgrad): 240 tenIn, tenFlow = self.saved_tensors 241 242 tenOutgrad = tenOutgrad.contiguous() 243 assert tenOutgrad.is_cuda == True 244 245 tenIngrad = ( 246 tenIn.new_zeros(
322 # end 323 324 325 def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): 326 assert tenMetric is None or tenMetric.shape[1] == 1 327 assert strType in ["summation", "average", "linear", "softmax"] 328 329 if strType == "average":
323 324 325 def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): 326 assert tenMetric is None or tenMetric.shape[1] == 1 327 assert strType in ["summation", "average", "linear", "softmax"] 328 329 if strType == "average": 330 tenInput = torch.cat(
381 382 def softsplat( 383 tenIn: torch.Tensor, tenFlow: torch.Tensor, tenMetric: torch.Tensor, strMode: str 384 ): 385 assert strMode.split("-")[0] in ["sum", "avg", "linear", "soft"] 386 387 if strMode == "sum": 388 assert tenMetric is None
384 ): 385 assert strMode.split("-")[0] in ["sum", "avg", "linear", "soft"] 386 387 if strMode == "sum": 388 assert tenMetric is None 389 if strMode == "avg": 390 assert tenMetric is None 391 if strMode.split("-")[0] == "linear":
386 387 if strMode == "sum": 388 assert tenMetric is None 389 if strMode == "avg": 390 assert tenMetric is None 391 if strMode.split("-")[0] == "linear": 392 assert tenMetric is not None 393 if strMode.split("-")[0] == "soft":
388 assert tenMetric is None 389 if strMode == "avg": 390 assert tenMetric is None 391 if strMode.split("-")[0] == "linear": 392 assert tenMetric is not None 393 if strMode.split("-")[0] == "soft": 394 assert tenMetric is not None 395
390 assert tenMetric is None 391 if strMode.split("-")[0] == "linear": 392 assert tenMetric is not None 393 if strMode.split("-")[0] == "soft": 394 assert tenMetric is not None 395 396 if strMode == "avg": 397 tenIn = torch.cat(
59 strKey += str(objValue.stride()) 60 61 elif True: 62 print(strVariable, type(objValue)) 63 assert False 64 65 # end 66 # end
105 strKernel = strKernel.replace("{{type}}", "long") 106 107 elif type(objValue) == torch.Tensor: 108 print(strVariable, objValue.dtype) 109 assert False 110 111 elif True: 112 print(strVariable, type(objValue))
109 assert False 110 111 elif True: 112 print(strVariable, type(objValue)) 113 assert False 114 115 # end 116 # end
180 181 intArgs = int(objMatch.group(2)) 182 strArgs = strKernel[intStart:intStop].split(",") 183 184 assert intArgs == len(strArgs) - 1 185 186 strTensor = strArgs[0] 187 intStrides = objVariables[strTensor].stride()
27 28 def softsplat( 29 tenIn: torch.Tensor, tenFlow: torch.Tensor, tenMetric: torch.Tensor, strMode: str 30 ): 31 assert strMode.split("-")[0] in ["sum", "avg", "linear", "soft"] 32 33 if strMode == "sum": 34 assert tenMetric is None
30 ): 31 assert strMode.split("-")[0] in ["sum", "avg", "linear", "soft"] 32 33 if strMode == "sum": 34 assert tenMetric is None 35 if strMode == "avg": 36 assert tenMetric is None 37 if strMode.split("-")[0] == "linear":
32 33 if strMode == "sum": 34 assert tenMetric is None 35 if strMode == "avg": 36 assert tenMetric is None 37 if strMode.split("-")[0] == "linear": 38 assert tenMetric is not None 39 if strMode.split("-")[0] == "soft":
34 assert tenMetric is None 35 if strMode == "avg": 36 assert tenMetric is None 37 if strMode.split("-")[0] == "linear": 38 assert tenMetric is not None 39 if strMode.split("-")[0] == "soft": 40 assert tenMetric is not None 41
36 assert tenMetric is None 37 if strMode.split("-")[0] == "linear": 38 assert tenMetric is not None 39 if strMode.split("-")[0] == "soft": 40 assert tenMetric is not None 41 42 if strMode == "avg": 43 tenIn = torch.cat(
80 81 return tenOut 82 83 def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): 84 assert tenMetric is None or tenMetric.shape[1] == 1 85 assert strType in ["summation", "average", "linear", "softmax"] 86 87 if strType == "average":
81 return tenOut 82 83 def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): 84 assert tenMetric is None or tenMetric.shape[1] == 1 85 assert strType in ["summation", "average", "linear", "softmax"] 86 87 if strType == "average": 88 tenInput = torch.cat(
28 import typing 29 30 ########################################################## 31 32 assert ( 33 int(str("").join(torch.__version__.split(".")[0:2])) >= 13 34 ) # requires at least pytorch version 1.3.0 35 36 torch.set_grad_enabled( 37 False
169 ) # https://github.com/pytorch/pytorch/issues/62854 170 171 # end 172 173 assert False # to make torchscript happy 174 175 # end 176
199 ] 200 fltStride *= 1.0 201 202 elif True: 203 assert False 204 205 # end 206 # end
287 288 # end 289 # end 290 291 assert len(intChans) == 1 292 293 # end 294
323 objScratch: typing.Dict[str, typing.List[int]], 324 ): 325 super().__init__() 326 327 assert len(intIns) == len(intOuts) 328 assert len(intOuts) == len(intIns) 329 330 self.intRows = len(intIns) and len(intOuts)
324 ): 325 super().__init__() 326 327 assert len(intIns) == len(intOuts) 328 assert len(intOuts) == len(intIns) 329 330 self.intRows = len(intIns) and len(intOuts) 331 self.intIns = intIns.copy()
418 objScratch: typing.Dict[str, typing.List[int]], 419 ): 420 super().__init__() 421 422 assert len(intIns) == len(intOuts) 423 assert len(intOuts) == len(intIns) 424 425 self.intRows = len(intIns) and len(intOuts)
419 ): 420 super().__init__() 421 422 assert len(intIns) == len(intOuts) 423 assert len(intOuts) == len(intIns) 424 425 self.intRows = len(intIns) and len(intOuts) 426 self.intIns = intIns.copy()
713 if netNetwork is None: 714 netNetwork = Network().to(get_torch_device()).eval() 715 # end 716 717 assert tenOne.shape[1] == tenTwo.shape[1] 718 assert tenOne.shape[2] == tenTwo.shape[2] 719 720 intWidth = tenOne.shape[2]
714 netNetwork = Network().to(get_torch_device()).eval() 715 # end 716 717 assert tenOne.shape[1] == tenTwo.shape[1] 718 assert tenOne.shape[2] == tenTwo.shape[2] 719 720 intWidth = tenOne.shape[2] 721 intHeight = tenOne.shape[1]
719 720 intWidth = tenOne.shape[2] 721 intHeight = tenOne.shape[1] 722 723 assert ( 724 intWidth <= 1280 725 ) # while our approach works with larger images, we do not recommend it unless you are aware of the implications 726 assert ( 727 intHeight <= 720 728 ) # while our approach works with larger images, we do not recommend it unless you are aware of the implications
722 723 assert ( 724 intWidth <= 1280 725 ) # while our approach works with larger images, we do not recommend it unless you are aware of the implications 726 assert ( 727 intHeight <= 720 728 ) # while our approach works with larger images, we do not recommend it unless you are aware of the implications 729 730 tenPreprocessedOne = tenOne.to(get_torch_device()).view(1, 3, intHeight, intWidth) 731 tenPreprocessedTwo = tenTwo.to(get_torch_device()).view(1, 3, intHeight, intWidth)
586 if netNetwork is None: 587 netNetwork = Network().cuda().eval() 588 # end 589 590 assert tenFirst.shape[1] == tenSecond.shape[1] 591 assert tenFirst.shape[2] == tenSecond.shape[2] 592 593 intWidth = tenFirst.shape[2]
587 netNetwork = Network().cuda().eval() 588 # end 589 590 assert tenFirst.shape[1] == tenSecond.shape[1] 591 assert tenFirst.shape[2] == tenSecond.shape[2] 592 593 intWidth = tenFirst.shape[2] 594 intHeight = tenFirst.shape[1]
592 593 intWidth = tenFirst.shape[2] 594 intHeight = tenFirst.shape[1] 595 596 assert ( 597 intWidth == 1024 598 ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 599 assert ( 600 intHeight == 436 601 ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
595 596 assert ( 597 intWidth == 1024 598 ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 599 assert ( 600 intHeight == 436 601 ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 602 603 tenPreprocessedFirst = tenFirst.cuda().view(1, 3, intHeight, intWidth) 604 tenPreprocessedSecond = tenSecond.cuda().view(1, 3, intHeight, intWidth)
1588 1589 self.n_row = n_row 1590 self.n_col = n_col 1591 self.n_chs = grid_chs 1592 assert ( 1593 len(grid_chs) == self.n_row 1594 ), "should give num channels for each row (scale stream)" 1595 assert ( 1596 len(in_chs) == self.n_row 1597 ), "should give input channels for each row (scale stream)"
1591 self.n_chs = grid_chs 1592 assert ( 1593 len(grid_chs) == self.n_row 1594 ), "should give num channels for each row (scale stream)" 1595 assert ( 1596 len(in_chs) == self.n_row 1597 ), "should give input channels for each row (scale stream)" 1598 1599 for r, n_ch in enumerate(self.n_chs): 1600 setattr(self, f"lateral_{r}_0", LateralBlock(in_chs[r], n_ch))
1611 1612 self.lateral_final = LateralBlock(self.n_chs[0], out_chs) 1613 1614 def forward(self, *args): 1615 assert len(args) == self.n_row 1616 1617 # extensible, memory-efficient 1618 cur_col = list(args)
1642 self.n_row = n_row 1643 self.n_col = n_col 1644 self.n_chs = grid_chs 1645 self.outrow = outrow 1646 assert ( 1647 len(grid_chs) == self.n_row 1648 ), "should give num channels for each row (scale stream)" 1649 assert ( 1650 len(in_chs) == self.n_row 1651 ), "should give input channels for each row (scale stream)"
1645 self.outrow = outrow 1646 assert ( 1647 len(grid_chs) == self.n_row 1648 ), "should give num channels for each row (scale stream)" 1649 assert ( 1650 len(in_chs) == self.n_row 1651 ), "should give input channels for each row (scale stream)" 1652 assert len(out_chs) == len( 1653 self.outrow 1654 ), "should give out channels for each output row (scale stream)"
1648 ), "should give num channels for each row (scale stream)" 1649 assert ( 1650 len(in_chs) == self.n_row 1651 ), "should give input channels for each row (scale stream)" 1652 assert len(out_chs) == len( 1653 self.outrow 1654 ), "should give out channels for each output row (scale stream)" 1655 1656 for r, n_ch in enumerate(self.n_chs): 1657 setattr(self, f"lateral_{r}_0", LateralBlock(in_chs[r], n_ch))
1669 for i, r in enumerate(outrow): 1670 setattr(self, f"lateral_final_{r}", LateralBlock(self.n_chs[r], out_chs[i])) 1671 1672 def forward(self, *args): 1673 assert len(args) == self.n_row 1674 1675 # extensible, memory-efficient 1676 cur_col = list(args)
1701 1702 self.n_row = n_row 1703 self.n_col = n_col 1704 self.n_chs = grid_chs 1705 assert ( 1706 len(grid_chs) == self.n_row 1707 ), "should give num channels for each row (scale stream)" 1708 1709 for r, n_ch in enumerate(self.n_chs): 1710 if r == 0:
1750 1751 self.n_row = 3 1752 self.n_col = 6 1753 self.n_chs = grid_chs 1754 assert ( 1755 len(grid_chs) == self.n_row 1756 ), "should give num channels for each row (scale stream)" 1757 1758 self.lateral_init = LateralBlock(in_chs, self.n_chs[0]) 1759
2339 if netNetwork is None: 2340 netNetwork = Network().cuda().eval() 2341 # end 2342 2343 assert tenFirst.shape[1] == tenSecond.shape[1] 2344 assert tenFirst.shape[2] == tenSecond.shape[2] 2345 2346 intWidth = tenFirst.shape[2]
2340 netNetwork = Network().cuda().eval() 2341 # end 2342 2343 assert tenFirst.shape[1] == tenSecond.shape[1] 2344 assert tenFirst.shape[2] == tenSecond.shape[2] 2345 2346 intWidth = tenFirst.shape[2] 2347 intHeight = tenFirst.shape[1]
2345 2346 intWidth = tenFirst.shape[2] 2347 intHeight = tenFirst.shape[1] 2348 2349 assert ( 2350 intWidth == 1024 2351 ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 2352 assert ( 2353 intHeight == 436 2354 ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
2348 2349 assert ( 2350 intWidth == 1024 2351 ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 2352 assert ( 2353 intHeight == 436 2354 ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 2355 2356 tenPreprocessedFirst = tenFirst.cuda().view(1, 3, intHeight, intWidth) 2357 tenPreprocessedSecond = tenSecond.cuda().view(1, 3, intHeight, intWidth)
44 t_value shape : [B,1] ############### 45 ''' 46 B, C, T, H, W = x.size() 47 B2, C2 = t_value.size() 48 assert C2 == 1, "t_value shape is [B,]" 49 assert T % 2 == 0, "T must be an even number" 50 t_value = t_value.view(B, 1, 1, 1) 51
45 ''' 46 B, C, T, H, W = x.size() 47 B2, C2 = t_value.size() 48 assert C2 == 1, "t_value shape is [B,]" 49 assert T % 2 == 0, "T must be an even number" 50 t_value = t_value.view(B, 1, 1, 1) 51 52 flow_l = None
142 x shape : [B,C,T,H,W] 143 t_value shape : [B,1] ############### 144 ''' 145 B, C, T, H, W = x.size() 146 assert T % 2 == 0, "T must be an even number" 147 148 ####################### For a single level 149 l = 2 ** level
18 ] 19 20 config_path = os.path.join(os.path.dirname(__file__), "./config.yaml") 21 if os.path.exists(config_path): 22 config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) 23 else: 24 raise Exception("config.yaml file is neccessary, plz recreate the config file by downloading it from https://github.com/Fannovel16/ComfyUI-Frame-Interpolation") 25 DEVICE = get_torch_device()
118 return einops.rearrange(frames, "n c h w -> n h w c")[..., :3].cpu() 119 120 def assert_batch_size(frames, batch_size=2, vfi_name=None): 121 subject_verb = "Most VFI models require" if vfi_name is None else f"VFI model {vfi_name} requires" 122 assert len(frames) >= batch_size, f"{subject_verb} at least {batch_size} frames to work with, only found {frames.shape[0]}. Please check the frame input using PreviewImage." 123 124 def _generic_frame_loop( 125 frames,