You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to capture the model prompt cross attention in order to apply some latent optimization techniques during inference, but for some reason I'm getting latents.grad as None no matter what I'm trying. I'll add small code snippets to describe what I'm trying to do:
class Optimizer: def __init__( self, loss_fn: LayoutLoss, num_refinements: int = 3, lr_start: float = 0.01, lr_end: float = 0.05, betas: tuple[float, float] = (0.4, 0.9), weight_decay: float = 0.0, ): self.loss_fn = loss_fn self.num_refinements = num_refinements self.lr_start = lr_start self.lr_end = lr_end self.betas = betas self.weight_decay = weight_decay def optimize( self, transformer: LTX2VideoTransformer3DModel, latents: torch.Tensor, audio_latents: torch.Tensor, prompt_embeds: torch.Tensor, audio_prompt_embeds: torch.Tensor, timestep: torch.Tensor, attention_mask: torch.Tensor, num_frames: int, height: int, width: int, fps: float, audio_num_frames: int, video_coords: torch.Tensor, audio_coords: torch.Tensor, attention_kwargs: Dict[str, Any], store: AttentionStore, progress_bar: tqdm.tqdm, ) -> torch.Tensor: latents = latents.clone().detach() latents = latents.to(transformer.dtype) optimizer = torch.optim.AdamW( [latents], lr=self.lr_start, betas=self.betas, weight_decay=self.weight_decay, ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.num_refinements, eta_min=self.lr_end ) transformer.zero_grad(set_to_none=True) first_loss = None with torch.enable_grad(): for i in range(self.num_refinements): latents = latents.requires_grad_(True) store.reset() latent_model_input = latents.to(transformer.dtype) _ = transformer( hidden_states=latent_model_input, audio_hidden_states=audio_latents, encoder_hidden_states=prompt_embeds, audio_encoder_hidden_states=audio_prompt_embeds, timestep=timestep, encoder_attention_mask=attention_mask, audio_encoder_attention_mask=attention_mask, num_frames=num_frames, height=height, width=width, fps=fps, audio_num_frames=audio_num_frames, video_coords=video_coords, audio_coords=audio_coords, attention_kwargs=attention_kwargs, return_dict=False, )[0] attn = store.get_avg_attention().unsqueeze(0) loss = self.loss_fn(attn) # Backward loss.backward() # FIX 4: Verify gradients exist before stepping if latents.grad is None: print(f"WARNING: latents.grad is None at iteration {i+1}!") print(" Gradient flow is broken. Check:") print(" 1. AttentionStore doesn't use .clone()") print(" 2. No dtype conversion breaks the computation graph") print(" 3. Gradient checkpointing is disabled") break # Only step if we have gradients optimizer.step() scheduler.step() if i == 0: first_loss = loss.item() current_lr = scheduler.get_last_lr()[0] progress_bar.set_postfix( loss=f"{first_loss:.2f}→{loss.item():.2f}", grad=f"{latents.grad.norm().item():.2e}", lr=f"{current_lr:.2e}", refine_step=f"{i + 1}/{self.num_refinements}", ) store.reset() return latents.detach() class AttentionStore: def __init__(self): self.accumulator = None self.count = 0 self.keep_heads = False def __call__(self, probs: torch.Tensor) -> torch.Tensor: if probs.shape[0] == 2: probs = probs[1:] if not self.keep_heads: probs = probs.mean(dim=1) if self.accumulator is None: self.accumulator = probs else: self.accumulator = self.accumulator + probs self.count += 1 return probs def reset(self): self.accumulator = None self.count = 0 def get_avg_attention(self) -> torch.Tensor: return self.accumulator / self.count class AttnProcessor: r""" Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model. Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can support audio-to-video (a2v) and video-to-audio (v2a) cross attention. FIXED: Now uses manual attention output for cross-attention to maintain gradient flow. """ _attention_backend = None _parallel_config = None def __init__(self, store: AttentionStore, name: str): if is_torch_version("<", "2.0"): raise ValueError( "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation." ) self.store = store self.name = name def __call__( self, attn: "LTX2Attention", hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) original_encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is None: encoder_hidden_states = hidden_states query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.norm_q(query) key = attn.norm_k(key) if query_rotary_emb is not None: if attn.rope_type == "interleaved": query = apply_interleaved_rotary_emb(query, query_rotary_emb) key = apply_interleaved_rotary_emb( key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb ) elif attn.rope_type == "split": query = apply_split_rotary_emb(query, query_rotary_emb) key = apply_split_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb) query = query.unflatten(2, (attn.heads, -1)) key = key.unflatten(2, (attn.heads, -1)) value = value.unflatten(2, (attn.heads, -1)) is_cross = original_encoder_hidden_states is not None and original_encoder_hidden_states is not hidden_states if is_cross: q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) scale_factor = 1.0 / math.sqrt(q.size(-1)) scores = torch.matmul(q, k.transpose(-1, -2)) * scale_factor if attention_mask is not None: scores = scores + attention_mask probs = F.softmax(scores, dim=-1) self.store(probs) hidden_states = dispatch_attention_fn( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, backend=self._attention_backend, ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) return hidden_states
Hoped someone can point me to my issue, I have a feeling the the captured attention are not affecting the latents during the forward (graph computation).
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I'm trying to capture the model prompt cross attention in order to apply some latent optimization techniques during inference, but for some reason I'm getting latents.grad as None no matter what I'm trying. I'll add small code snippets to describe what I'm trying to do:
Hoped someone can point me to my issue, I have a feeling the the captured attention are not affecting the latents during the forward (graph computation).
Beta Was this translation helpful? Give feedback.
All reactions