Skip to content
2 changes: 1 addition & 1 deletion benchmarks/backward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def dynamic_mask_attention_python(
value_states = repeat_kv(value_states, num_queries_per_kv)
attn_mask = repeat_kv(attn_mask, num_queries_per_kv)
attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv)

attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1))
attn_weights = attn_weights * scaling + attn_bias # Apply scaling and zoh
attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization
Expand Down
1,002 changes: 542 additions & 460 deletions csrc/flash_dmattn/flash_api.cpp

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions csrc/flash_dmattn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ struct Mask_params {
index_t mask_batch_stride; // Stride between batches of attention mask
index_t mask_head_stride; // Stride between heads of attention mask
index_t mask_row_stride; // Stride between rows of attention mask

// The number of heads in the mask.
int h_mask;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -61,6 +64,9 @@ struct Bias_params {
index_t bias_batch_stride; // Stride between batches of attention bias
index_t bias_head_stride; // Stride between heads of attention bias
index_t bias_row_stride; // Stride between rows of attention bias

// The number of heads in the bias.
int h_bias;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
6 changes: 4 additions & 2 deletions csrc/flash_dmattn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
+ n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
+ n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
const index_t row_offset_mask = binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)
+ (bidh / params.h_h_k_ratio) * params.mask_head_stride + (m_block_max - 1) * kBlockM * params.mask_row_stride + n_block * kBlockN;
+ h_idx_mask * params.mask_head_stride + (m_block_max - 1) * kBlockM * params.mask_row_stride + n_block * kBlockN;
const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
const index_t row_offset_bias = binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb)
+ (bidh / params.h_h_k_ratio) * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN;
+ h_idx_bias * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN;
const index_t row_offset_dbias = binfo.bias_offset(params.dbias_batch_stride, params.dbias_row_stride, bidb)
+ bidh * params.dbias_head_stride + (m_block_max - 1) * kBlockM * params.dbias_row_stride + n_block * kBlockN;
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
Expand Down
20 changes: 12 additions & 8 deletions csrc/flash_dmattn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).

const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);

// Global memory tensor configuration
Tensor mQ = make_tensor(
Expand Down Expand Up @@ -170,21 +172,21 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
); // (kBlockN, kHeadDim, nblocksN)
Tensor mMask = make_tensor(
make_gmem_ptr(reinterpret_cast<const bool*>(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)),
make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
make_shape(params.h_mask, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
make_stride(params.mask_head_stride, params.mask_row_stride, _1{})
);
Tensor gMask = local_tile(
mMask(bidh / params.h_h_k_ratio, _, _),
mMask(h_idx_mask, _, _),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_coord(m_block, _)
); // (kBlockM, kBlockN, nblocksN)
Tensor mBias = make_tensor(
make_gmem_ptr(reinterpret_cast<Element*>(params.bias_ptr) + binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb)),
make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
make_shape(params.h_bias, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
make_stride(params.bias_head_stride, params.bias_row_stride, _1{})
);
Tensor gBias = local_tile(
mBias(bidh / params.h_h_k_ratio, _, _),
mBias(h_idx_bias, _, _),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_coord(m_block, _)
); // (kBlockM, kBlockN, nblocksN)
Expand Down Expand Up @@ -840,16 +842,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
: block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
const index_t col_offset_mask = (block_table == nullptr)
? binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb_cache)
+ (bidh / params.h_h_k_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1) * kBlockN
+ h_idx_mask * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1) * kBlockN
: binfo.q_offset(/*batch_stride=*/index_t(0), params.mask_row_stride, bidb_cache)
+ (bidh / params.h_h_k_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset;
+ h_idx_mask * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset;
const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
const index_t col_offset_bias = (block_table == nullptr)
? binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb_cache)
+ (bidh / params.h_h_k_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1) * kBlockN
+ h_idx_bias * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1) * kBlockN
: binfo.q_offset(/*batch_stride=*/index_t(0), params.bias_row_stride, bidb_cache)
+ (bidh / params.h_h_k_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset;
+ h_idx_bias * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset;

// Global memory tensor configuration
Tensor mQ = make_tensor(
Expand Down
8 changes: 6 additions & 2 deletions flash_dmattn/flash_dmattn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,14 @@ def flash_dmattn_func(
key: torch.Tensor. The key tensor of shape (batch_size, seqlen, nheads_k, headdim)
value: torch.Tensor. The value tensor of shape (batch_size, seqlen, nheads_k, headdim)
attn_mask: torch.Tensor, optional. The attention mask boolean tensor of
shape (batch_size, nheads_k, seqlen_q, seqlen_k) to apply to the attention scores.
shape (batch_size, nheads, seqlen_q, seqlen_k) to apply to the attention scores.
Also supports shape (batch_size, nheads_k, seqlen_q, seqlen_k) or
(batch_size, 1, seqlen_q, seqlen_k) for MQA/GQA.
If None, no mask is applied.
attn_bias: torch.Tensor, optional. The attention bias float tensor of
shape (batch_size, nheads_k, seqlen_q, seqlen_k) to add to the attention scores.
shape (batch_size, nheads, seqlen_q, seqlen_k) to add to the attention scores.
Also supports shape (batch_size, nheads_k, seqlen_q, seqlen_k) or
(batch_size, 1, seqlen_q, seqlen_k) for MQA/GQA.
If None, no bias is applied.
is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
scale: float. The scaling of QK^T before applying softmax.
Expand Down
4 changes: 2 additions & 2 deletions flash_dmattn/integrations/flash_dynamic_mask_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def flash_dynamic_mask_attention_forward(
query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim).
key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, num_kv_heads, query_len, key_len).
attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, num_kv_heads, query_len, key_len).
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, num_heads, query_len, key_len), also supports (batch_size, num_kv_heads, query_len, key_len) or (batch_size, 1, query_len, key_len) for MQA/GQA.
attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, num_heads, query_len, key_len), also supports (batch_size, num_kv_heads, query_len, key_len) or (batch_size, 1, query_len, key_len) for MQA/GQA.
scaling (Optional[float]): The scaling factor for the attention scores.
softcap (Optional[float]): The softcap value for the attention scores.
**kwargs: Additional keyword arguments.
Expand Down
Loading