42 import cupy
43 print("CuPy is already installed.")
44 except:
45 print("Uninstall cupy if existed...")
46 os.system(f'"{sys.executable}" {s_param} -m pip uninstall -y cupy-wheel cupy-cuda102 cupy-cuda110 cupy-cuda111 cupy-cuda11x cupy-cuda12x')
47 print("Installing cupy...")
48 cuda_ver = get_cuda_ver_from_dir(cuda_home)
49 cupy_package = f"cupy-cuda{cuda_ver}" if cuda_ver is not None else "cupy-wheel"
46 os.system(f'"{sys.executable}" {s_param} -m pip uninstall -y cupy-wheel cupy-cuda102 cupy-cuda110 cupy-cuda111 cupy-cuda11x cupy-cuda12x')
47 print("Installing cupy...")
48 cuda_ver = get_cuda_ver_from_dir(cuda_home)
49 cupy_package = f"cupy-cuda{cuda_ver}" if cuda_ver is not None else "cupy-wheel"
50 os.system(f'"{sys.executable}" {s_param} -m pip install {cupy_package}')
51
52 with open(Path(__file__).parent / "requirements-no-cupy.txt", 'r') as f:
53 for package in f.readlines():
52 with open(Path(__file__).parent / "requirements-no-cupy.txt", 'r') as f:
53 for package in f.readlines():
54 package = package.strip()
55 print(f"Installing {package}...")
56 os.system(f'"{sys.executable}" {s_param} -m pip install {package}')
57
58 print("Checking cupy...")
59 install_cupy()
33 os.makedirs(f"test_result/video{i}", exist_ok=True)
34 for j, frame in enumerate(frames):
35 frame.save(f"test_result/video{i}/{j}.jpg")
36 frames[0].save(f"test_result/video{i}.gif", save_all=True, append_images=frames[1:], optimize=True, duration=1/3, loop=0)
37 os.startfile(f"test_result{os.path.sep}video{i}.gif")
38 #torchvision.io.video.write_video("test.mp4", einops.rearrange(result, "n c h w -> n h w c").cpu(), fps=1)
125 convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 126 Returns: 127 np.ndarray: Flow visualization image of shape [H,W,3] 128 """ 129 assert flow_uv.ndim == 3, 'input flow must have three dimensions' 130 assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 131 if clip_flow is not None: 132 flow_uv = np.clip(flow_uv, 0, clip_flow)
126 Returns: 127 np.ndarray: Flow visualization image of shape [H,W,3] 128 """ 129 assert flow_uv.ndim == 3, 'input flow must have three dimensions' 130 assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 131 if clip_flow is not None: 132 flow_uv = np.clip(flow_uv, 0, clip_flow) 133 u = flow_uv[:,:,0]
1099 def __call__(self, coords0, coords1):
1100 r = self.radius
1101 coords0 = coords0.permute(0, 2, 3, 1)
1102 coords1 = coords1.permute(0, 2, 3, 1)
1103 assert coords0.shape == coords1.shape, f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]"
1104 batch, h1, w1, _ = coords0.shape
1105
1106 out_pyramid = []
1028 if metric is not None and metric.dtype != torch.float32: 1029 metric = metric.type(torch.float32) 1030 1031 # move to gpu if necessary 1032 assert img.device == flow.device 1033 if metric is not None: 1034 assert img.device == metric.device 1035 was_cpu = img.device.type == "cpu"
1030
1031 # move to gpu if necessary
1032 assert img.device == flow.device
1033 if metric is not None:
1034 assert img.device == metric.device
1035 was_cpu = img.device.type == "cpu"
1036 if was_cpu:
1037 img = img.to("cuda")
1202 f = _deepflow
1203 a = a.convert("L").cv2()
1204 b = b.convert("L").cv2()
1205 else:
1206 assert 0
1207 ans = f(b, a)
1208 if back:
1209 ans = np.concatenate(
1222 # package
1223 url = f"http://localhost:8109/get-flow"
1224 if mode == "shm":
1225 t = time.time()
1226 fn_a = img_a.save(mkfile(f"/dev/shm/_flownet2/{t}/img_a.png"))
1227 fn_b = img_b.save(mkfile(f"/dev/shm/_flownet2/{t}/img_b.png"))
1228 elif mode == "net":
1229 assert False, "not impl"
1223 url = f"http://localhost:8109/get-flow"
1224 if mode == "shm":
1225 t = time.time()
1226 fn_a = img_a.save(mkfile(f"/dev/shm/_flownet2/{t}/img_a.png"))
1227 fn_b = img_b.save(mkfile(f"/dev/shm/_flownet2/{t}/img_b.png"))
1228 elif mode == "net":
1229 assert False, "not impl"
1230 q = u2d.img2uri(img.pil("RGB"))
1225 t = time.time()
1226 fn_a = img_a.save(mkfile(f"/dev/shm/_flownet2/{t}/img_a.png"))
1227 fn_b = img_b.save(mkfile(f"/dev/shm/_flownet2/{t}/img_b.png"))
1228 elif mode == "net":
1229 assert False, "not impl"
1230 q = u2d.img2uri(img.pil("RGB"))
1231 q.decode()
1232 resp = requests.get(
1228 elif mode == "net":
1229 assert False, "not impl"
1230 q = u2d.img2uri(img.pil("RGB"))
1231 q.decode()
1232 resp = requests.get(
1233 url,
1234 params={
1235 "img_a": fn_a,
1236 "img_b": fn_b,
1237 "mode": mode,
1238 "back": back,
1239 # 'vis': vis,
1240 },
1241 )
1242
1243 # return
1244 ans = {"response": resp}
1250 }
1251 # if vis:
1252 # ans['output']['vis'] = I(j['fn_vis'])
1253 if mode == "shm":
1254 shutil.rmtree(f"/dev/shm/_flownet2/{t}")
1255 return ans
1256
1257
1411 1412 class GridnetTotalDropout(nn.Module): 1413 def __init__(self, p): 1414 super().__init__() 1415 assert 0 <= p < 1 1416 self.p = p 1417 self.weight = 1 / (1 - p) 1418 return
1470 x.view(bs * k, ch, h, w), 1471 is_flow=is_flow, 1472 ).view(bs, k, ch, *self.size) 1473 else: 1474 assert 0 1475 1476 1477 ###################### CANNY ######################
1543 bs, ch, h, w = img.shape 1544 if ch in [3, 4]: 1545 img = kornia.color.rgb_to_grayscale(img[:, :3]) 1546 else: 1547 assert ch == 1 1548 1549 # calculate dog 1550 kern0 = max(2 * int(sigma * kernel_factor) + 1, 3)
1582 return cd 1583 1584 1585 def batch_chamfer_distance_t(gt, pred, block=1024, return_more=False): 1586 assert gt.device == pred.device and gt.shape == pred.shape 1587 bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1] 1588 dpred = batch_edt(pred, block=block) 1589 cd = (gt * dpred).float().mean((-2, -1)) / np.sqrt(h**2 + w**2)
1587 bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1] 1588 dpred = batch_edt(pred, block=block) 1589 cd = (gt * dpred).float().mean((-2, -1)) / np.sqrt(h**2 + w**2) 1590 if len(cd.shape) == 2: 1591 assert cd.shape[1] == 1 1592 cd = cd.squeeze(1) 1593 return cd 1594
1593 return cd 1594 1595 1596 def batch_chamfer_distance_p(gt, pred, block=1024, return_more=False): 1597 assert gt.device == pred.device and gt.shape == pred.shape 1598 bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1] 1599 dgt = batch_edt(gt, block=block) 1600 cd = (pred * dgt).float().mean((-2, -1)) / np.sqrt(h**2 + w**2)
1598 bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1] 1599 dgt = batch_edt(gt, block=block) 1600 cd = (pred * dgt).float().mean((-2, -1)) / np.sqrt(h**2 + w**2) 1601 if len(cd.shape) == 2: 1602 assert cd.shape[1] == 1 1603 cd = cd.squeeze(1) 1604 return cd 1605
1606 1607 # normalized by diameter 1608 # always between [0,1] 1609 def batch_hausdorff_distance(gt, pred, block=1024, return_more=False): 1610 assert gt.device == pred.device and gt.shape == pred.shape 1611 bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1] 1612 dgt = batch_edt(gt, block=block) 1613 dpred = batch_edt(pred, block=block)
1617 (dpred * gt).amax(dim=(-2, -1)), 1618 ] 1619 ).amax(dim=0).float() / np.sqrt(h**2 + w**2) 1620 if len(hd.shape) == 2: 1621 assert hd.shape[1] == 1 1622 hd = hd.squeeze(1) 1623 return hd 1624
1662 if isinstance(x, torch.Tensor): 1663 return x.to(device) 1664 if isinstance(x, np.ndarray): 1665 return torch.tensor(x).to(device) 1666 assert 0, "data not understood" 1667 1668 1669 ################ PARSING ################
1737 with open(fn, mode) as handle: 1738 return handle.write(text) 1739 1740 1741 import pickle 1742 1743 1744 def dump(obj, fn, mode="wb"):
1748 1749 1750 def load(fn, mode="rb"): 1751 with open(fn, mode) as handle: 1752 return pickle.load(handle) 1753 1754 1755 import json
1777 def yread(fn, mode="r"): 1778 with open(fn, mode) as handle: 1779 return yaml.safe_load(handle) 1780 1781 except: 1782 pass 1783 1784 try: 1785 import pyunpack
1782 pass 1783 1784 try: 1785 import pyunpack 1786 except: 1787 pass 1788 1789 try: 1790 import mysql
1788 1789 try: 1790 import mysql 1791 import mysql.connector 1792 except: 1793 pass 1794 1795 1796 ################ MISC ################
1854 # calculate table size 1855 t = copy.deepcopy(self.t) 1856 totalrows = len(t) 1857 totalcols = [len(r) for r in t] 1858 assert min(totalcols) == max(totalcols) 1859 totalcols = totalcols[0] 1860 1861 # string-ify
1863 for j in range(totalcols):
1864 x, s = t[i][j]
1865 sp = s[11]
1866 if sp:
1867 x = eval(f'f"{{{x}{sp}}}"')
1868 Table._put((str(x), s), t, (i, j), empty)
1869
1870 # expand delimiters
780 kernel_size=size, 781 padding='same') 782 if activation is None: 783 return _conv 784 assert activation == 'relu' 785 return nn.Sequential( 786 _conv, 787 nn.LeakyReLU(.2)
105 self.test_branch_idx = test_branch_idx
106 self.norm = norm
107 self.activation = activation
108
109 assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
110
111 self.weight = nn.Parameter(
112 torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
123 def forward(self, inputs): 124 num_branch = ( 125 self.num_branch if self.training or self.test_branch_idx == -1 else 1 126 ) 127 assert len(inputs) == num_branch 128 129 if self.training or self.test_branch_idx == -1: 130 outputs = [
314 315 316 def single_head_full_attention(q, k, v): 317 # q, k, v: [B, L, C] 318 assert q.dim() == k.dim() == v.dim() == 3 319 320 scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** 0.5) # [B, L, L] 321 attn = torch.softmax(scores, dim=2) # [B, L, L]
376 attn_mask=None, 377 ): 378 # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py 379 # q, k, v: [B, L, C] 380 assert q.dim() == k.dim() == v.dim() == 3 381 382 assert h is not None and w is not None 383 assert q.size(1) == h * w
378 # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py 379 # q, k, v: [B, L, C] 380 assert q.dim() == k.dim() == v.dim() == 3 381 382 assert h is not None and w is not None 383 assert q.size(1) == h * w 384 385 b, _, c = q.size()
379 # q, k, v: [B, L, C] 380 assert q.dim() == k.dim() == v.dim() == 3 381 382 assert h is not None and w is not None 383 assert q.size(1) == h * w 384 385 b, _, c = q.size() 386
395 396 scale_factor = c**0.5 397 398 if with_shift: 399 assert attn_mask is not None # compute once 400 shift_size_h = window_size_h // 2 401 shift_size_w = window_size_w // 2 402
633 attn_num_splits=None, 634 **kwargs, 635 ): 636 b, c, h, w = feature0.shape 637 assert self.d_model == c 638 639 feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C] 640 feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
748 feature0, 749 flow, 750 local_window_radius=1, 751 ): 752 assert flow.size(1) == 2 753 assert local_window_radius > 0 754 755 b, c, h, w = feature0.size()
749 flow, 750 local_window_radius=1, 751 ): 752 assert flow.size(1) == 2 753 assert local_window_radius > 0 754 755 b, c, h, w = feature0.size() 756
933 return grid 934 935 936 def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): 937 assert device is not None 938 939 x, y = torch.meshgrid( 940 [
984 985 986 def flow_warp(feature, flow, mask=False, padding_mode="zeros"): 987 b, c, h, w = feature.size() 988 assert flow.size(1) == 2 989 990 grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] 991
994 995 def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5): 996 # fwd_flow, bwd_flow: [B, 2, H, W] 997 # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) 998 assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 999 assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 1000 flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] 1001
995 def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5): 996 # fwd_flow, bwd_flow: [B, 2, H, W] 997 # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) 998 assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 999 assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 1000 flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] 1001 1002 warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
1063 channel_last=False, 1064 ): 1065 if channel_last: # [B, H, W, C] 1066 b, h, w, c = feature.size() 1067 assert h % num_splits == 0 and w % num_splits == 0 1068 1069 b_new = b * num_splits * num_splits 1070 h_new = h // num_splits
1076 .reshape(b_new, h_new, w_new, c) 1077 ) # [B*K*K, H/K, W/K, C] 1078 else: # [B, C, H, W] 1079 b, c, h, w = feature.size() 1080 assert h % num_splits == 0 and w % num_splits == 0 1081 1082 b_new = b * num_splits * num_splits 1083 h_new = h // num_splits
1278 ) # list of features 1279 1280 flow = None 1281 1282 assert ( 1283 len(attn_splits_list) 1284 == len(corr_radius_list) 1285 == len(prop_radius_list) 1286 == self.num_scales 1287 ) 1288 1289 for scale_idx in range(self.num_scales): 1290 feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
104 self.test_branch_idx = test_branch_idx
105 self.norm = norm
106 self.activation = activation
107
108 assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
109
110 self.weight = nn.Parameter(
111 torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
122 def forward(self, inputs): 123 num_branch = ( 124 self.num_branch if self.training or self.test_branch_idx == -1 else 1 125 ) 126 assert len(inputs) == num_branch 127 128 if self.training or self.test_branch_idx == -1: 129 outputs = [
313 314 315 def single_head_full_attention(q, k, v): 316 # q, k, v: [B, L, C] 317 assert q.dim() == k.dim() == v.dim() == 3 318 319 scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** 0.5) # [B, L, L] 320 attn = torch.softmax(scores, dim=2) # [B, L, L]
375 attn_mask=None, 376 ): 377 # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py 378 # q, k, v: [B, L, C] 379 assert q.dim() == k.dim() == v.dim() == 3 380 381 assert h is not None and w is not None 382 assert q.size(1) == h * w
377 # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py 378 # q, k, v: [B, L, C] 379 assert q.dim() == k.dim() == v.dim() == 3 380 381 assert h is not None and w is not None 382 assert q.size(1) == h * w 383 384 b, _, c = q.size()
378 # q, k, v: [B, L, C] 379 assert q.dim() == k.dim() == v.dim() == 3 380 381 assert h is not None and w is not None 382 assert q.size(1) == h * w 383 384 b, _, c = q.size() 385
394 395 scale_factor = c**0.5 396 397 if with_shift: 398 assert attn_mask is not None # compute once 399 shift_size_h = window_size_h // 2 400 shift_size_w = window_size_w // 2 401
632 attn_num_splits=None, 633 **kwargs, 634 ): 635 b, c, h, w = feature0.shape 636 assert self.d_model == c 637 638 feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C] 639 feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
747 feature0, 748 flow, 749 local_window_radius=1, 750 ): 751 assert flow.size(1) == 2 752 assert local_window_radius > 0 753 754 b, c, h, w = feature0.size()
748 flow, 749 local_window_radius=1, 750 ): 751 assert flow.size(1) == 2 752 assert local_window_radius > 0 753 754 b, c, h, w = feature0.size() 755
932 return grid 933 934 935 def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): 936 assert device is not None 937 938 x, y = torch.meshgrid( 939 [
983 984 985 def flow_warp(feature, flow, mask=False, padding_mode="zeros"): 986 b, c, h, w = feature.size() 987 assert flow.size(1) == 2 988 989 grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] 990
993 994 def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5): 995 # fwd_flow, bwd_flow: [B, 2, H, W] 996 # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) 997 assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 998 assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 999 flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] 1000
994 def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5): 995 # fwd_flow, bwd_flow: [B, 2, H, W] 996 # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) 997 assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 998 assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 999 flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] 1000 1001 warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
1062 channel_last=False, 1063 ): 1064 if channel_last: # [B, H, W, C] 1065 b, h, w, c = feature.size() 1066 assert h % num_splits == 0 and w % num_splits == 0 1067 1068 b_new = b * num_splits * num_splits 1069 h_new = h // num_splits
1075 .reshape(b_new, h_new, w_new, c) 1076 ) # [B*K*K, H/K, W/K, C] 1077 else: # [B, C, H, W] 1078 b, c, h, w = feature.size() 1079 assert h % num_splits == 0 and w % num_splits == 0 1080 1081 b_new = b * num_splits * num_splits 1082 h_new = h // num_splits
1277 ) # list of features 1278 1279 flow = None 1280 1281 assert ( 1282 len(attn_splits_list) 1283 == len(corr_radius_list) 1284 == len(prop_radius_list) 1285 == self.num_scales 1286 ) 1287 1288 for scale_idx in range(self.num_scales): 1289 feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
613 self.maskconvx8 = nn.Conv2d(c, 8 * 8 * 9, 1, 1, 0) 614 self.maskconvx4 = nn.Conv2d(c, 4 * 4 * 9, 1, 1, 0) 615 616 self.level = level 617 assert self.level in [4, 8, 16], "Bitch" 618 619 def mask_conv(self, x): 620 if self.level == 4:
266 ) # https://github.com/pytorch/pytorch/issues/62854 267 268 # end 269 270 assert False # to make torchscript happy 271 272 # end 273
296 ] 297 fltStride *= 1.0 298 299 elif True: 300 assert False 301 302 # end 303 # end
384 385 # end 386 # end 387 388 assert len(intChans) == 1 389 390 # end 391
7 8 config_path = Path(Path(__file__).parent.parent.parent.resolve(), "config.yaml") 9 10 if os.path.exists(config_path): 11 config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) 12 ops_backend = config["ops_backend"] 13 else: 14 ops_backend = "taichi"
12 ops_backend = config["ops_backend"] 13 else: 14 ops_backend = "taichi" 15 16 assert ops_backend in ["taichi", "cupy"] 17 18 if ops_backend == "taichi": 19 from .taichi_ops import softsplat, ModuleSoftsplat, FunctionSoftsplat, softsplat_func, costvol_func, sepconv_func, init, batch_edt, FunctionAdaCoF, ModuleCorrelation, FunctionCorrelation, _FunctionCorrelation
270 intFilterSize = int(math.sqrt(weight.size(1))) 271 intOutputHeight = weight.size(2) 272 intOutputWidth = weight.size(3) 273 274 assert ( 275 intInputHeight - ((intFilterSize - 1) * dilation + 1) == intOutputHeight - 1 276 ) 277 assert ( 278 intInputWidth - ((intFilterSize - 1) * dilation + 1) == intOutputWidth - 1 279 )
273 274 assert ( 275 intInputHeight - ((intFilterSize - 1) * dilation + 1) == intOutputHeight - 1 276 ) 277 assert ( 278 intInputWidth - ((intFilterSize - 1) * dilation + 1) == intOutputWidth - 1 279 ) 280 281 assert input.is_contiguous() == True 282 assert weight.is_contiguous() == True
277 assert ( 278 intInputWidth - ((intFilterSize - 1) * dilation + 1) == intOutputWidth - 1 279 ) 280 281 assert input.is_contiguous() == True 282 assert weight.is_contiguous() == True 283 assert offset_i.is_contiguous() == True 284 assert offset_j.is_contiguous() == True
278 intInputWidth - ((intFilterSize - 1) * dilation + 1) == intOutputWidth - 1 279 ) 280 281 assert input.is_contiguous() == True 282 assert weight.is_contiguous() == True 283 assert offset_i.is_contiguous() == True 284 assert offset_j.is_contiguous() == True 285
279 ) 280 281 assert input.is_contiguous() == True 282 assert weight.is_contiguous() == True 283 assert offset_i.is_contiguous() == True 284 assert offset_j.is_contiguous() == True 285 286 output = input.new_zeros(
280 281 assert input.is_contiguous() == True 282 assert weight.is_contiguous() == True 283 assert offset_i.is_contiguous() == True 284 assert offset_j.is_contiguous() == True 285 286 output = input.new_zeros( 287 intSample, intInputDepth, intOutputHeight, intOutputWidth
343 intFilterSize = int(math.sqrt(weight.size(1))) 344 intOutputHeight = weight.size(2) 345 intOutputWidth = weight.size(3) 346 347 assert ( 348 intInputHeight - ((intFilterSize - 1) * dilation + 1) == intOutputHeight - 1 349 ) 350 assert ( 351 intInputWidth - ((intFilterSize - 1) * dilation + 1) == intOutputWidth - 1 352 )
346 347 assert ( 348 intInputHeight - ((intFilterSize - 1) * dilation + 1) == intOutputHeight - 1 349 ) 350 assert ( 351 intInputWidth - ((intFilterSize - 1) * dilation + 1) == intOutputWidth - 1 352 ) 353 354 assert gradOutput.is_contiguous() == True 355
350 assert ( 351 intInputWidth - ((intFilterSize - 1) * dilation + 1) == intOutputWidth - 1 352 ) 353 354 assert gradOutput.is_contiguous() == True 355 356 gradInput = ( 357 input.new_zeros(intSample, intInputDepth, intInputHeight, intInputWidth)
48 _batch_edt = cuda_launch(*_batch_edt_kernel) 49 50 # bookkeeppingg 51 if len(img.shape) == 4: 52 assert img.shape[1] == 1 53 img = img.squeeze(1) 54 expand = True 55 else:
241 242 self.save_for_backward(first, second, rbot0, rbot1) 243 244 first = first.contiguous() 245 assert first.is_cuda == True 246 second = second.contiguous() 247 assert second.is_cuda == True 248
243 244 first = first.contiguous() 245 assert first.is_cuda == True 246 second = second.contiguous() 247 assert second.is_cuda == True 248 249 output = first.new_zeros([first.shape[0], 81, first.shape[2], first.shape[3]]) 250
298 def backward(self, gradOutput): 299 first, second, rbot0, rbot1 = self.saved_tensors 300 301 gradOutput = gradOutput.contiguous() 302 assert gradOutput.is_cuda == True 303 304 gradFirst = ( 305 first.new_zeros(
189 def backward(self, tenOutgrad): 190 tenOne, tenTwo = self.saved_tensors 191 192 tenOutgrad = tenOutgrad.contiguous() 193 assert tenOutgrad.is_cuda == True 194 195 tenOnegrad = ( 196 tenOne.new_zeros(
189 ], 190 ) 191 192 elif tenIn.is_cuda != True: 193 assert False 194 195 # end 196
205 def backward(self, tenOutgrad): 206 tenIn, tenVer, tenHor = self.saved_tensors 207 208 tenOutgrad = tenOutgrad.contiguous() 209 assert tenOutgrad.is_cuda == True 210 211 tenIngrad = ( 212 tenIn.new_empty(
223 ), 224 ) 225 226 elif tenIn.is_cuda != True: 227 assert False 228 229 # end 230
239 def backward(self, tenOutgrad): 240 tenIn, tenFlow = self.saved_tensors 241 242 tenOutgrad = tenOutgrad.contiguous() 243 assert tenOutgrad.is_cuda == True 244 245 tenIngrad = ( 246 tenIn.new_zeros(
322 # end 323 324 325 def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): 326 assert tenMetric is None or tenMetric.shape[1] == 1 327 assert strType in ["summation", "average", "linear", "softmax"] 328 329 if strType == "average":
323 324 325 def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): 326 assert tenMetric is None or tenMetric.shape[1] == 1 327 assert strType in ["summation", "average", "linear", "softmax"] 328 329 if strType == "average": 330 tenInput = torch.cat(
381
382 def softsplat(
383 tenIn: torch.Tensor, tenFlow: torch.Tensor, tenMetric: torch.Tensor, strMode: str
384 ):
385 assert strMode.split("-")[0] in ["sum", "avg", "linear", "soft"]
386
387 if strMode == "sum":
388 assert tenMetric is None
384 ):
385 assert strMode.split("-")[0] in ["sum", "avg", "linear", "soft"]
386
387 if strMode == "sum":
388 assert tenMetric is None
389 if strMode == "avg":
390 assert tenMetric is None
391 if strMode.split("-")[0] == "linear":
386
387 if strMode == "sum":
388 assert tenMetric is None
389 if strMode == "avg":
390 assert tenMetric is None
391 if strMode.split("-")[0] == "linear":
392 assert tenMetric is not None
393 if strMode.split("-")[0] == "soft":
388 assert tenMetric is None
389 if strMode == "avg":
390 assert tenMetric is None
391 if strMode.split("-")[0] == "linear":
392 assert tenMetric is not None
393 if strMode.split("-")[0] == "soft":
394 assert tenMetric is not None
395
390 assert tenMetric is None
391 if strMode.split("-")[0] == "linear":
392 assert tenMetric is not None
393 if strMode.split("-")[0] == "soft":
394 assert tenMetric is not None
395
396 if strMode == "avg":
397 tenIn = torch.cat(
59 strKey += str(objValue.stride()) 60 61 elif True: 62 print(strVariable, type(objValue)) 63 assert False 64 65 # end 66 # end
105 strKernel = strKernel.replace("{{type}}", "long")
106
107 elif type(objValue) == torch.Tensor:
108 print(strVariable, objValue.dtype)
109 assert False
110
111 elif True:
112 print(strVariable, type(objValue))
109 assert False 110 111 elif True: 112 print(strVariable, type(objValue)) 113 assert False 114 115 # end 116 # end
180
181 intArgs = int(objMatch.group(2))
182 strArgs = strKernel[intStart:intStop].split(",")
183
184 assert intArgs == len(strArgs) - 1
185
186 strTensor = strArgs[0]
187 intStrides = objVariables[strTensor].stride()
27
28 def softsplat(
29 tenIn: torch.Tensor, tenFlow: torch.Tensor, tenMetric: torch.Tensor, strMode: str
30 ):
31 assert strMode.split("-")[0] in ["sum", "avg", "linear", "soft"]
32
33 if strMode == "sum":
34 assert tenMetric is None
30 ):
31 assert strMode.split("-")[0] in ["sum", "avg", "linear", "soft"]
32
33 if strMode == "sum":
34 assert tenMetric is None
35 if strMode == "avg":
36 assert tenMetric is None
37 if strMode.split("-")[0] == "linear":
32
33 if strMode == "sum":
34 assert tenMetric is None
35 if strMode == "avg":
36 assert tenMetric is None
37 if strMode.split("-")[0] == "linear":
38 assert tenMetric is not None
39 if strMode.split("-")[0] == "soft":
34 assert tenMetric is None
35 if strMode == "avg":
36 assert tenMetric is None
37 if strMode.split("-")[0] == "linear":
38 assert tenMetric is not None
39 if strMode.split("-")[0] == "soft":
40 assert tenMetric is not None
41
36 assert tenMetric is None
37 if strMode.split("-")[0] == "linear":
38 assert tenMetric is not None
39 if strMode.split("-")[0] == "soft":
40 assert tenMetric is not None
41
42 if strMode == "avg":
43 tenIn = torch.cat(
80 81 return tenOut 82 83 def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): 84 assert tenMetric is None or tenMetric.shape[1] == 1 85 assert strType in ["summation", "average", "linear", "softmax"] 86 87 if strType == "average":
81 return tenOut 82 83 def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): 84 assert tenMetric is None or tenMetric.shape[1] == 1 85 assert strType in ["summation", "average", "linear", "softmax"] 86 87 if strType == "average": 88 tenInput = torch.cat(
28 import typing
29
30 ##########################################################
31
32 assert (
33 int(str("").join(torch.__version__.split(".")[0:2])) >= 13
34 ) # requires at least pytorch version 1.3.0
35
36 torch.set_grad_enabled(
37 False
169 ) # https://github.com/pytorch/pytorch/issues/62854 170 171 # end 172 173 assert False # to make torchscript happy 174 175 # end 176
199 ] 200 fltStride *= 1.0 201 202 elif True: 203 assert False 204 205 # end 206 # end
287 288 # end 289 # end 290 291 assert len(intChans) == 1 292 293 # end 294
323 objScratch: typing.Dict[str, typing.List[int]], 324 ): 325 super().__init__() 326 327 assert len(intIns) == len(intOuts) 328 assert len(intOuts) == len(intIns) 329 330 self.intRows = len(intIns) and len(intOuts)
324 ): 325 super().__init__() 326 327 assert len(intIns) == len(intOuts) 328 assert len(intOuts) == len(intIns) 329 330 self.intRows = len(intIns) and len(intOuts) 331 self.intIns = intIns.copy()
418 objScratch: typing.Dict[str, typing.List[int]], 419 ): 420 super().__init__() 421 422 assert len(intIns) == len(intOuts) 423 assert len(intOuts) == len(intIns) 424 425 self.intRows = len(intIns) and len(intOuts)
419 ): 420 super().__init__() 421 422 assert len(intIns) == len(intOuts) 423 assert len(intOuts) == len(intIns) 424 425 self.intRows = len(intIns) and len(intOuts) 426 self.intIns = intIns.copy()
713 if netNetwork is None: 714 netNetwork = Network().to(get_torch_device()).eval() 715 # end 716 717 assert tenOne.shape[1] == tenTwo.shape[1] 718 assert tenOne.shape[2] == tenTwo.shape[2] 719 720 intWidth = tenOne.shape[2]
714 netNetwork = Network().to(get_torch_device()).eval() 715 # end 716 717 assert tenOne.shape[1] == tenTwo.shape[1] 718 assert tenOne.shape[2] == tenTwo.shape[2] 719 720 intWidth = tenOne.shape[2] 721 intHeight = tenOne.shape[1]
719 720 intWidth = tenOne.shape[2] 721 intHeight = tenOne.shape[1] 722 723 assert ( 724 intWidth <= 1280 725 ) # while our approach works with larger images, we do not recommend it unless you are aware of the implications 726 assert ( 727 intHeight <= 720 728 ) # while our approach works with larger images, we do not recommend it unless you are aware of the implications
722 723 assert ( 724 intWidth <= 1280 725 ) # while our approach works with larger images, we do not recommend it unless you are aware of the implications 726 assert ( 727 intHeight <= 720 728 ) # while our approach works with larger images, we do not recommend it unless you are aware of the implications 729 730 tenPreprocessedOne = tenOne.to(get_torch_device()).view(1, 3, intHeight, intWidth) 731 tenPreprocessedTwo = tenTwo.to(get_torch_device()).view(1, 3, intHeight, intWidth)
586 if netNetwork is None: 587 netNetwork = Network().cuda().eval() 588 # end 589 590 assert tenFirst.shape[1] == tenSecond.shape[1] 591 assert tenFirst.shape[2] == tenSecond.shape[2] 592 593 intWidth = tenFirst.shape[2]
587 netNetwork = Network().cuda().eval() 588 # end 589 590 assert tenFirst.shape[1] == tenSecond.shape[1] 591 assert tenFirst.shape[2] == tenSecond.shape[2] 592 593 intWidth = tenFirst.shape[2] 594 intHeight = tenFirst.shape[1]
592 593 intWidth = tenFirst.shape[2] 594 intHeight = tenFirst.shape[1] 595 596 assert ( 597 intWidth == 1024 598 ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 599 assert ( 600 intHeight == 436 601 ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
595 596 assert ( 597 intWidth == 1024 598 ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 599 assert ( 600 intHeight == 436 601 ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 602 603 tenPreprocessedFirst = tenFirst.cuda().view(1, 3, intHeight, intWidth) 604 tenPreprocessedSecond = tenSecond.cuda().view(1, 3, intHeight, intWidth)
1588 1589 self.n_row = n_row 1590 self.n_col = n_col 1591 self.n_chs = grid_chs 1592 assert ( 1593 len(grid_chs) == self.n_row 1594 ), "should give num channels for each row (scale stream)" 1595 assert ( 1596 len(in_chs) == self.n_row 1597 ), "should give input channels for each row (scale stream)"
1591 self.n_chs = grid_chs
1592 assert (
1593 len(grid_chs) == self.n_row
1594 ), "should give num channels for each row (scale stream)"
1595 assert (
1596 len(in_chs) == self.n_row
1597 ), "should give input channels for each row (scale stream)"
1598
1599 for r, n_ch in enumerate(self.n_chs):
1600 setattr(self, f"lateral_{r}_0", LateralBlock(in_chs[r], n_ch))
1611 1612 self.lateral_final = LateralBlock(self.n_chs[0], out_chs) 1613 1614 def forward(self, *args): 1615 assert len(args) == self.n_row 1616 1617 # extensible, memory-efficient 1618 cur_col = list(args)
1642 self.n_row = n_row 1643 self.n_col = n_col 1644 self.n_chs = grid_chs 1645 self.outrow = outrow 1646 assert ( 1647 len(grid_chs) == self.n_row 1648 ), "should give num channels for each row (scale stream)" 1649 assert ( 1650 len(in_chs) == self.n_row 1651 ), "should give input channels for each row (scale stream)"
1645 self.outrow = outrow 1646 assert ( 1647 len(grid_chs) == self.n_row 1648 ), "should give num channels for each row (scale stream)" 1649 assert ( 1650 len(in_chs) == self.n_row 1651 ), "should give input channels for each row (scale stream)" 1652 assert len(out_chs) == len( 1653 self.outrow 1654 ), "should give out channels for each output row (scale stream)"
1648 ), "should give num channels for each row (scale stream)"
1649 assert (
1650 len(in_chs) == self.n_row
1651 ), "should give input channels for each row (scale stream)"
1652 assert len(out_chs) == len(
1653 self.outrow
1654 ), "should give out channels for each output row (scale stream)"
1655
1656 for r, n_ch in enumerate(self.n_chs):
1657 setattr(self, f"lateral_{r}_0", LateralBlock(in_chs[r], n_ch))
1669 for i, r in enumerate(outrow):
1670 setattr(self, f"lateral_final_{r}", LateralBlock(self.n_chs[r], out_chs[i]))
1671
1672 def forward(self, *args):
1673 assert len(args) == self.n_row
1674
1675 # extensible, memory-efficient
1676 cur_col = list(args)
1701 1702 self.n_row = n_row 1703 self.n_col = n_col 1704 self.n_chs = grid_chs 1705 assert ( 1706 len(grid_chs) == self.n_row 1707 ), "should give num channels for each row (scale stream)" 1708 1709 for r, n_ch in enumerate(self.n_chs): 1710 if r == 0:
1750 1751 self.n_row = 3 1752 self.n_col = 6 1753 self.n_chs = grid_chs 1754 assert ( 1755 len(grid_chs) == self.n_row 1756 ), "should give num channels for each row (scale stream)" 1757 1758 self.lateral_init = LateralBlock(in_chs, self.n_chs[0]) 1759
2339 if netNetwork is None: 2340 netNetwork = Network().cuda().eval() 2341 # end 2342 2343 assert tenFirst.shape[1] == tenSecond.shape[1] 2344 assert tenFirst.shape[2] == tenSecond.shape[2] 2345 2346 intWidth = tenFirst.shape[2]
2340 netNetwork = Network().cuda().eval() 2341 # end 2342 2343 assert tenFirst.shape[1] == tenSecond.shape[1] 2344 assert tenFirst.shape[2] == tenSecond.shape[2] 2345 2346 intWidth = tenFirst.shape[2] 2347 intHeight = tenFirst.shape[1]
2345 2346 intWidth = tenFirst.shape[2] 2347 intHeight = tenFirst.shape[1] 2348 2349 assert ( 2350 intWidth == 1024 2351 ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 2352 assert ( 2353 intHeight == 436 2354 ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
2348 2349 assert ( 2350 intWidth == 1024 2351 ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 2352 assert ( 2353 intHeight == 436 2354 ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 2355 2356 tenPreprocessedFirst = tenFirst.cuda().view(1, 3, intHeight, intWidth) 2357 tenPreprocessedSecond = tenSecond.cuda().view(1, 3, intHeight, intWidth)
44 t_value shape : [B,1] ############### 45 ''' 46 B, C, T, H, W = x.size() 47 B2, C2 = t_value.size() 48 assert C2 == 1, "t_value shape is [B,]" 49 assert T % 2 == 0, "T must be an even number" 50 t_value = t_value.view(B, 1, 1, 1) 51
45 ''' 46 B, C, T, H, W = x.size() 47 B2, C2 = t_value.size() 48 assert C2 == 1, "t_value shape is [B,]" 49 assert T % 2 == 0, "T must be an even number" 50 t_value = t_value.view(B, 1, 1, 1) 51 52 flow_l = None
142 x shape : [B,C,T,H,W] 143 t_value shape : [B,1] ############### 144 ''' 145 B, C, T, H, W = x.size() 146 assert T % 2 == 0, "T must be an even number" 147 148 ####################### For a single level 149 l = 2 ** level
18 ]
19
20 config_path = os.path.join(os.path.dirname(__file__), "./config.yaml")
21 if os.path.exists(config_path):
22 config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
23 else:
24 raise Exception("config.yaml file is neccessary, plz recreate the config file by downloading it from https://github.com/Fannovel16/ComfyUI-Frame-Interpolation")
25 DEVICE = get_torch_device()
118 return einops.rearrange(frames, "n c h w -> n h w c")[..., :3].cpu()
119
120 def assert_batch_size(frames, batch_size=2, vfi_name=None):
121 subject_verb = "Most VFI models require" if vfi_name is None else f"VFI model {vfi_name} requires"
122 assert len(frames) >= batch_size, f"{subject_verb} at least {batch_size} frames to work with, only found {frames.shape[0]}. Please check the frame input using PreviewImage."
123
124 def _generic_frame_loop(
125 frames,