1 import errno 2 import shlex 3 import subprocess 4 5 __version__ = "0.3.2" 6 7 8 class FFmpeg(object):
91 :raise: `FFRuntimeError` in case FFmpeg command exits with a non-zero code; 92 `FFExecutableNotFoundError` in case the executable path passed was not valid 93 """ 94 try: 95 self.process = subprocess.Popen( 96 self._cmd, 97 stdin=subprocess.PIPE, 98 stdout=stdout, 99 stderr=stderr, 100 env=env, 101 **kwargs 102 ) 103 except OSError as e: 104 if e.errno == errno.ENOENT: 105 raise FFExecutableNotFoundError(
1 import errno 2 import shlex 3 import subprocess 4 5 __version__ = "0.3.2" 6 7 8 class FFmpeg(object):
91 :raise: `FFRuntimeError` in case FFmpeg command exits with a non-zero code; 92 `FFExecutableNotFoundError` in case the executable path passed was not valid 93 """ 94 try: 95 self.process = subprocess.Popen( 96 self._cmd, 97 stdin=subprocess.PIPE, 98 stdout=stdout, 99 stderr=stderr, 100 env=env, 101 **kwargs 102 ) 103 except OSError as e: 104 if e.errno == errno.ENOENT: 105 raise FFExecutableNotFoundError(
4 import os 5 import posixpath 6 import random 7 import re 8 import subprocess 9 import time 10 import torch 11 import torchaudio
41 is_hidden = os.path.basename(f.path).startswith(".") 42 43 if file_ext in ext and not is_hidden: 44 files.append(f.path) 45 except: 46 pass 47 except: 48 pass 49
43 if file_ext in ext and not is_hidden: 44 files.append(f.path) 45 except: 46 pass 47 except: 48 pass 49 50 for dir in list(subfolders): 51 sf, f = fast_scandir(dir, ext)
79 has_banned = any( 80 [banned_word in name_lower for banned_word in banned_words]) 81 if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"): 82 files.append(f.path) 83 except: 84 pass 85 except: 86 pass 87
81 if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"): 82 files.append(f.path) 83 except: 84 pass 85 except: 86 pass 87 88 for dir in list(subfolders): 89 sf, f = keyword_scandir(dir, ext, keywords)
220 custom_metadata = custom_metadata_fn(info, audio) 221 info.update(custom_metadata) 222 223 if "__reject__" in info and info["__reject__"]: 224 return self[random.randrange(len(self))] 225 226 return (audio, info) 227 except Exception as e:
225 226 return (audio, info) 227 except Exception as e: 228 print(f'Couldn\'t load file {audio_filename}: {e}') 229 return self[random.randrange(len(self))] 230 231 def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixes=None, handler=None): 232 """Return function over iterator that groups key, value pairs into samples.
234 :param lcase: convert suffixes to lower case (Default value = True) 235 """ 236 current_sample = None 237 for filesample in data: 238 assert isinstance(filesample, dict) 239 fname, value = filesample["fname"], filesample["data"] 240 prefix, suffix = keys(fname) 241 if wds.tariterators.trace:
282 # Add the --recursive flag if requested 283 cmd.append('--recursive') 284 285 # Run the `aws s3 ls` command and capture the output 286 run_ls = subprocess.run(cmd, capture_output=True, check=True) 287 # Split the output into lines and strip whitespace from each line 288 contents = run_ls.stdout.decode('utf-8').split('\n') 289 contents = [x.strip() for x in contents if x]
556 def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4): 557 558 dataset_type = dataset_config.get("dataset_type", None) 559 560 assert dataset_type is not None, "Dataset type must be specified in dataset config" 561 562 if audio_channels == 1: 563 force_channels = "mono"
567 if dataset_type == "audio_dir": 568 569 audio_dir_configs = dataset_config.get("datasets", None) 570 571 assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" 572 573 configs = [] 574
573 configs = [] 574 575 for audio_dir_config in audio_dir_configs: 576 audio_dir_path = audio_dir_config.get("path", None) 577 assert audio_dir_path is not None, "Path must be set for local audio directory configuration" 578 579 custom_metadata_fn = None 580 custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None)
38 39 # If randomize is False, always start at the beginning of the audio 40 offset = 0 41 if(self.randomize and n_samples > self.n_samples): 42 offset = random.randint(0, upper_bound) 43 44 # Calculate the start and end times of the chunk 45 t_start = offset / (upper_bound + self.n_samples)
74 def __init__(self, p=0.5): 75 super().__init__() 76 self.p = p 77 def __call__(self, signal): 78 return -signal if (random.random() < self.p) else signal 79 80 class Mono(nn.Module): 81 def __call__(self, signal):
146 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 147 torch.backends.cudnn.benchmark = False 148 149 # Conditioning 150 assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors" 151 if conditioning_tensors is None: 152 conditioning_tensors = model.conditioner(conditioning, device) 153 conditioning_inputs = model.get_conditioning_inputs(conditioning_tensors)
191 # This is helpful for forward and reverse outpainting 192 cropfrom = math.floor(mask_args["cropfrom"]/100.0 * sample_size) 193 pastefrom = math.floor(mask_args["pastefrom"]/100.0 * sample_size) 194 pasteto = math.ceil(mask_args["pasteto"]/100.0 * sample_size) 195 assert pastefrom < pasteto, "Paste From should be less than Paste To" 196 croplen = pasteto - pastefrom 197 if cropfrom + croplen > sample_size: 198 croplen = sample_size - cropfrom
655 return ui 656 657 def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False): 658 659 assert (pretrained_name is not None) ^ (model_config_path is not None and ckpt_path is not None), "Must specify either pretrained name or provide a model config and checkpoint, but not both" 660 661 if model_config_path is not None: 662 # Load config from json file
99 If this is the case, we insert extra 0 padding to the right before the reflection happen. 100 """ 101 length = x.shape[-1] 102 padding_left, padding_right = paddings 103 assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) 104 if mode == 'reflect': 105 max_pad = max(padding_left, padding_right) 106 extra_pad = 0
116 117 def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): 118 """Remove padding from x, handling properly zero padding. Only for 1d!""" 119 padding_left, padding_right = paddings 120 assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) 121 assert (padding_left + padding_right) <= x.shape[-1] 122 end = x.shape[-1] - padding_right 123 return x[..., padding_left: end]
117 def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): 118 """Remove padding from x, handling properly zero padding. Only for 1d!""" 119 padding_left, padding_right = paddings 120 assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) 121 assert (padding_left + padding_right) <= x.shape[-1] 122 end = x.shape[-1] - padding_right 123 return x[..., padding_left: end] 124
173 174 def Downsample1d( 175 in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 176 ) -> nn.Module: 177 assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" 178 179 return Conv1d( 180 in_channels=in_channels,
305 use_snake=use_snake 306 ) 307 308 if self.use_mapping: 309 assert exists(context_mapping_features) 310 self.to_scale_shift = MappingToScaleShift( 311 features=context_mapping_features, channels=out_channels 312 )
326 ) 327 328 def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: 329 assert_message = "context mapping required if context_mapping_features > 0" 330 assert not (self.use_mapping ^ exists(mapping)), assert_message 331 332 h = self.block1(x, causal=causal) 333
350 use_snake: bool = False, 351 ): 352 super().__init__() 353 assert_message = f"out_channels must be divisible by patch_size ({patch_size})" 354 assert out_channels % patch_size == 0, assert_message 355 self.patch_size = patch_size 356 357 self.block = ResnetBlock1d(
378 use_snake: bool = False 379 ): 380 super().__init__() 381 assert_message = f"in_channels must be divisible by patch_size ({patch_size})" 382 assert in_channels % patch_size == 0, assert_message 383 self.patch_size = patch_size 384 385 self.block = ResnetBlock1d(
522 context_mask: Optional[Tensor] = None, # [b, m], false is masked, 523 causal: Optional[bool] = False, 524 ) -> Tensor: 525 assert_message = "You must provide a context when using context_features" 526 assert not self.context_features or exists(context), assert_message 527 # Use context if provided 528 context = default(context, x) 529 # Normalize then compute q from input and k,v from context
669 """Used for continuous time""" 670 671 def __init__(self, dim: int): 672 super().__init__() 673 assert (dim % 2) == 0 674 half_dim = dim // 2 675 self.weights = nn.Parameter(torch.randn(half_dim)) 676
745 ] 746 ) 747 748 if self.use_transformer: 749 assert ( 750 (exists(attention_heads) or exists(attention_features)) 751 and exists(attention_multiplier) 752 ) 753 754 if attention_features is None and attention_heads is not None: 755 attention_features = channels // attention_heads
857 ] 858 ) 859 860 if self.use_transformer: 861 assert ( 862 (exists(attention_heads) or exists(attention_features)) 863 and exists(attention_multiplier) 864 ) 865 866 if attention_features is None and attention_heads is not None: 867 attention_features = channels // attention_heads
953 use_snake=use_snake 954 ) 955 956 if self.use_transformer: 957 assert ( 958 (exists(attention_heads) or exists(attention_features)) 959 and exists(attention_multiplier) 960 ) 961 962 if attention_features is None and attention_heads is not None: 963 attention_features = channels // attention_heads
1055 has_context = [c > 0 for c in context_channels] 1056 self.has_context = has_context 1057 self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))] 1058 1059 assert ( 1060 len(factors) == num_layers 1061 and len(attentions) >= num_layers 1062 and len(num_blocks) == num_layers 1063 ) 1064 1065 if use_context_time or use_context_features: 1066 context_mapping_features = channels * context_features_multiplier
1072 nn.GELU(), 1073 ) 1074 1075 if use_context_time: 1076 assert exists(context_mapping_features) 1077 self.to_time = nn.Sequential( 1078 TimePositionalEmbedding( 1079 dim=channels, out_features=context_mapping_features
1081 nn.GELU(), 1082 ) 1083 1084 if use_context_features: 1085 assert exists(context_features) and exists(context_mapping_features) 1086 self.to_features = nn.Sequential( 1087 nn.Linear( 1088 in_features=context_features, out_features=context_mapping_features
1091 ) 1092 1093 if use_stft: 1094 stft_kwargs, kwargs = groupby("stft_", kwargs) 1095 assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True" 1096 stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2 1097 in_channels *= stft_channels 1098 out_channels *= stft_channels
1096 stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2 1097 in_channels *= stft_channels 1098 out_channels *= stft_channels 1099 context_channels[0] *= stft_channels if use_stft_context else 1 1100 assert exists(in_channels) and exists(out_channels) 1101 self.stft = STFT(**stft_kwargs) 1102 1103 assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}"
1099 context_channels[0] *= stft_channels if use_stft_context else 1 1100 assert exists(in_channels) and exists(out_channels) 1101 self.stft = STFT(**stft_kwargs) 1102 1103 assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}" 1104 1105 self.to_in = Patcher( 1106 in_channels=in_channels + context_channels[0],
1179 """Gets context channels at `layer` and checks that shape is correct""" 1180 use_context_channels = self.use_context_channels and self.has_context[layer] 1181 if not use_context_channels: 1182 return None 1183 assert exists(channels_list), "Missing context" 1184 # Get channels index (skipping zero channel contexts) 1185 channels_id = self.channels_ids[layer] 1186 # Get channels
1185 channels_id = self.channels_ids[layer] 1186 # Get channels 1187 channels = channels_list[channels_id] 1188 message = f"Missing context for layer {layer} at index {channels_id}" 1189 assert exists(channels), message 1190 # Check channels 1191 num_channels = self.context_channels[layer] 1192 message = f"Expected context with {num_channels} channels at idx {channels_id}"
1189 assert exists(channels), message 1190 # Check channels 1191 num_channels = self.context_channels[layer] 1192 message = f"Expected context with {num_channels} channels at idx {channels_id}" 1193 assert channels.shape[1] == num_channels, message 1194 # STFT channels if requested 1195 channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa 1196 return channels
1202 items, mapping = [], None 1203 # Compute time features 1204 if self.use_context_time: 1205 assert_message = "use_context_time=True but no time features provided" 1206 assert exists(time), assert_message 1207 items += [self.to_time(time)] 1208 # Compute features 1209 if self.use_context_features:
1207 items += [self.to_time(time)] 1208 # Compute features 1209 if self.use_context_features: 1210 assert_message = "context_features exists but no features provided" 1211 assert exists(features), assert_message 1212 items += [self.to_features(features)] 1213 # Compute joint mapping 1214 if self.use_context_time or self.use_context_features:
1268 1269 def forward(self, x: Tensor) -> Tensor: 1270 batch_size, length, device = *x.shape[0:2], x.device 1271 assert_message = "Input sequence length must be <= max_length" 1272 assert length <= self.max_length, assert_message 1273 position = torch.arange(length, device=device) 1274 fixed_embedding = self.embedding(position) 1275 fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
1302 1303 self.use_xattn_time = use_xattn_time 1304 1305 if use_xattn_time: 1306 assert exists(context_embedding_features) 1307 self.to_time_embedding = nn.Sequential( 1308 TimePositionalEmbedding( 1309 dim=kwargs["channels"], out_features=context_embedding_features
1491 def forward(self, x: Union[List[float], Tensor]) -> Tensor: 1492 if not torch.is_tensor(x): 1493 device = next(self.embedding.parameters()).device 1494 x = torch.tensor(x, device=device) 1495 assert isinstance(x, Tensor) 1496 shape = x.shape 1497 x = rearrange(x, "... -> (...)") 1498 embedding = self.embedding(x)
365 Decode discrete tokens to audio 366 Only works with discrete autoencoders 367 ''' 368 369 assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders" 370 371 latents = self.bottleneck.decode_tokens(tokens, **kwargs) 372
395 ''' 396 batch_size = len(audio_list) 397 if isinstance(in_sr_list, int): 398 in_sr_list = [in_sr_list]*batch_size 399 assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list" 400 new_audio = [] 401 max_length = 0 402 # resample & find the max length
408 audio = audio.squeeze(0) 409 elif len(audio.shape) == 1: 410 # Mono signal, channel dimension is missing, unsqueeze it in 411 audio = audio.unsqueeze(0) 412 assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension" 413 # Resample audio 414 if in_sr != self.sample_rate: 415 resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
609 # AE factories 610 611 def create_encoder_from_config(encoder_config: Dict[str, Any]): 612 encoder_type = encoder_config.get("type", None) 613 assert encoder_type is not None, "Encoder type must be specified" 614 615 if encoder_type == "oobleck": 616 encoder = OobleckEncoder(
649 return encoder 650 651 def create_decoder_from_config(decoder_config: Dict[str, Any]): 652 decoder_type = decoder_config.get("type", None) 653 assert decoder_type is not None, "Decoder type must be specified" 654 655 if decoder_type == "oobleck": 656 decoder = OobleckDecoder(
693 694 bottleneck = ae_config.get("bottleneck", None) 695 696 latent_dim = ae_config.get("latent_dim", None) 697 assert latent_dim is not None, "latent_dim must be specified in model config" 698 downsampling_ratio = ae_config.get("downsampling_ratio", None) 699 assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" 700 io_channels = ae_config.get("io_channels", None)
695 696 latent_dim = ae_config.get("latent_dim", None) 697 assert latent_dim is not None, "latent_dim must be specified in model config" 698 downsampling_ratio = ae_config.get("downsampling_ratio", None) 699 assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" 700 io_channels = ae_config.get("io_channels", None) 701 assert io_channels is not None, "io_channels must be specified in model config" 702 sample_rate = config.get("sample_rate", None)
697 assert latent_dim is not None, "latent_dim must be specified in model config" 698 downsampling_ratio = ae_config.get("downsampling_ratio", None) 699 assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" 700 io_channels = ae_config.get("io_channels", None) 701 assert io_channels is not None, "io_channels must be specified in model config" 702 sample_rate = config.get("sample_rate", None) 703 assert sample_rate is not None, "sample_rate must be specified in model config" 704
699 assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" 700 io_channels = ae_config.get("io_channels", None) 701 assert io_channels is not None, "io_channels must be specified in model config" 702 sample_rate = config.get("sample_rate", None) 703 assert sample_rate is not None, "sample_rate must be specified in model config" 704 705 in_channels = ae_config.get("in_channels", None) 706 out_channels = ae_config.get("out_channels", None)
752 elif diffusion_model_type == "dit": 753 diffusion = DiTWrapper(**diffae_config["diffusion"]["config"]) 754 755 latent_dim = diffae_config.get("latent_dim", None) 756 assert latent_dim is not None, "latent_dim must be specified in model config" 757 downsampling_ratio = diffae_config.get("downsampling_ratio", None) 758 assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" 759 io_channels = diffae_config.get("io_channels", None)
754 755 latent_dim = diffae_config.get("latent_dim", None) 756 assert latent_dim is not None, "latent_dim must be specified in model config" 757 downsampling_ratio = diffae_config.get("downsampling_ratio", None) 758 assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" 759 io_channels = diffae_config.get("io_channels", None) 760 assert io_channels is not None, "io_channels must be specified in model config" 761 sample_rate = config.get("sample_rate", None)
756 assert latent_dim is not None, "latent_dim must be specified in model config" 757 downsampling_ratio = diffae_config.get("downsampling_ratio", None) 758 assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" 759 io_channels = diffae_config.get("io_channels", None) 760 assert io_channels is not None, "io_channels must be specified in model config" 761 sample_rate = config.get("sample_rate", None) 762 assert sample_rate is not None, "sample_rate must be specified in model config" 763
758 assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" 759 io_channels = diffae_config.get("io_channels", None) 760 assert io_channels is not None, "io_channels must be specified in model config" 761 sample_rate = config.get("sample_rate", None) 762 assert sample_rate is not None, "sample_rate must be specified in model config" 763 764 bottleneck = diffae_config.get("bottleneck", None) 765
33 34 class SelfAttention1d(nn.Module): 35 def __init__(self, c_in, n_head=1, dropout_rate=0.): 36 super().__init__() 37 assert c_in % n_head == 0 38 self.norm = nn.GroupNorm(1, c_in) 39 self.n_head = n_head 40 self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
83 84 class FourierFeatures(nn.Module): 85 def __init__(self, in_features, out_features, std=1.): 86 super().__init__() 87 assert out_features % 2 == 0 88 self.weight = nn.Parameter(torch.randn( 89 [out_features // 2, in_features]) * std) 90
153 154 def Downsample1d_2( 155 in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 156 ) -> nn.Module: 157 assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" 158 159 return nn.Conv1d( 160 in_channels=in_channels,
44 timesteps: int 45 n_q: int 46 47 def __post_init__(self): 48 assert len(self.layout) > 0 49 self._validate_layout() 50 self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes) 51 self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
64 qs = set() 65 for coord in seq_coords: 66 qs.add(coord.q) 67 last_q_timestep = q_timesteps[coord.q] 68 assert coord.t >= last_q_timestep, \ 69 f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}" 70 q_timesteps[coord.q] = coord.t 71 # each sequence step contains at max 1 coordinate per codebook 72 assert len(qs) == len(seq_coords), \
68 assert coord.t >= last_q_timestep, \ 69 f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}" 70 q_timesteps[coord.q] = coord.t 71 # each sequence step contains at max 1 coordinate per codebook 72 assert len(qs) == len(seq_coords), \ 73 f"Multiple entries for a same codebook are found at step {s}" 74 75 @property 76 def num_sequence_steps(self):
96 """Get codebook coordinates in the layout that corresponds to the specified timestep t 97 and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step 98 and the actual codebook coordinates. 99 """ 100 assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps" 101 if q is not None: 102 assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks" 103 coords = []
98 and the actual codebook coordinates. 99 """ 100 assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps" 101 if q is not None: 102 assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks" 103 coords = [] 104 for s, seq_codes in enumerate(self.layout): 105 for code in seq_codes:
125 Returns: 126 indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S]. 127 mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S]. 128 """ 129 assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" 130 assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern" 131 # use the proper layout based on whether we limit ourselves to valid steps only or not, 132 # note that using the valid_layout will result in a truncated sequence up to the valid steps
126 indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S]. 127 mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S]. 128 """ 129 assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" 130 assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern" 131 # use the proper layout based on whether we limit ourselves to valid steps only or not, 132 # note that using the valid_layout will result in a truncated sequence up to the valid steps 133 ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
195 """ 196 ref_layout = self.valid_layout if keep_only_valid_steps else self.layout 197 # TODO(jade): Do we want to further truncate to only valid timesteps here as well? 198 timesteps = self.timesteps 199 assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" 200 assert sequence_steps <= len(ref_layout), \ 201 f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}" 202
196 ref_layout = self.valid_layout if keep_only_valid_steps else self.layout 197 # TODO(jade): Do we want to further truncate to only valid timesteps here as well? 198 timesteps = self.timesteps 199 assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" 200 assert sequence_steps <= len(ref_layout), \ 201 f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}" 202 203 # ensure we take the appropriate indexes to keep the model output from the first special token as well 204 if is_model_output and self.starts_with_special_token():
284 cached (bool): if True, patterns for a given length are cached. In general 285 that should be true for efficiency reason to avoid synchronization points. 286 """ 287 def __init__(self, n_q: int, cached: bool = True): 288 assert n_q > 0 289 self.n_q = n_q 290 self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore 291
329 delays = list(range(n_q)) 330 self.delays = delays 331 self.flatten_first = flatten_first 332 self.empty_initial = empty_initial 333 assert len(self.delays) == self.n_q 334 assert sorted(self.delays) == self.delays 335 336 def get_pattern(self, timesteps: int) -> Pattern:
330 self.delays = delays 331 self.flatten_first = flatten_first 332 self.empty_initial = empty_initial 333 assert len(self.delays) == self.n_q 334 assert sorted(self.delays) == self.delays 335 336 def get_pattern(self, timesteps: int) -> Pattern: 337 omit_special_token = self.empty_initial < 0
423 if flattening is None: 424 flattening = list(range(n_q)) 425 if delays is None: 426 delays = [0] * n_q 427 assert len(flattening) == n_q 428 assert len(delays) == n_q 429 assert sorted(flattening) == flattening 430 assert sorted(delays) == delays
424 flattening = list(range(n_q)) 425 if delays is None: 426 delays = [0] * n_q 427 assert len(flattening) == n_q 428 assert len(delays) == n_q 429 assert sorted(flattening) == flattening 430 assert sorted(delays) == delays 431 self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
425 if delays is None: 426 delays = [0] * n_q 427 assert len(flattening) == n_q 428 assert len(delays) == n_q 429 assert sorted(flattening) == flattening 430 assert sorted(delays) == delays 431 self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening) 432 self.max_delay = max(delays)
426 delays = [0] * n_q 427 assert len(flattening) == n_q 428 assert len(delays) == n_q 429 assert sorted(flattening) == flattening 430 assert sorted(delays) == delays 431 self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening) 432 self.max_delay = max(delays) 433
441 if inner_step not in flattened_codebooks: 442 flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay) 443 else: 444 flat_codebook = flattened_codebooks[inner_step] 445 assert flat_codebook.delay == delay, ( 446 "Delay and flattening between codebooks is inconsistent: ", 447 "two codebooks flattened to the same position should have the same delay." 448 ) 449 flat_codebook.codebooks.append(q) 450 flattened_codebooks[inner_step] = flat_codebook 451 return flattened_codebooks
505 super().__init__(n_q) 506 if delays is None: 507 delays = [0] * (n_q - 1) 508 self.delays = delays 509 assert len(self.delays) == self.n_q - 1 510 assert sorted(self.delays) == self.delays 511 512 def get_pattern(self, timesteps: int) -> Pattern:
506 if delays is None: 507 delays = [0] * (n_q - 1) 508 self.delays = delays 509 assert len(self.delays) == self.n_q - 1 510 assert sorted(self.delays) == self.delays 511 512 def get_pattern(self, timesteps: int) -> Pattern: 513 out: PatternLayout = [[]]
265 max_length: str = 128, 266 enable_grad: bool = False, 267 project_out: bool = False 268 ): 269 assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}" 270 super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out) 271 272 from transformers import T5EncoderModel, AutoTokenizer
546 elif conditioner_type == "lut": 547 conditioners[id] = TokenizerLUTConditioner(**conditioner_config) 548 elif conditioner_type == "pretransform": 549 sample_rate = conditioner_config.pop("sample_rate", None) 550 assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners" 551 552 pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate) 553
525 rescale_cfg: bool = False, 526 scale_phi: float = 0.0, 527 **kwargs): 528 529 assert batch_cfg, "batch_cfg must be True for DiTWrapper" 530 #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper" 531 532 return self.model(
571 model_type = diffusion_uncond_config.get('type', None) 572 573 diffusion_config = diffusion_uncond_config.get('config', {}) 574 575 assert model_type is not None, "Must specify model type in config" 576 577 pretransform = diffusion_uncond_config.get("pretransform", None) 578
576 577 pretransform = diffusion_uncond_config.get("pretransform", None) 578 579 sample_size = config.get("sample_size", None) 580 assert sample_size is not None, "Must specify sample size in config" 581 582 sample_rate = config.get("sample_rate", None) 583 assert sample_rate is not None, "Must specify sample rate in config"
579 sample_size = config.get("sample_size", None) 580 assert sample_size is not None, "Must specify sample size in config" 581 582 sample_rate = config.get("sample_rate", None) 583 assert sample_rate is not None, "Must specify sample rate in config" 584 585 if pretransform is not None: 586 pretransform = create_pretransform_from_config(pretransform, sample_rate)
621 622 model_type = config["model_type"] 623 624 diffusion_config = model_config.get('diffusion', None) 625 assert diffusion_config is not None, "Must specify diffusion config" 626 627 diffusion_model_type = diffusion_config.get('type', None) 628 assert diffusion_model_type is not None, "Must specify diffusion model type"
624 diffusion_config = model_config.get('diffusion', None) 625 assert diffusion_config is not None, "Must specify diffusion config" 626 627 diffusion_model_type = diffusion_config.get('type', None) 628 assert diffusion_model_type is not None, "Must specify diffusion model type" 629 630 diffusion_model_config = diffusion_config.get('config', None) 631 assert diffusion_model_config is not None, "Must specify diffusion model config"
627 diffusion_model_type = diffusion_config.get('type', None) 628 assert diffusion_model_type is not None, "Must specify diffusion model type" 629 630 diffusion_model_config = diffusion_config.get('config', None) 631 assert diffusion_model_config is not None, "Must specify diffusion model config" 632 633 if diffusion_model_type == 'adp_cfg_1d': 634 diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
637 elif diffusion_model_type == 'dit': 638 diffusion_model = DiTWrapper(**diffusion_model_config) 639 640 io_channels = model_config.get('io_channels', None) 641 assert io_channels is not None, "Must specify io_channels in model config" 642 643 sample_rate = config.get('sample_rate', None) 644 assert sample_rate is not None, "Must specify sample_rate in config"
640 io_channels = model_config.get('io_channels', None) 641 assert io_channels is not None, "Must specify io_channels in model config" 642 643 sample_rate = config.get('sample_rate', None) 644 assert sample_rate is not None, "Must specify sample_rate in config" 645 646 diffusion_objective = diffusion_config.get('diffusion_objective', 'v') 647
679 extra_kwargs["diffusion_objective"] = diffusion_objective 680 681 elif model_type == "diffusion_prior": 682 prior_type = model_config.get("prior_type", None) 683 assert prior_type is not None, "Must specify prior_type in diffusion prior model config" 684 685 if prior_type == "mono_stereo": 686 from .diffusion_prior import MonoToStereoDiffusionPrior
248 mask=None, 249 return_info=False, 250 **kwargs): 251 252 assert causal == False, "Causal mode is not supported for DiffusionTransformer" 253 254 if cross_attn_cond_mask is not None: 255 cross_attn_cond_mask = cross_attn_cond_mask.bool()
2 3 def create_model_from_config(model_config): 4 model_type = model_config.get('model_type', None) 5 6 assert model_type is not None, 'model_type must be specified in model config' 7 8 if model_type == 'autoencoder': 9 from .autoencoders import create_autoencoder_from_config
31 32 def create_pretransform_from_config(pretransform_config, sample_rate): 33 pretransform_type = pretransform_config.get('type', None) 34 35 assert pretransform_type is not None, 'type must be specified in pretransform config' 36 37 if pretransform_type == 'autoencoder': 38 from .autoencoders import create_autoencoder_from_config
83 84 def create_bottleneck_from_config(bottleneck_config): 85 bottleneck_type = bottleneck_config.get('type', None) 86 87 assert bottleneck_type is not None, 'type must be specified in bottleneck config' 88 89 if bottleneck_type == 'tanh': 90 from .bottleneck import TanhBottleneck
66 ): 67 68 batch, num_quantizers, seq_len = sequence.shape 69 70 assert num_quantizers == self.num_quantizers, "Number of quantizers in sequence must match number of quantizers in model" 71 72 backbone_input = sum([self.embeds[i](sequence[:, i]) for i in range(num_quantizers)]) # [batch, seq_len, embed_dim] 73
150 global_cond_ids: tp.List[str] = [] 151 ): 152 super().__init__() 153 154 assert pretransform.is_discrete, "Pretransform must be discrete" 155 self.pretransform = pretransform 156 157 self.pretransform.requires_grad_(False)
370 possible_batch_sizes.append(conditioning_tensors[list(conditioning_tensors.keys())[0]][0].shape[0]) 371 else: 372 possible_batch_sizes.append(1) 373 374 assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs" 375 376 batch_size = possible_batch_sizes[0] 377
376 batch_size = possible_batch_sizes[0] 377 378 if init_data is None: 379 # Initialize with zeros 380 assert batch_size > 0 381 init_data = torch.zeros((batch_size, self.num_quantizers, 0), device=device, dtype=torch.long) 382 383 batch_size, num_quantizers, seq_len = init_data.shape
382 383 batch_size, num_quantizers, seq_len = init_data.shape 384 385 start_offset = seq_len 386 assert start_offset < max_gen_len, "init data longer than max gen length" 387 388 pattern = self.lm.pattern_provider.get_pattern(max_gen_len) 389
395 396 gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) # [batch, num_quantizers, gen_sequence_len] 397 398 start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) 399 assert start_offset_sequence is not None 400 401 # Generation 402 prev_offset = 0
439 # Callback to report progress 440 # Pass in the offset relative to the start of the sequence, and the length of the current sequence 441 callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence) 442 443 assert not (gen_sequence == unknown_token).any(), "Unknown tokens in generated sequence" 444 445 out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) 446
444 445 out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) 446 447 # sanity checks over the returned codes and corresponding masks 448 assert (out_codes[..., :max_gen_len] != unknown_token).all() 449 assert (out_mask[..., :max_gen_len] == 1).all() 450 451 #out_codes = out_codes[..., 0:max_gen_len]
445 out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) 446 447 # sanity checks over the returned codes and corresponding masks 448 assert (out_codes[..., :max_gen_len] != unknown_token).all() 449 assert (out_mask[..., :max_gen_len] == 1).all() 450 451 #out_codes = out_codes[..., 0:max_gen_len] 452
469 470 471 def create_audio_lm_from_config(config): 472 model_config = config.get('model', None) 473 assert model_config is not None, 'model config must be specified in config' 474 475 sample_rate = config.get('sample_rate', None) 476 assert sample_rate is not None, "Must specify sample_rate in config"
472 model_config = config.get('model', None) 473 assert model_config is not None, 'model config must be specified in config' 474 475 sample_rate = config.get('sample_rate', None) 476 assert sample_rate is not None, "Must specify sample_rate in config" 477 478 lm_config = model_config.get('lm', None) 479 assert lm_config is not None, 'lm config must be specified in model config'
475 sample_rate = config.get('sample_rate', None) 476 assert sample_rate is not None, "Must specify sample_rate in config" 477 478 lm_config = model_config.get('lm', None) 479 assert lm_config is not None, 'lm config must be specified in model config' 480 481 codebook_pattern = lm_config.get("codebook_pattern", "delay") 482
490 pretransform_config = model_config.get("pretransform", None) 491 492 pretransform = create_pretransform_from_config(pretransform_config, sample_rate) 493 494 assert pretransform.is_discrete, "Pretransform must be discrete" 495 496 min_input_length = pretransform.downsampling_ratio 497
509 510 lm_type = lm_config.get("type", None) 511 lm_model_config = lm_config.get("config", None) 512 513 assert lm_type is not None, "Must specify lm type in lm config" 514 assert lm_model_config is not None, "Must specify lm model config in lm config" 515 516 if lm_type == "x-transformers":
510 lm_type = lm_config.get("type", None) 511 lm_model_config = lm_config.get("config", None) 512 513 assert lm_type is not None, "Must specify lm type in lm config" 514 assert lm_model_config is not None, "Must specify lm model config in lm config" 515 516 if lm_type == "x-transformers": 517 backbone = XTransformersAudioLMBackbone(**lm_model_config)
20 super(PQMF, self).__init__() 21 22 # Ensure num_bands is a power of 2 23 is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands))) 24 assert is_power_of_2, "'num_bands' must be a power of 2." 25 26 # Create the prototype filter 27 prototype_filter = design_prototype_filter(attenuation, num_bands)
74 75 return decoded 76 77 def tokenize(self, x, **kwargs): 78 assert self.model.is_discrete, "Cannot tokenize with a continuous model" 79 80 _, info = self.model.encode(x, return_info = True, **kwargs) 81
81 82 return info[self.model.bottleneck.tokens_id] 83 84 def decode_tokens(self, tokens, **kwargs): 85 assert self.model.is_discrete, "Cannot decode tokens with a continuous model" 86 87 return self.model.decode_tokens(tokens, **kwargs) 88
221 self.model.to(torch.float16).eval().requires_grad_(False) 222 223 def encode(self, x): 224 225 assert False, "Audiocraft compression models do not support continuous encoding" 226 227 # latents = self.model.encoder(x) 228
238 # return output 239 240 def decode(self, z): 241 242 assert False, "Audiocraft compression models do not support continuous decoding" 243 244 # if self.scale != 1.0: 245 # z = z * self.scale
49 self.emb = nn.Embedding(max_seq_len, dim) 50 51 def forward(self, x, pos = None, seq_start_pos = None): 52 seq_len, device = x.shape[1], x.device 53 assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' 54 55 if pos is None: 56 pos = torch.arange(seq_len, device = device)
64 65 class ScaledSinusoidalEmbedding(nn.Module): 66 def __init__(self, dim, theta = 10000): 67 super().__init__() 68 assert (dim % 2) == 0, 'dimension must be divisible by 2' 69 self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) 70 71 half_dim = dim // 2
104 105 inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) 106 self.register_buffer('inv_freq', inv_freq) 107 108 assert interpolation_factor >= 1. 109 self.interpolation_factor = interpolation_factor 110 111 if not use_xpos:
346 if q_len == 1 and causal: 347 causal = False 348 349 if mask is not None: 350 assert mask.ndim == 4 351 mask = mask.expand(batch, heads, q_len, k_len) 352 353 # handle kv cache - this should be bypassable in updated flash attention 2
478 out = natten.functional.natten1dav(attn, v, kernel_size = self.natten_kernel_size, dilation=1).to(dtype_in) 479 480 # Prioritize Flash Attention 2 481 elif self.use_fa_flash: 482 assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2' 483 # Flash Attention 2 requires FP16 inputs 484 fa_dtype_in = q.dtype 485 q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.float16), (q, k, v))
768 769 if prepend_embeds is not None: 770 prepend_length, prepend_dim = prepend_embeds.shape[1:] 771 772 assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension' 773 774 x = torch.cat((prepend_embeds, x), dim = -2) 775
23 self.wavelet = wavelet 24 self.channels = channels 25 self.levels = levels 26 filt = get_filter_bank(wavelet) 27 assert filt.shape[-1] % 2 == 1 28 kernel = filt[:2, None] 29 kernel = torch.flip(kernel, dims=(-1,)) 30 index_i = torch.repeat_interleave(torch.arange(2), channels)
55 self.wavelet = wavelet 56 self.channels = channels 57 self.levels = levels 58 filt = get_filter_bank(wavelet) 59 assert filt.shape[-1] % 2 == 1 60 kernel = filt[2:, None] 61 index_i = torch.repeat_interleave(torch.arange(2), channels) 62 index_j = torch.tile(torch.arange(channels), (2,))
272 self.losses = MultiLoss(self.loss_modules) 273 274 self.log_loss_info = log_loss_info 275 276 assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" 277 278 if optimizer_configs is None: 279 optimizer_configs = {
328 with torch.cuda.amp.autocast(): 329 conditioning = self.diffusion.conditioner(metadata, self.device) 330 331 # If mask_padding is on, randomly drop the padding masks to allow for learning silence padding 332 use_padding_mask = self.mask_padding and random.random() > self.mask_padding_dropout 333 334 # Create batch tensor of attention masks from the "mask" field of the metadata array 335 if use_padding_mask:
618 self.losses = MultiLoss(self.loss_modules) 619 620 self.log_loss_info = log_loss_info 621 622 assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" 623 624 if optimizer_configs is None: 625 optimizer_configs = {
660 # Create a mask tensor for each batch element 661 masks = [] 662 663 for i in range(b): 664 mask_type = random.randint(0, 2) 665 666 if mask_type == 0: # Random mask with multiple segments 667 num_segments = random.randint(1, self.max_mask_segments)
663 for i in range(b): 664 mask_type = random.randint(0, 2) 665 666 if mask_type == 0: # Random mask with multiple segments 667 num_segments = random.randint(1, self.max_mask_segments) 668 max_segment_length = max_mask_length // num_segments 669 670 segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments)
670 segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments) 671 672 mask = torch.ones((1, 1, sequence_length)) 673 for length in segment_lengths: 674 mask_start = random.randint(0, sequence_length - length) 675 mask[:, :, mask_start:mask_start + length] = 0 676 677 elif mask_type == 1: # Full mask
678 mask = torch.zeros((1, 1, sequence_length)) 679 680 elif mask_type == 2: # Causal mask 681 mask = torch.ones((1, 1, sequence_length)) 682 mask_length = random.randint(1, max_mask_length) 683 mask[:, :, -mask_length:] = 0 684 685 mask = mask.to(sequence.device)
1189 for j in js: 1190 if i == j or (i != j and sources_added < num_sources): 1191 # Randomly offset the mixed element between 0 and the length of the source 1192 seq_len = reals.shape[2] 1193 offset = random.randint(0, seq_len-1) 1194 source[i, :, offset:] += reals[j, :, :-offset] 1195 if i == j: 1196 # If this is the real one, shift the reals as well to ensure alignment
3 from ..models.factory import create_model_from_config 4 5 def create_training_wrapper_from_config(model_config, model): 6 model_type = model_config.get('model_type', None) 7 assert model_type is not None, 'model_type must be specified in model config' 8 9 training_config = model_config.get('training', None) 10 assert training_config is not None, 'training config must be specified in model config'
6 model_type = model_config.get('model_type', None) 7 assert model_type is not None, 'model_type must be specified in model config' 8 9 training_config = model_config.get('training', None) 10 assert training_config is not None, 'training config must be specified in model config' 11 12 if model_type == 'autoencoder': 13 from .autoencoders import AutoencoderTrainingWrapper
157 raise NotImplementedError(f'Unknown model type: {model_type}') 158 159 def create_demo_callback_from_config(model_config, **kwargs): 160 model_type = model_config.get('model_type', None) 161 assert model_type is not None, 'model_type must be specified in model config' 162 163 training_config = model_config.get('training', None) 164 assert training_config is not None, 'training config must be specified in model config'
160 model_type = model_config.get('model_type', None) 161 assert model_type is not None, 'model_type must be specified in model config' 162 163 training_config = model_config.get('training', None) 164 assert training_config is not None, 'training config must be specified in model config' 165 166 demo_config = training_config.get("demo", {}) 167
36 self.model_ema = None 37 if use_ema: 38 self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10) 39 40 assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" 41 42 if optimizer_configs is None: 43 optimizer_configs = {
93 ce (torch.Tensor): Cross entropy averaged over the codebooks 94 ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). 95 """ 96 B, K, T = targets.shape 97 assert logits.shape[:-1] == targets.shape 98 assert mask.shape == targets.shape 99 ce = torch.zeros([], device=targets.device) 100 ce_per_codebook: tp.List[torch.Tensor] = []
94 ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). 95 """ 96 B, K, T = targets.shape 97 assert logits.shape[:-1] == targets.shape 98 assert mask.shape == targets.shape 99 ce = torch.zeros([], device=targets.device) 100 ce_per_codebook: tp.List[torch.Tensor] = [] 101 for k in range(K):
327 print(e) 328 print("Try `pip install auraloss[all]`.") 329 330 if self.scale == "mel": 331 assert sample_rate != None # Must set sample rate to use mel scale 332 assert n_bins <= fft_size # Must be more FFT bins than Mel bins 333 fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins) 334 fb = torch.tensor(fb).unsqueeze(0)
328 print("Try `pip install auraloss[all]`.") 329 330 if self.scale == "mel": 331 assert sample_rate != None # Must set sample rate to use mel scale 332 assert n_bins <= fft_size # Must be more FFT bins than Mel bins 333 fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins) 334 fb = torch.tensor(fb).unsqueeze(0) 335
333 fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins) 334 fb = torch.tensor(fb).unsqueeze(0) 335 336 elif self.scale == "chroma": 337 assert sample_rate != None # Must set sample rate to use chroma scale 338 assert n_bins <= fft_size # Must be more FFT bins than chroma bins 339 fb = librosa.filters.chroma( 340 sr=sample_rate, n_fft=fft_size, n_chroma=n_bins
334 fb = torch.tensor(fb).unsqueeze(0) 335 336 elif self.scale == "chroma": 337 assert sample_rate != None # Must set sample rate to use chroma scale 338 assert n_bins <= fft_size # Must be more FFT bins than chroma bins 339 fb = librosa.filters.chroma( 340 sr=sample_rate, n_fft=fft_size, n_chroma=n_bins 341 )
481 scale_invariance: bool = False, 482 **kwargs, 483 ): 484 super().__init__() 485 assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) # must define all 486 self.fft_sizes = fft_sizes 487 self.hop_sizes = hop_sizes 488 self.win_lengths = win_lengths
588 loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'. 589 loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor): 590 Aggregate and intermediate loss terms. Only returned if output='full'. 591 """ 592 assert input.shape == target.shape # must have same shape 593 bs, chs, seq_len = input.size() 594 595 # compute sum and difference signals for both