1 import errno 2 import shlex 3 import subprocess 4 5 __version__ = "0.3.2" 6 7 8 class FFmpeg(object):
91 :raise: `FFRuntimeError` in case FFmpeg command exits with a non-zero code; 92 `FFExecutableNotFoundError` in case the executable path passed was not valid 93 """ 94 try: 95 self.process = subprocess.Popen( 96 self._cmd, 97 stdin=subprocess.PIPE, 98 stdout=stdout, 99 stderr=stderr, 100 env=env, 101 **kwargs 102 ) 103 except OSError as e: 104 if e.errno == errno.ENOENT: 105 raise FFExecutableNotFoundError(
1 import errno 2 import shlex 3 import subprocess 4 5 __version__ = "0.3.2" 6 7 8 class FFmpeg(object):
91 :raise: `FFRuntimeError` in case FFmpeg command exits with a non-zero code; 92 `FFExecutableNotFoundError` in case the executable path passed was not valid 93 """ 94 try: 95 self.process = subprocess.Popen( 96 self._cmd, 97 stdin=subprocess.PIPE, 98 stdout=stdout, 99 stderr=stderr, 100 env=env, 101 **kwargs 102 ) 103 except OSError as e: 104 if e.errno == errno.ENOENT: 105 raise FFExecutableNotFoundError(
4 import os 5 import posixpath 6 import random 7 import re 8 import subprocess 9 import time 10 import torch 11 import torchaudio
41 is_hidden = os.path.basename(f.path).startswith(".")
42
43 if file_ext in ext and not is_hidden:
44 files.append(f.path)
45 except:
46 pass
47 except:
48 pass
49
43 if file_ext in ext and not is_hidden: 44 files.append(f.path) 45 except: 46 pass 47 except: 48 pass 49 50 for dir in list(subfolders): 51 sf, f = fast_scandir(dir, ext)
79 has_banned = any(
80 [banned_word in name_lower for banned_word in banned_words])
81 if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"):
82 files.append(f.path)
83 except:
84 pass
85 except:
86 pass
87
81 if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"):
82 files.append(f.path)
83 except:
84 pass
85 except:
86 pass
87
88 for dir in list(subfolders):
89 sf, f = keyword_scandir(dir, ext, keywords)
220 custom_metadata = custom_metadata_fn(info, audio) 221 info.update(custom_metadata) 222 223 if "__reject__" in info and info["__reject__"]: 224 return self[random.randrange(len(self))] 225 226 return (audio, info) 227 except Exception as e:
225
226 return (audio, info)
227 except Exception as e:
228 print(f'Couldn\'t load file {audio_filename}: {e}')
229 return self[random.randrange(len(self))]
230
231 def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixes=None, handler=None):
232 """Return function over iterator that groups key, value pairs into samples.
234 :param lcase: convert suffixes to lower case (Default value = True) 235 """ 236 current_sample = None 237 for filesample in data: 238 assert isinstance(filesample, dict) 239 fname, value = filesample["fname"], filesample["data"] 240 prefix, suffix = keys(fname) 241 if wds.tariterators.trace:
282 # Add the --recursive flag if requested
283 cmd.append('--recursive')
284
285 # Run the `aws s3 ls` command and capture the output
286 run_ls = subprocess.run(cmd, capture_output=True, check=True)
287 # Split the output into lines and strip whitespace from each line
288 contents = run_ls.stdout.decode('utf-8').split('\n')
289 contents = [x.strip() for x in contents if x]
556 def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4):
557
558 dataset_type = dataset_config.get("dataset_type", None)
559
560 assert dataset_type is not None, "Dataset type must be specified in dataset config"
561
562 if audio_channels == 1:
563 force_channels = "mono"
567 if dataset_type == "audio_dir":
568
569 audio_dir_configs = dataset_config.get("datasets", None)
570
571 assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
572
573 configs = []
574
573 configs = []
574
575 for audio_dir_config in audio_dir_configs:
576 audio_dir_path = audio_dir_config.get("path", None)
577 assert audio_dir_path is not None, "Path must be set for local audio directory configuration"
578
579 custom_metadata_fn = None
580 custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None)
38 39 # If randomize is False, always start at the beginning of the audio 40 offset = 0 41 if(self.randomize and n_samples > self.n_samples): 42 offset = random.randint(0, upper_bound) 43 44 # Calculate the start and end times of the chunk 45 t_start = offset / (upper_bound + self.n_samples)
74 def __init__(self, p=0.5): 75 super().__init__() 76 self.p = p 77 def __call__(self, signal): 78 return -signal if (random.random() < self.p) else signal 79 80 class Mono(nn.Module): 81 def __call__(self, signal):
146 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 147 torch.backends.cudnn.benchmark = False 148 149 # Conditioning 150 assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors" 151 if conditioning_tensors is None: 152 conditioning_tensors = model.conditioner(conditioning, device) 153 conditioning_inputs = model.get_conditioning_inputs(conditioning_tensors)
191 # This is helpful for forward and reverse outpainting 192 cropfrom = math.floor(mask_args["cropfrom"]/100.0 * sample_size) 193 pastefrom = math.floor(mask_args["pastefrom"]/100.0 * sample_size) 194 pasteto = math.ceil(mask_args["pasteto"]/100.0 * sample_size) 195 assert pastefrom < pasteto, "Paste From should be less than Paste To" 196 croplen = pasteto - pastefrom 197 if cropfrom + croplen > sample_size: 198 croplen = sample_size - cropfrom
655 return ui 656 657 def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False): 658 659 assert (pretrained_name is not None) ^ (model_config_path is not None and ckpt_path is not None), "Must specify either pretrained name or provide a model config and checkpoint, but not both" 660 661 if model_config_path is not None: 662 # Load config from json file
99 If this is the case, we insert extra 0 padding to the right before the reflection happen. 100 """ 101 length = x.shape[-1] 102 padding_left, padding_right = paddings 103 assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) 104 if mode == 'reflect': 105 max_pad = max(padding_left, padding_right) 106 extra_pad = 0
116 117 def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): 118 """Remove padding from x, handling properly zero padding. Only for 1d!""" 119 padding_left, padding_right = paddings 120 assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) 121 assert (padding_left + padding_right) <= x.shape[-1] 122 end = x.shape[-1] - padding_right 123 return x[..., padding_left: end]
117 def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): 118 """Remove padding from x, handling properly zero padding. Only for 1d!""" 119 padding_left, padding_right = paddings 120 assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) 121 assert (padding_left + padding_right) <= x.shape[-1] 122 end = x.shape[-1] - padding_right 123 return x[..., padding_left: end] 124
173 174 def Downsample1d( 175 in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 176 ) -> nn.Module: 177 assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" 178 179 return Conv1d( 180 in_channels=in_channels,
305 use_snake=use_snake 306 ) 307 308 if self.use_mapping: 309 assert exists(context_mapping_features) 310 self.to_scale_shift = MappingToScaleShift( 311 features=context_mapping_features, channels=out_channels 312 )
326 ) 327 328 def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: 329 assert_message = "context mapping required if context_mapping_features > 0" 330 assert not (self.use_mapping ^ exists(mapping)), assert_message 331 332 h = self.block1(x, causal=causal) 333
350 use_snake: bool = False,
351 ):
352 super().__init__()
353 assert_message = f"out_channels must be divisible by patch_size ({patch_size})"
354 assert out_channels % patch_size == 0, assert_message
355 self.patch_size = patch_size
356
357 self.block = ResnetBlock1d(
378 use_snake: bool = False
379 ):
380 super().__init__()
381 assert_message = f"in_channels must be divisible by patch_size ({patch_size})"
382 assert in_channels % patch_size == 0, assert_message
383 self.patch_size = patch_size
384
385 self.block = ResnetBlock1d(
522 context_mask: Optional[Tensor] = None, # [b, m], false is masked, 523 causal: Optional[bool] = False, 524 ) -> Tensor: 525 assert_message = "You must provide a context when using context_features" 526 assert not self.context_features or exists(context), assert_message 527 # Use context if provided 528 context = default(context, x) 529 # Normalize then compute q from input and k,v from context
669 """Used for continuous time""" 670 671 def __init__(self, dim: int): 672 super().__init__() 673 assert (dim % 2) == 0 674 half_dim = dim // 2 675 self.weights = nn.Parameter(torch.randn(half_dim)) 676
745 ] 746 ) 747 748 if self.use_transformer: 749 assert ( 750 (exists(attention_heads) or exists(attention_features)) 751 and exists(attention_multiplier) 752 ) 753 754 if attention_features is None and attention_heads is not None: 755 attention_features = channels // attention_heads
857 ] 858 ) 859 860 if self.use_transformer: 861 assert ( 862 (exists(attention_heads) or exists(attention_features)) 863 and exists(attention_multiplier) 864 ) 865 866 if attention_features is None and attention_heads is not None: 867 attention_features = channels // attention_heads
953 use_snake=use_snake 954 ) 955 956 if self.use_transformer: 957 assert ( 958 (exists(attention_heads) or exists(attention_features)) 959 and exists(attention_multiplier) 960 ) 961 962 if attention_features is None and attention_heads is not None: 963 attention_features = channels // attention_heads
1055 has_context = [c > 0 for c in context_channels] 1056 self.has_context = has_context 1057 self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))] 1058 1059 assert ( 1060 len(factors) == num_layers 1061 and len(attentions) >= num_layers 1062 and len(num_blocks) == num_layers 1063 ) 1064 1065 if use_context_time or use_context_features: 1066 context_mapping_features = channels * context_features_multiplier
1072 nn.GELU(), 1073 ) 1074 1075 if use_context_time: 1076 assert exists(context_mapping_features) 1077 self.to_time = nn.Sequential( 1078 TimePositionalEmbedding( 1079 dim=channels, out_features=context_mapping_features
1081 nn.GELU(), 1082 ) 1083 1084 if use_context_features: 1085 assert exists(context_features) and exists(context_mapping_features) 1086 self.to_features = nn.Sequential( 1087 nn.Linear( 1088 in_features=context_features, out_features=context_mapping_features
1091 )
1092
1093 if use_stft:
1094 stft_kwargs, kwargs = groupby("stft_", kwargs)
1095 assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True"
1096 stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2
1097 in_channels *= stft_channels
1098 out_channels *= stft_channels
1096 stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2
1097 in_channels *= stft_channels
1098 out_channels *= stft_channels
1099 context_channels[0] *= stft_channels if use_stft_context else 1
1100 assert exists(in_channels) and exists(out_channels)
1101 self.stft = STFT(**stft_kwargs)
1102
1103 assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}"
1099 context_channels[0] *= stft_channels if use_stft_context else 1
1100 assert exists(in_channels) and exists(out_channels)
1101 self.stft = STFT(**stft_kwargs)
1102
1103 assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}"
1104
1105 self.to_in = Patcher(
1106 in_channels=in_channels + context_channels[0],
1179 """Gets context channels at `layer` and checks that shape is correct""" 1180 use_context_channels = self.use_context_channels and self.has_context[layer] 1181 if not use_context_channels: 1182 return None 1183 assert exists(channels_list), "Missing context" 1184 # Get channels index (skipping zero channel contexts) 1185 channels_id = self.channels_ids[layer] 1186 # Get channels
1185 channels_id = self.channels_ids[layer]
1186 # Get channels
1187 channels = channels_list[channels_id]
1188 message = f"Missing context for layer {layer} at index {channels_id}"
1189 assert exists(channels), message
1190 # Check channels
1191 num_channels = self.context_channels[layer]
1192 message = f"Expected context with {num_channels} channels at idx {channels_id}"
1189 assert exists(channels), message
1190 # Check channels
1191 num_channels = self.context_channels[layer]
1192 message = f"Expected context with {num_channels} channels at idx {channels_id}"
1193 assert channels.shape[1] == num_channels, message
1194 # STFT channels if requested
1195 channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa
1196 return channels
1202 items, mapping = [], None 1203 # Compute time features 1204 if self.use_context_time: 1205 assert_message = "use_context_time=True but no time features provided" 1206 assert exists(time), assert_message 1207 items += [self.to_time(time)] 1208 # Compute features 1209 if self.use_context_features:
1207 items += [self.to_time(time)] 1208 # Compute features 1209 if self.use_context_features: 1210 assert_message = "context_features exists but no features provided" 1211 assert exists(features), assert_message 1212 items += [self.to_features(features)] 1213 # Compute joint mapping 1214 if self.use_context_time or self.use_context_features:
1268 1269 def forward(self, x: Tensor) -> Tensor: 1270 batch_size, length, device = *x.shape[0:2], x.device 1271 assert_message = "Input sequence length must be <= max_length" 1272 assert length <= self.max_length, assert_message 1273 position = torch.arange(length, device=device) 1274 fixed_embedding = self.embedding(position) 1275 fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
1302 1303 self.use_xattn_time = use_xattn_time 1304 1305 if use_xattn_time: 1306 assert exists(context_embedding_features) 1307 self.to_time_embedding = nn.Sequential( 1308 TimePositionalEmbedding( 1309 dim=kwargs["channels"], out_features=context_embedding_features
1491 def forward(self, x: Union[List[float], Tensor]) -> Tensor: 1492 if not torch.is_tensor(x): 1493 device = next(self.embedding.parameters()).device 1494 x = torch.tensor(x, device=device) 1495 assert isinstance(x, Tensor) 1496 shape = x.shape 1497 x = rearrange(x, "... -> (...)") 1498 embedding = self.embedding(x)
365 Decode discrete tokens to audio 366 Only works with discrete autoencoders 367 ''' 368 369 assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders" 370 371 latents = self.bottleneck.decode_tokens(tokens, **kwargs) 372
395 ''' 396 batch_size = len(audio_list) 397 if isinstance(in_sr_list, int): 398 in_sr_list = [in_sr_list]*batch_size 399 assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list" 400 new_audio = [] 401 max_length = 0 402 # resample & find the max length
408 audio = audio.squeeze(0) 409 elif len(audio.shape) == 1: 410 # Mono signal, channel dimension is missing, unsqueeze it in 411 audio = audio.unsqueeze(0) 412 assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension" 413 # Resample audio 414 if in_sr != self.sample_rate: 415 resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
609 # AE factories
610
611 def create_encoder_from_config(encoder_config: Dict[str, Any]):
612 encoder_type = encoder_config.get("type", None)
613 assert encoder_type is not None, "Encoder type must be specified"
614
615 if encoder_type == "oobleck":
616 encoder = OobleckEncoder(
649 return encoder
650
651 def create_decoder_from_config(decoder_config: Dict[str, Any]):
652 decoder_type = decoder_config.get("type", None)
653 assert decoder_type is not None, "Decoder type must be specified"
654
655 if decoder_type == "oobleck":
656 decoder = OobleckDecoder(
693
694 bottleneck = ae_config.get("bottleneck", None)
695
696 latent_dim = ae_config.get("latent_dim", None)
697 assert latent_dim is not None, "latent_dim must be specified in model config"
698 downsampling_ratio = ae_config.get("downsampling_ratio", None)
699 assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
700 io_channels = ae_config.get("io_channels", None)
695
696 latent_dim = ae_config.get("latent_dim", None)
697 assert latent_dim is not None, "latent_dim must be specified in model config"
698 downsampling_ratio = ae_config.get("downsampling_ratio", None)
699 assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
700 io_channels = ae_config.get("io_channels", None)
701 assert io_channels is not None, "io_channels must be specified in model config"
702 sample_rate = config.get("sample_rate", None)
697 assert latent_dim is not None, "latent_dim must be specified in model config"
698 downsampling_ratio = ae_config.get("downsampling_ratio", None)
699 assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
700 io_channels = ae_config.get("io_channels", None)
701 assert io_channels is not None, "io_channels must be specified in model config"
702 sample_rate = config.get("sample_rate", None)
703 assert sample_rate is not None, "sample_rate must be specified in model config"
704
699 assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
700 io_channels = ae_config.get("io_channels", None)
701 assert io_channels is not None, "io_channels must be specified in model config"
702 sample_rate = config.get("sample_rate", None)
703 assert sample_rate is not None, "sample_rate must be specified in model config"
704
705 in_channels = ae_config.get("in_channels", None)
706 out_channels = ae_config.get("out_channels", None)
752 elif diffusion_model_type == "dit":
753 diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
754
755 latent_dim = diffae_config.get("latent_dim", None)
756 assert latent_dim is not None, "latent_dim must be specified in model config"
757 downsampling_ratio = diffae_config.get("downsampling_ratio", None)
758 assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
759 io_channels = diffae_config.get("io_channels", None)
754
755 latent_dim = diffae_config.get("latent_dim", None)
756 assert latent_dim is not None, "latent_dim must be specified in model config"
757 downsampling_ratio = diffae_config.get("downsampling_ratio", None)
758 assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
759 io_channels = diffae_config.get("io_channels", None)
760 assert io_channels is not None, "io_channels must be specified in model config"
761 sample_rate = config.get("sample_rate", None)
756 assert latent_dim is not None, "latent_dim must be specified in model config"
757 downsampling_ratio = diffae_config.get("downsampling_ratio", None)
758 assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
759 io_channels = diffae_config.get("io_channels", None)
760 assert io_channels is not None, "io_channels must be specified in model config"
761 sample_rate = config.get("sample_rate", None)
762 assert sample_rate is not None, "sample_rate must be specified in model config"
763
758 assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
759 io_channels = diffae_config.get("io_channels", None)
760 assert io_channels is not None, "io_channels must be specified in model config"
761 sample_rate = config.get("sample_rate", None)
762 assert sample_rate is not None, "sample_rate must be specified in model config"
763
764 bottleneck = diffae_config.get("bottleneck", None)
765
33 34 class SelfAttention1d(nn.Module): 35 def __init__(self, c_in, n_head=1, dropout_rate=0.): 36 super().__init__() 37 assert c_in % n_head == 0 38 self.norm = nn.GroupNorm(1, c_in) 39 self.n_head = n_head 40 self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
83 84 class FourierFeatures(nn.Module): 85 def __init__(self, in_features, out_features, std=1.): 86 super().__init__() 87 assert out_features % 2 == 0 88 self.weight = nn.Parameter(torch.randn( 89 [out_features // 2, in_features]) * std) 90
153 154 def Downsample1d_2( 155 in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 156 ) -> nn.Module: 157 assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" 158 159 return nn.Conv1d( 160 in_channels=in_channels,
44 timesteps: int 45 n_q: int 46 47 def __post_init__(self): 48 assert len(self.layout) > 0 49 self._validate_layout() 50 self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes) 51 self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
64 qs = set()
65 for coord in seq_coords:
66 qs.add(coord.q)
67 last_q_timestep = q_timesteps[coord.q]
68 assert coord.t >= last_q_timestep, \
69 f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
70 q_timesteps[coord.q] = coord.t
71 # each sequence step contains at max 1 coordinate per codebook
72 assert len(qs) == len(seq_coords), \
68 assert coord.t >= last_q_timestep, \
69 f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
70 q_timesteps[coord.q] = coord.t
71 # each sequence step contains at max 1 coordinate per codebook
72 assert len(qs) == len(seq_coords), \
73 f"Multiple entries for a same codebook are found at step {s}"
74
75 @property
76 def num_sequence_steps(self):
96 """Get codebook coordinates in the layout that corresponds to the specified timestep t 97 and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step 98 and the actual codebook coordinates. 99 """ 100 assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps" 101 if q is not None: 102 assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks" 103 coords = []
98 and the actual codebook coordinates. 99 """ 100 assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps" 101 if q is not None: 102 assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks" 103 coords = [] 104 for s, seq_codes in enumerate(self.layout): 105 for code in seq_codes:
125 Returns:
126 indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
127 mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
128 """
129 assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
130 assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
131 # use the proper layout based on whether we limit ourselves to valid steps only or not,
132 # note that using the valid_layout will result in a truncated sequence up to the valid steps
126 indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
127 mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
128 """
129 assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
130 assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
131 # use the proper layout based on whether we limit ourselves to valid steps only or not,
132 # note that using the valid_layout will result in a truncated sequence up to the valid steps
133 ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
195 """
196 ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
197 # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
198 timesteps = self.timesteps
199 assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
200 assert sequence_steps <= len(ref_layout), \
201 f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
202
196 ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
197 # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
198 timesteps = self.timesteps
199 assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
200 assert sequence_steps <= len(ref_layout), \
201 f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
202
203 # ensure we take the appropriate indexes to keep the model output from the first special token as well
204 if is_model_output and self.starts_with_special_token():
284 cached (bool): if True, patterns for a given length are cached. In general 285 that should be true for efficiency reason to avoid synchronization points. 286 """ 287 def __init__(self, n_q: int, cached: bool = True): 288 assert n_q > 0 289 self.n_q = n_q 290 self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore 291
329 delays = list(range(n_q)) 330 self.delays = delays 331 self.flatten_first = flatten_first 332 self.empty_initial = empty_initial 333 assert len(self.delays) == self.n_q 334 assert sorted(self.delays) == self.delays 335 336 def get_pattern(self, timesteps: int) -> Pattern:
330 self.delays = delays 331 self.flatten_first = flatten_first 332 self.empty_initial = empty_initial 333 assert len(self.delays) == self.n_q 334 assert sorted(self.delays) == self.delays 335 336 def get_pattern(self, timesteps: int) -> Pattern: 337 omit_special_token = self.empty_initial < 0
423 if flattening is None: 424 flattening = list(range(n_q)) 425 if delays is None: 426 delays = [0] * n_q 427 assert len(flattening) == n_q 428 assert len(delays) == n_q 429 assert sorted(flattening) == flattening 430 assert sorted(delays) == delays
424 flattening = list(range(n_q)) 425 if delays is None: 426 delays = [0] * n_q 427 assert len(flattening) == n_q 428 assert len(delays) == n_q 429 assert sorted(flattening) == flattening 430 assert sorted(delays) == delays 431 self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
425 if delays is None: 426 delays = [0] * n_q 427 assert len(flattening) == n_q 428 assert len(delays) == n_q 429 assert sorted(flattening) == flattening 430 assert sorted(delays) == delays 431 self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening) 432 self.max_delay = max(delays)
426 delays = [0] * n_q 427 assert len(flattening) == n_q 428 assert len(delays) == n_q 429 assert sorted(flattening) == flattening 430 assert sorted(delays) == delays 431 self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening) 432 self.max_delay = max(delays) 433
441 if inner_step not in flattened_codebooks: 442 flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay) 443 else: 444 flat_codebook = flattened_codebooks[inner_step] 445 assert flat_codebook.delay == delay, ( 446 "Delay and flattening between codebooks is inconsistent: ", 447 "two codebooks flattened to the same position should have the same delay." 448 ) 449 flat_codebook.codebooks.append(q) 450 flattened_codebooks[inner_step] = flat_codebook 451 return flattened_codebooks
505 super().__init__(n_q) 506 if delays is None: 507 delays = [0] * (n_q - 1) 508 self.delays = delays 509 assert len(self.delays) == self.n_q - 1 510 assert sorted(self.delays) == self.delays 511 512 def get_pattern(self, timesteps: int) -> Pattern:
506 if delays is None: 507 delays = [0] * (n_q - 1) 508 self.delays = delays 509 assert len(self.delays) == self.n_q - 1 510 assert sorted(self.delays) == self.delays 511 512 def get_pattern(self, timesteps: int) -> Pattern: 513 out: PatternLayout = [[]]
265 max_length: str = 128,
266 enable_grad: bool = False,
267 project_out: bool = False
268 ):
269 assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}"
270 super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out)
271
272 from transformers import T5EncoderModel, AutoTokenizer
546 elif conditioner_type == "lut":
547 conditioners[id] = TokenizerLUTConditioner(**conditioner_config)
548 elif conditioner_type == "pretransform":
549 sample_rate = conditioner_config.pop("sample_rate", None)
550 assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners"
551
552 pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate)
553
525 rescale_cfg: bool = False, 526 scale_phi: float = 0.0, 527 **kwargs): 528 529 assert batch_cfg, "batch_cfg must be True for DiTWrapper" 530 #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper" 531 532 return self.model(
571 model_type = diffusion_uncond_config.get('type', None)
572
573 diffusion_config = diffusion_uncond_config.get('config', {})
574
575 assert model_type is not None, "Must specify model type in config"
576
577 pretransform = diffusion_uncond_config.get("pretransform", None)
578
576
577 pretransform = diffusion_uncond_config.get("pretransform", None)
578
579 sample_size = config.get("sample_size", None)
580 assert sample_size is not None, "Must specify sample size in config"
581
582 sample_rate = config.get("sample_rate", None)
583 assert sample_rate is not None, "Must specify sample rate in config"
579 sample_size = config.get("sample_size", None)
580 assert sample_size is not None, "Must specify sample size in config"
581
582 sample_rate = config.get("sample_rate", None)
583 assert sample_rate is not None, "Must specify sample rate in config"
584
585 if pretransform is not None:
586 pretransform = create_pretransform_from_config(pretransform, sample_rate)
621
622 model_type = config["model_type"]
623
624 diffusion_config = model_config.get('diffusion', None)
625 assert diffusion_config is not None, "Must specify diffusion config"
626
627 diffusion_model_type = diffusion_config.get('type', None)
628 assert diffusion_model_type is not None, "Must specify diffusion model type"
624 diffusion_config = model_config.get('diffusion', None)
625 assert diffusion_config is not None, "Must specify diffusion config"
626
627 diffusion_model_type = diffusion_config.get('type', None)
628 assert diffusion_model_type is not None, "Must specify diffusion model type"
629
630 diffusion_model_config = diffusion_config.get('config', None)
631 assert diffusion_model_config is not None, "Must specify diffusion model config"
627 diffusion_model_type = diffusion_config.get('type', None)
628 assert diffusion_model_type is not None, "Must specify diffusion model type"
629
630 diffusion_model_config = diffusion_config.get('config', None)
631 assert diffusion_model_config is not None, "Must specify diffusion model config"
632
633 if diffusion_model_type == 'adp_cfg_1d':
634 diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
637 elif diffusion_model_type == 'dit':
638 diffusion_model = DiTWrapper(**diffusion_model_config)
639
640 io_channels = model_config.get('io_channels', None)
641 assert io_channels is not None, "Must specify io_channels in model config"
642
643 sample_rate = config.get('sample_rate', None)
644 assert sample_rate is not None, "Must specify sample_rate in config"
640 io_channels = model_config.get('io_channels', None)
641 assert io_channels is not None, "Must specify io_channels in model config"
642
643 sample_rate = config.get('sample_rate', None)
644 assert sample_rate is not None, "Must specify sample_rate in config"
645
646 diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
647
679 extra_kwargs["diffusion_objective"] = diffusion_objective
680
681 elif model_type == "diffusion_prior":
682 prior_type = model_config.get("prior_type", None)
683 assert prior_type is not None, "Must specify prior_type in diffusion prior model config"
684
685 if prior_type == "mono_stereo":
686 from .diffusion_prior import MonoToStereoDiffusionPrior
248 mask=None, 249 return_info=False, 250 **kwargs): 251 252 assert causal == False, "Causal mode is not supported for DiffusionTransformer" 253 254 if cross_attn_cond_mask is not None: 255 cross_attn_cond_mask = cross_attn_cond_mask.bool()
2
3 def create_model_from_config(model_config):
4 model_type = model_config.get('model_type', None)
5
6 assert model_type is not None, 'model_type must be specified in model config'
7
8 if model_type == 'autoencoder':
9 from .autoencoders import create_autoencoder_from_config
31
32 def create_pretransform_from_config(pretransform_config, sample_rate):
33 pretransform_type = pretransform_config.get('type', None)
34
35 assert pretransform_type is not None, 'type must be specified in pretransform config'
36
37 if pretransform_type == 'autoencoder':
38 from .autoencoders import create_autoencoder_from_config
83
84 def create_bottleneck_from_config(bottleneck_config):
85 bottleneck_type = bottleneck_config.get('type', None)
86
87 assert bottleneck_type is not None, 'type must be specified in bottleneck config'
88
89 if bottleneck_type == 'tanh':
90 from .bottleneck import TanhBottleneck
66 ): 67 68 batch, num_quantizers, seq_len = sequence.shape 69 70 assert num_quantizers == self.num_quantizers, "Number of quantizers in sequence must match number of quantizers in model" 71 72 backbone_input = sum([self.embeds[i](sequence[:, i]) for i in range(num_quantizers)]) # [batch, seq_len, embed_dim] 73
150 global_cond_ids: tp.List[str] = [] 151 ): 152 super().__init__() 153 154 assert pretransform.is_discrete, "Pretransform must be discrete" 155 self.pretransform = pretransform 156 157 self.pretransform.requires_grad_(False)
370 possible_batch_sizes.append(conditioning_tensors[list(conditioning_tensors.keys())[0]][0].shape[0]) 371 else: 372 possible_batch_sizes.append(1) 373 374 assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs" 375 376 batch_size = possible_batch_sizes[0] 377
376 batch_size = possible_batch_sizes[0] 377 378 if init_data is None: 379 # Initialize with zeros 380 assert batch_size > 0 381 init_data = torch.zeros((batch_size, self.num_quantizers, 0), device=device, dtype=torch.long) 382 383 batch_size, num_quantizers, seq_len = init_data.shape
382 383 batch_size, num_quantizers, seq_len = init_data.shape 384 385 start_offset = seq_len 386 assert start_offset < max_gen_len, "init data longer than max gen length" 387 388 pattern = self.lm.pattern_provider.get_pattern(max_gen_len) 389
395 396 gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) # [batch, num_quantizers, gen_sequence_len] 397 398 start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) 399 assert start_offset_sequence is not None 400 401 # Generation 402 prev_offset = 0
439 # Callback to report progress 440 # Pass in the offset relative to the start of the sequence, and the length of the current sequence 441 callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence) 442 443 assert not (gen_sequence == unknown_token).any(), "Unknown tokens in generated sequence" 444 445 out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) 446
444 445 out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) 446 447 # sanity checks over the returned codes and corresponding masks 448 assert (out_codes[..., :max_gen_len] != unknown_token).all() 449 assert (out_mask[..., :max_gen_len] == 1).all() 450 451 #out_codes = out_codes[..., 0:max_gen_len]
445 out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) 446 447 # sanity checks over the returned codes and corresponding masks 448 assert (out_codes[..., :max_gen_len] != unknown_token).all() 449 assert (out_mask[..., :max_gen_len] == 1).all() 450 451 #out_codes = out_codes[..., 0:max_gen_len] 452
469
470
471 def create_audio_lm_from_config(config):
472 model_config = config.get('model', None)
473 assert model_config is not None, 'model config must be specified in config'
474
475 sample_rate = config.get('sample_rate', None)
476 assert sample_rate is not None, "Must specify sample_rate in config"
472 model_config = config.get('model', None)
473 assert model_config is not None, 'model config must be specified in config'
474
475 sample_rate = config.get('sample_rate', None)
476 assert sample_rate is not None, "Must specify sample_rate in config"
477
478 lm_config = model_config.get('lm', None)
479 assert lm_config is not None, 'lm config must be specified in model config'
475 sample_rate = config.get('sample_rate', None)
476 assert sample_rate is not None, "Must specify sample_rate in config"
477
478 lm_config = model_config.get('lm', None)
479 assert lm_config is not None, 'lm config must be specified in model config'
480
481 codebook_pattern = lm_config.get("codebook_pattern", "delay")
482
490 pretransform_config = model_config.get("pretransform", None)
491
492 pretransform = create_pretransform_from_config(pretransform_config, sample_rate)
493
494 assert pretransform.is_discrete, "Pretransform must be discrete"
495
496 min_input_length = pretransform.downsampling_ratio
497
509
510 lm_type = lm_config.get("type", None)
511 lm_model_config = lm_config.get("config", None)
512
513 assert lm_type is not None, "Must specify lm type in lm config"
514 assert lm_model_config is not None, "Must specify lm model config in lm config"
515
516 if lm_type == "x-transformers":
510 lm_type = lm_config.get("type", None)
511 lm_model_config = lm_config.get("config", None)
512
513 assert lm_type is not None, "Must specify lm type in lm config"
514 assert lm_model_config is not None, "Must specify lm model config in lm config"
515
516 if lm_type == "x-transformers":
517 backbone = XTransformersAudioLMBackbone(**lm_model_config)
20 super(PQMF, self).__init__() 21 22 # Ensure num_bands is a power of 2 23 is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands))) 24 assert is_power_of_2, "'num_bands' must be a power of 2." 25 26 # Create the prototype filter 27 prototype_filter = design_prototype_filter(attenuation, num_bands)
74 75 return decoded 76 77 def tokenize(self, x, **kwargs): 78 assert self.model.is_discrete, "Cannot tokenize with a continuous model" 79 80 _, info = self.model.encode(x, return_info = True, **kwargs) 81
81 82 return info[self.model.bottleneck.tokens_id] 83 84 def decode_tokens(self, tokens, **kwargs): 85 assert self.model.is_discrete, "Cannot decode tokens with a continuous model" 86 87 return self.model.decode_tokens(tokens, **kwargs) 88
221 self.model.to(torch.float16).eval().requires_grad_(False) 222 223 def encode(self, x): 224 225 assert False, "Audiocraft compression models do not support continuous encoding" 226 227 # latents = self.model.encoder(x) 228
238 # return output 239 240 def decode(self, z): 241 242 assert False, "Audiocraft compression models do not support continuous decoding" 243 244 # if self.scale != 1.0: 245 # z = z * self.scale
49 self.emb = nn.Embedding(max_seq_len, dim)
50
51 def forward(self, x, pos = None, seq_start_pos = None):
52 seq_len, device = x.shape[1], x.device
53 assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
54
55 if pos is None:
56 pos = torch.arange(seq_len, device = device)
64 65 class ScaledSinusoidalEmbedding(nn.Module): 66 def __init__(self, dim, theta = 10000): 67 super().__init__() 68 assert (dim % 2) == 0, 'dimension must be divisible by 2' 69 self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) 70 71 half_dim = dim // 2
104
105 inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
106 self.register_buffer('inv_freq', inv_freq)
107
108 assert interpolation_factor >= 1.
109 self.interpolation_factor = interpolation_factor
110
111 if not use_xpos:
346 if q_len == 1 and causal: 347 causal = False 348 349 if mask is not None: 350 assert mask.ndim == 4 351 mask = mask.expand(batch, heads, q_len, k_len) 352 353 # handle kv cache - this should be bypassable in updated flash attention 2
478 out = natten.functional.natten1dav(attn, v, kernel_size = self.natten_kernel_size, dilation=1).to(dtype_in) 479 480 # Prioritize Flash Attention 2 481 elif self.use_fa_flash: 482 assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2' 483 # Flash Attention 2 requires FP16 inputs 484 fa_dtype_in = q.dtype 485 q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.float16), (q, k, v))
768 769 if prepend_embeds is not None: 770 prepend_length, prepend_dim = prepend_embeds.shape[1:] 771 772 assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension' 773 774 x = torch.cat((prepend_embeds, x), dim = -2) 775
23 self.wavelet = wavelet 24 self.channels = channels 25 self.levels = levels 26 filt = get_filter_bank(wavelet) 27 assert filt.shape[-1] % 2 == 1 28 kernel = filt[:2, None] 29 kernel = torch.flip(kernel, dims=(-1,)) 30 index_i = torch.repeat_interleave(torch.arange(2), channels)
55 self.wavelet = wavelet 56 self.channels = channels 57 self.levels = levels 58 filt = get_filter_bank(wavelet) 59 assert filt.shape[-1] % 2 == 1 60 kernel = filt[2:, None] 61 index_i = torch.repeat_interleave(torch.arange(2), channels) 62 index_j = torch.tile(torch.arange(channels), (2,))
272 self.losses = MultiLoss(self.loss_modules)
273
274 self.log_loss_info = log_loss_info
275
276 assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config"
277
278 if optimizer_configs is None:
279 optimizer_configs = {
328 with torch.cuda.amp.autocast(): 329 conditioning = self.diffusion.conditioner(metadata, self.device) 330 331 # If mask_padding is on, randomly drop the padding masks to allow for learning silence padding 332 use_padding_mask = self.mask_padding and random.random() > self.mask_padding_dropout 333 334 # Create batch tensor of attention masks from the "mask" field of the metadata array 335 if use_padding_mask:
618 self.losses = MultiLoss(self.loss_modules)
619
620 self.log_loss_info = log_loss_info
621
622 assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config"
623
624 if optimizer_configs is None:
625 optimizer_configs = {
660 # Create a mask tensor for each batch element 661 masks = [] 662 663 for i in range(b): 664 mask_type = random.randint(0, 2) 665 666 if mask_type == 0: # Random mask with multiple segments 667 num_segments = random.randint(1, self.max_mask_segments)
663 for i in range(b): 664 mask_type = random.randint(0, 2) 665 666 if mask_type == 0: # Random mask with multiple segments 667 num_segments = random.randint(1, self.max_mask_segments) 668 max_segment_length = max_mask_length // num_segments 669 670 segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments)
670 segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments) 671 672 mask = torch.ones((1, 1, sequence_length)) 673 for length in segment_lengths: 674 mask_start = random.randint(0, sequence_length - length) 675 mask[:, :, mask_start:mask_start + length] = 0 676 677 elif mask_type == 1: # Full mask
678 mask = torch.zeros((1, 1, sequence_length)) 679 680 elif mask_type == 2: # Causal mask 681 mask = torch.ones((1, 1, sequence_length)) 682 mask_length = random.randint(1, max_mask_length) 683 mask[:, :, -mask_length:] = 0 684 685 mask = mask.to(sequence.device)
1189 for j in js: 1190 if i == j or (i != j and sources_added < num_sources): 1191 # Randomly offset the mixed element between 0 and the length of the source 1192 seq_len = reals.shape[2] 1193 offset = random.randint(0, seq_len-1) 1194 source[i, :, offset:] += reals[j, :, :-offset] 1195 if i == j: 1196 # If this is the real one, shift the reals as well to ensure alignment
3 from ..models.factory import create_model_from_config
4
5 def create_training_wrapper_from_config(model_config, model):
6 model_type = model_config.get('model_type', None)
7 assert model_type is not None, 'model_type must be specified in model config'
8
9 training_config = model_config.get('training', None)
10 assert training_config is not None, 'training config must be specified in model config'
6 model_type = model_config.get('model_type', None)
7 assert model_type is not None, 'model_type must be specified in model config'
8
9 training_config = model_config.get('training', None)
10 assert training_config is not None, 'training config must be specified in model config'
11
12 if model_type == 'autoencoder':
13 from .autoencoders import AutoencoderTrainingWrapper
157 raise NotImplementedError(f'Unknown model type: {model_type}')
158
159 def create_demo_callback_from_config(model_config, **kwargs):
160 model_type = model_config.get('model_type', None)
161 assert model_type is not None, 'model_type must be specified in model config'
162
163 training_config = model_config.get('training', None)
164 assert training_config is not None, 'training config must be specified in model config'
160 model_type = model_config.get('model_type', None)
161 assert model_type is not None, 'model_type must be specified in model config'
162
163 training_config = model_config.get('training', None)
164 assert training_config is not None, 'training config must be specified in model config'
165
166 demo_config = training_config.get("demo", {})
167
36 self.model_ema = None
37 if use_ema:
38 self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10)
39
40 assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config"
41
42 if optimizer_configs is None:
43 optimizer_configs = {
93 ce (torch.Tensor): Cross entropy averaged over the codebooks 94 ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). 95 """ 96 B, K, T = targets.shape 97 assert logits.shape[:-1] == targets.shape 98 assert mask.shape == targets.shape 99 ce = torch.zeros([], device=targets.device) 100 ce_per_codebook: tp.List[torch.Tensor] = []
94 ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). 95 """ 96 B, K, T = targets.shape 97 assert logits.shape[:-1] == targets.shape 98 assert mask.shape == targets.shape 99 ce = torch.zeros([], device=targets.device) 100 ce_per_codebook: tp.List[torch.Tensor] = [] 101 for k in range(K):
327 print(e)
328 print("Try `pip install auraloss[all]`.")
329
330 if self.scale == "mel":
331 assert sample_rate != None # Must set sample rate to use mel scale
332 assert n_bins <= fft_size # Must be more FFT bins than Mel bins
333 fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins)
334 fb = torch.tensor(fb).unsqueeze(0)
328 print("Try `pip install auraloss[all]`.")
329
330 if self.scale == "mel":
331 assert sample_rate != None # Must set sample rate to use mel scale
332 assert n_bins <= fft_size # Must be more FFT bins than Mel bins
333 fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins)
334 fb = torch.tensor(fb).unsqueeze(0)
335
333 fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins) 334 fb = torch.tensor(fb).unsqueeze(0) 335 336 elif self.scale == "chroma": 337 assert sample_rate != None # Must set sample rate to use chroma scale 338 assert n_bins <= fft_size # Must be more FFT bins than chroma bins 339 fb = librosa.filters.chroma( 340 sr=sample_rate, n_fft=fft_size, n_chroma=n_bins
334 fb = torch.tensor(fb).unsqueeze(0) 335 336 elif self.scale == "chroma": 337 assert sample_rate != None # Must set sample rate to use chroma scale 338 assert n_bins <= fft_size # Must be more FFT bins than chroma bins 339 fb = librosa.filters.chroma( 340 sr=sample_rate, n_fft=fft_size, n_chroma=n_bins 341 )
481 scale_invariance: bool = False, 482 **kwargs, 483 ): 484 super().__init__() 485 assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) # must define all 486 self.fft_sizes = fft_sizes 487 self.hop_sizes = hop_sizes 488 self.win_lengths = win_lengths
588 loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'. 589 loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor): 590 Aggregate and intermediate loss terms. Only returned if output='full'. 591 """ 592 assert input.shape == target.shape # must have same shape 593 bs, chs, seq_len = input.size() 594 595 # compute sum and difference signals for both