Skip to content

Commit a629142

Browse files
authored
Merge pull request #140 from SmallDoges/support-integration
Replace attention_mask with cache_position for improved efficiency
2 parents e41efbf + 61de0e4 commit a629142

File tree

4 files changed

+237
-292
lines changed

4 files changed

+237
-292
lines changed

benchmarks/backward_equivalence.py

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def prepare_dynamic_mask(
5454
hidden_states: torch.Tensor,
5555
zoh_states: torch.Tensor,
5656
keep_window_size: int = 2048,
57-
attention_mask: torch.Tensor | None = None,
57+
cache_position: torch.Tensor = None,
5858
):
5959
"""
6060
Calculate dynamic attention mask to mask tokens for sparse attention.
@@ -65,28 +65,23 @@ def prepare_dynamic_mask(
6565
hidden_states: Input hidden states to determine dtype minimum value
6666
zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length)
6767
keep_window_size: Window size of tokens not dynamically masked
68-
attention_mask: Optional attention mask of shape (batch_size, 1, query_len, key_len)
68+
cache_position: Optional cache position for causal masking
6969
7070
Returns:
7171
tuple: (attn_bias, attn_mask)
7272
"""
73-
min_dtype = torch.finfo(hidden_states.dtype).min
7473
dtype = hidden_states.dtype
74+
min_dtype = torch.finfo(dtype).min
7575
attn_bias = zoh_states[:, :, None, :].expand(
7676
-1, -1, hidden_states.shape[2], -1
77-
) # [batch_size, num_kv_heads, query_len, key_len]
77+
).to(dtype) # [batch_size, num_kv_heads, query_len, key_len]
7878

79-
if attention_mask is not None:
80-
if attention_mask.dtype == torch.bool:
81-
attention_mask = torch.where(
82-
attention_mask,
83-
torch.tensor(0.0, device=attention_mask.device, dtype=dtype),
84-
min_dtype
85-
)
79+
if cache_position is not None:
8680
attn_bias = attn_bias.masked_fill(
87-
attention_mask[:, :, :, : attn_bias.shape[-1]] != 0, min_dtype
81+
torch.arange(attn_bias.shape[-1], device=attn_bias.device) > cache_position.reshape(-1, 1),
82+
min_dtype
8883
)
89-
84+
9085
if attn_bias.shape[-1] > keep_window_size:
9186
topk_values, topk_indices = torch.topk(
9287
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
@@ -150,7 +145,7 @@ def dynamic_mask_attention_python(
150145
dt_proj: torch.Tensor,
151146
A: torch.Tensor,
152147
scaling: float,
153-
causal_mask: torch.Tensor,
148+
cache_position: torch.Tensor,
154149
dout: torch.Tensor,
155150
keep_window_size=2048,
156151
is_causal=True,
@@ -165,7 +160,7 @@ def dynamic_mask_attention_python(
165160
dt_proj: [num_kv_heads, num_kv_heads * head_dim]
166161
A: [num_kv_heads]
167162
scaling: Attention scaling factor
168-
causal_mask: Causal attention mask
163+
cache_position: Cache position for causal masking
169164
dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
170165
keep_window_size: Number of tokens to keep in attention window
171166
is_causal: Whether to apply causal masking
@@ -188,7 +183,7 @@ def dynamic_mask_attention_python(
188183
query_states,
189184
zoh_states,
190185
keep_window_size,
191-
causal_mask if is_causal else None
186+
cache_position if is_causal else None
192187
)
193188
attn_bias_leaf = attn_bias
194189
attn_bias_leaf.retain_grad()
@@ -218,7 +213,7 @@ def dynamic_mask_attention_cuda(
218213
dt_proj: torch.Tensor,
219214
A: torch.Tensor,
220215
scaling: float,
221-
causal_mask: torch.Tensor,
216+
cache_position: torch.Tensor,
222217
dout: torch.Tensor,
223218
keep_window_size=2048,
224219
is_causal=True,
@@ -233,7 +228,7 @@ def dynamic_mask_attention_cuda(
233228
dt_proj: [num_kv_heads, num_kv_heads * head_dim]
234229
A: [num_kv_heads]
235230
scaling: Attention scaling factor
236-
causal_mask: Causal attention mask
231+
cache_position: Cache position for causal masking
237232
dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
238233
keep_window_size: Number of tokens to keep in attention window
239234
is_causal: Whether to apply causal masking
@@ -256,7 +251,7 @@ def dynamic_mask_attention_cuda(
256251
query_states,
257252
zoh_states,
258253
keep_window_size,
259-
causal_mask if is_causal else None
254+
cache_position if is_causal else None
260255
) # [batch_size, num_kv_heads, query_len, key_len]
261256
attn_bias_leaf = attn_bias
262257
attn_bias_leaf.retain_grad()
@@ -294,7 +289,7 @@ def dynamic_mask_attention_triton(
294289
dt_proj: torch.Tensor,
295290
A: torch.Tensor,
296291
scaling: float,
297-
causal_mask: torch.Tensor,
292+
cache_position: torch.Tensor,
298293
dout: torch.Tensor,
299294
keep_window_size=2048,
300295
is_causal=True,
@@ -309,7 +304,7 @@ def dynamic_mask_attention_triton(
309304
dt_proj: [num_kv_heads, num_kv_heads * head_dim]
310305
A: [num_kv_heads]
311306
scaling: Attention scaling factor
312-
causal_mask: Causal attention mask
307+
cache_position: Cache position for causal masking
313308
dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
314309
keep_window_size: Number of tokens to keep in attention window
315310
is_causal: Whether to apply causal masking
@@ -336,7 +331,7 @@ def dynamic_mask_attention_triton(
336331
query_states,
337332
zoh_states,
338333
keep_window_size,
339-
causal_mask if is_causal else None
334+
cache_position if is_causal else None
340335
) # [batch_size, num_kv_heads, query_len, key_len]
341336
attn_bias_leaf = attn_bias
342337
attn_bias_leaf.retain_grad()
@@ -378,7 +373,7 @@ def dynamic_mask_attention_flex(
378373
dt_proj: torch.Tensor,
379374
A: torch.Tensor,
380375
scaling: float,
381-
causal_mask: torch.Tensor,
376+
cache_position: torch.Tensor,
382377
dout: torch.Tensor,
383378
keep_window_size=2048,
384379
is_causal=True,
@@ -393,7 +388,7 @@ def dynamic_mask_attention_flex(
393388
dt_proj: [num_kv_heads, num_kv_heads * head_dim]
394389
A: [num_kv_heads]
395390
scaling: Attention scaling factor
396-
causal_mask: Causal attention mask
391+
cache_position: Cache position for causal masking
397392
dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
398393
keep_window_size: Number of tokens to keep in attention window
399394
is_causal: Whether to apply causal masking
@@ -416,7 +411,7 @@ def dynamic_mask_attention_flex(
416411
query_states,
417412
zoh_states,
418413
keep_window_size,
419-
causal_mask if is_causal else None
414+
cache_position if is_causal else None
420415
) # [batch_size, num_kv_heads, query_len, key_len]
421416
attn_bias.retain_grad()
422417

@@ -673,16 +668,8 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
673668
)
674669
A = torch.randn(num_kv_heads, device=device, dtype=dtype, requires_grad=True)
675670

676-
# Create custom causal mask with cache position
671+
# Create cache position
677672
cache_position = torch.arange(key_len - query_len, key_len, device=device)
678-
min_type = torch.finfo(value_states.dtype).min
679-
causal_mask = torch.full(
680-
(query_len, key_len), fill_value=min_type,
681-
device=device, dtype=value_states.dtype
682-
)
683-
causal_mask = torch.triu(causal_mask, diagonal=1)
684-
causal_mask *= torch.arange(key_len, device=device) > cache_position.reshape(-1, 1)
685-
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
686673

687674
# Set scaling factor and keep window size
688675
scaling = head_dim ** -0.5
@@ -705,7 +692,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
705692
start_time = time.time()
706693
attn_outputs_python, dq_python, dk_python, dv_python, dbias_python = dynamic_mask_attention_python(
707694
query_python, key_python, value_python, dt_proj_python, A_python,
708-
scaling, causal_mask, dout.clone(), keep_window_size, is_causal
695+
scaling, cache_position, dout.clone(), keep_window_size, is_causal
709696
)
710697
torch.cuda.synchronize()
711698
py_time = time.time() - start_time
@@ -722,7 +709,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
722709
start_time = time.time()
723710
attn_outputs_cuda, dq_cuda, dk_cuda, dv_cuda, dbias_cuda = dynamic_mask_attention_cuda(
724711
query_cuda, key_cuda, value_cuda, dt_proj_cuda, A_cuda,
725-
scaling, causal_mask, dout.clone(), keep_window_size, is_causal
712+
scaling, cache_position, dout.clone(), keep_window_size, is_causal
726713
)
727714
torch.cuda.synchronize()
728715
cuda_time = time.time() - start_time
@@ -787,7 +774,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
787774
if not is_close and max_dbias_diff > 1e-2:
788775
print(" ⚠️ Difference too large, stopping subsequent tests.")
789776
break
790-
del query_states, key_states, value_states, dt_proj, A, causal_mask, dout, dq_python, dk_python, dv_python, dbias_python, dq_cuda, dk_cuda, dv_cuda, dbias_cuda
777+
del query_states, key_states, value_states, dt_proj, A, cache_position, dout, dq_python, dk_python, dv_python, dbias_python, dq_cuda, dk_cuda, dv_cuda, dbias_cuda
791778
torch.cuda.empty_cache()
792779
gc.collect()
793780
torch.cuda.synchronize()

0 commit comments

Comments
 (0)