562 if isinstance(caption, str): 563 captions.append(caption) 564 elif isinstance(caption, (list, np.ndarray)): 565 # take a random caption if there are multiple 566 captions.append(random.choice(caption) if is_train else caption[0]) 567 else: 568 raise ValueError( 569 f"Caption column `{caption_column}` should contain either strings or lists of strings."
713 714 self.noise_scheduler.set_timesteps(40, device=self.accelerator.device) 715 timesteps = self.noise_scheduler.timesteps 716 717 mid_timestep = random.randint(30, 39) 718 719 for i, t in enumerate(timesteps[:mid_timestep]): 720 with torch.no_grad():
23 24 25 def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0): 26 27 assert vit in ['base', 'large'], "vit parameter must be base or large" 28 if vit=='base': 29 vision_width = 768 30 visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
347 outputs = self_attention_outputs[1:-1] 348 present_key_value = self_attention_outputs[-1] 349 350 if mode=='multimodal': 351 assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" 352 353 cross_attention_outputs = self.crossattention( 354 attention_output,
114 115 if os.path.isfile(download_target): 116 return download_target 117 118 with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 119 with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 120 while True: 121 buffer = source.read(8192)