@@ -54,7 +54,7 @@ def prepare_dynamic_mask(
5454 hidden_states : torch .Tensor ,
5555 zoh_states : torch .Tensor ,
5656 keep_window_size : int = 2048 ,
57- attention_mask : torch .Tensor | None = None ,
57+ cache_position : torch .Tensor = None ,
5858):
5959 """
6060 Calculate dynamic attention mask to mask tokens for sparse attention.
@@ -65,28 +65,23 @@ def prepare_dynamic_mask(
6565 hidden_states: Input hidden states to determine dtype minimum value
6666 zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length)
6767 keep_window_size: Window size of tokens not dynamically masked
68- attention_mask : Optional attention mask of shape (batch_size, 1, query_len, key_len)
68+ cache_position : Optional cache position for causal masking
6969
7070 Returns:
7171 tuple: (attn_bias, attn_mask)
7272 """
73- min_dtype = torch .finfo (hidden_states .dtype ).min
7473 dtype = hidden_states .dtype
74+ min_dtype = torch .finfo (dtype ).min
7575 attn_bias = zoh_states [:, :, None , :].expand (
7676 - 1 , - 1 , hidden_states .shape [2 ], - 1
77- ) # [batch_size, num_kv_heads, query_len, key_len]
77+ ). to ( dtype ) # [batch_size, num_kv_heads, query_len, key_len]
7878
79- if attention_mask is not None :
80- if attention_mask .dtype == torch .bool :
81- attention_mask = torch .where (
82- attention_mask ,
83- torch .tensor (0.0 , device = attention_mask .device , dtype = dtype ),
84- min_dtype
85- )
79+ if cache_position is not None :
8680 attn_bias = attn_bias .masked_fill (
87- attention_mask [:, :, :, : attn_bias .shape [- 1 ]] != 0 , min_dtype
81+ torch .arange (attn_bias .shape [- 1 ], device = attn_bias .device ) > cache_position .reshape (- 1 , 1 ),
82+ min_dtype
8883 )
89-
84+
9085 if attn_bias .shape [- 1 ] > keep_window_size :
9186 topk_values , topk_indices = torch .topk (
9287 attn_bias , keep_window_size , dim = - 1 , largest = True , sorted = False
@@ -150,7 +145,7 @@ def dynamic_mask_attention_python(
150145 dt_proj : torch .Tensor ,
151146 A : torch .Tensor ,
152147 scaling : float ,
153- causal_mask : torch .Tensor ,
148+ cache_position : torch .Tensor ,
154149 dout : torch .Tensor ,
155150 keep_window_size = 2048 ,
156151 is_causal = True ,
@@ -165,7 +160,7 @@ def dynamic_mask_attention_python(
165160 dt_proj: [num_kv_heads, num_kv_heads * head_dim]
166161 A: [num_kv_heads]
167162 scaling: Attention scaling factor
168- causal_mask: Causal attention mask
163+ cache_position: Cache position for causal masking
169164 dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
170165 keep_window_size: Number of tokens to keep in attention window
171166 is_causal: Whether to apply causal masking
@@ -188,7 +183,7 @@ def dynamic_mask_attention_python(
188183 query_states ,
189184 zoh_states ,
190185 keep_window_size ,
191- causal_mask if is_causal else None
186+ cache_position if is_causal else None
192187 )
193188 attn_bias_leaf = attn_bias
194189 attn_bias_leaf .retain_grad ()
@@ -218,7 +213,7 @@ def dynamic_mask_attention_cuda(
218213 dt_proj : torch .Tensor ,
219214 A : torch .Tensor ,
220215 scaling : float ,
221- causal_mask : torch .Tensor ,
216+ cache_position : torch .Tensor ,
222217 dout : torch .Tensor ,
223218 keep_window_size = 2048 ,
224219 is_causal = True ,
@@ -233,7 +228,7 @@ def dynamic_mask_attention_cuda(
233228 dt_proj: [num_kv_heads, num_kv_heads * head_dim]
234229 A: [num_kv_heads]
235230 scaling: Attention scaling factor
236- causal_mask: Causal attention mask
231+ cache_position: Cache position for causal masking
237232 dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
238233 keep_window_size: Number of tokens to keep in attention window
239234 is_causal: Whether to apply causal masking
@@ -256,7 +251,7 @@ def dynamic_mask_attention_cuda(
256251 query_states ,
257252 zoh_states ,
258253 keep_window_size ,
259- causal_mask if is_causal else None
254+ cache_position if is_causal else None
260255 ) # [batch_size, num_kv_heads, query_len, key_len]
261256 attn_bias_leaf = attn_bias
262257 attn_bias_leaf .retain_grad ()
@@ -294,7 +289,7 @@ def dynamic_mask_attention_triton(
294289 dt_proj : torch .Tensor ,
295290 A : torch .Tensor ,
296291 scaling : float ,
297- causal_mask : torch .Tensor ,
292+ cache_position : torch .Tensor ,
298293 dout : torch .Tensor ,
299294 keep_window_size = 2048 ,
300295 is_causal = True ,
@@ -309,7 +304,7 @@ def dynamic_mask_attention_triton(
309304 dt_proj: [num_kv_heads, num_kv_heads * head_dim]
310305 A: [num_kv_heads]
311306 scaling: Attention scaling factor
312- causal_mask: Causal attention mask
307+ cache_position: Cache position for causal masking
313308 dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
314309 keep_window_size: Number of tokens to keep in attention window
315310 is_causal: Whether to apply causal masking
@@ -336,7 +331,7 @@ def dynamic_mask_attention_triton(
336331 query_states ,
337332 zoh_states ,
338333 keep_window_size ,
339- causal_mask if is_causal else None
334+ cache_position if is_causal else None
340335 ) # [batch_size, num_kv_heads, query_len, key_len]
341336 attn_bias_leaf = attn_bias
342337 attn_bias_leaf .retain_grad ()
@@ -378,7 +373,7 @@ def dynamic_mask_attention_flex(
378373 dt_proj : torch .Tensor ,
379374 A : torch .Tensor ,
380375 scaling : float ,
381- causal_mask : torch .Tensor ,
376+ cache_position : torch .Tensor ,
382377 dout : torch .Tensor ,
383378 keep_window_size = 2048 ,
384379 is_causal = True ,
@@ -393,7 +388,7 @@ def dynamic_mask_attention_flex(
393388 dt_proj: [num_kv_heads, num_kv_heads * head_dim]
394389 A: [num_kv_heads]
395390 scaling: Attention scaling factor
396- causal_mask: Causal attention mask
391+ cache_position: Cache position for causal masking
397392 dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
398393 keep_window_size: Number of tokens to keep in attention window
399394 is_causal: Whether to apply causal masking
@@ -416,7 +411,7 @@ def dynamic_mask_attention_flex(
416411 query_states ,
417412 zoh_states ,
418413 keep_window_size ,
419- causal_mask if is_causal else None
414+ cache_position if is_causal else None
420415 ) # [batch_size, num_kv_heads, query_len, key_len]
421416 attn_bias .retain_grad ()
422417
@@ -673,16 +668,8 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
673668 )
674669 A = torch .randn (num_kv_heads , device = device , dtype = dtype , requires_grad = True )
675670
676- # Create custom causal mask with cache position
671+ # Create cache position
677672 cache_position = torch .arange (key_len - query_len , key_len , device = device )
678- min_type = torch .finfo (value_states .dtype ).min
679- causal_mask = torch .full (
680- (query_len , key_len ), fill_value = min_type ,
681- device = device , dtype = value_states .dtype
682- )
683- causal_mask = torch .triu (causal_mask , diagonal = 1 )
684- causal_mask *= torch .arange (key_len , device = device ) > cache_position .reshape (- 1 , 1 )
685- causal_mask = causal_mask [None , None , :, :].expand (batch_size , 1 , - 1 , - 1 )
686673
687674 # Set scaling factor and keep window size
688675 scaling = head_dim ** - 0.5
@@ -705,7 +692,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
705692 start_time = time .time ()
706693 attn_outputs_python , dq_python , dk_python , dv_python , dbias_python = dynamic_mask_attention_python (
707694 query_python , key_python , value_python , dt_proj_python , A_python ,
708- scaling , causal_mask , dout .clone (), keep_window_size , is_causal
695+ scaling , cache_position , dout .clone (), keep_window_size , is_causal
709696 )
710697 torch .cuda .synchronize ()
711698 py_time = time .time () - start_time
@@ -722,7 +709,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
722709 start_time = time .time ()
723710 attn_outputs_cuda , dq_cuda , dk_cuda , dv_cuda , dbias_cuda = dynamic_mask_attention_cuda (
724711 query_cuda , key_cuda , value_cuda , dt_proj_cuda , A_cuda ,
725- scaling , causal_mask , dout .clone (), keep_window_size , is_causal
712+ scaling , cache_position , dout .clone (), keep_window_size , is_causal
726713 )
727714 torch .cuda .synchronize ()
728715 cuda_time = time .time () - start_time
@@ -787,7 +774,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
787774 if not is_close and max_dbias_diff > 1e-2 :
788775 print (" ⚠️ Difference too large, stopping subsequent tests." )
789776 break
790- del query_states , key_states , value_states , dt_proj , A , causal_mask , dout , dq_python , dk_python , dv_python , dbias_python , dq_cuda , dk_cuda , dv_cuda , dbias_cuda
777+ del query_states , key_states , value_states , dt_proj , A , cache_position , dout , dq_python , dk_python , dv_python , dbias_python , dq_cuda , dk_cuda , dv_cuda , dbias_cuda
791778 torch .cuda .empty_cache ()
792779 gc .collect ()
793780 torch .cuda .synchronize ()
0 commit comments