Skip to content

Conversation

@ssam18
Copy link

@ssam18 ssam18 commented Nov 12, 2025

Fixes #6390

Problem

When use_fp8=True is enabled in HybridParallelPlugin and the model has output layers with dimensions not divisible by 16 (e.g., binary classification with 2 outputs), the training fails with:

Expected both dimensions of mat2 to be divisible by 16 but got torch.Size([768, 2]) 

Root Cause

torch._scaled_mm requires both dimensions of the weight matrix to be divisible by 16. The existing check in linear_fp8() only validated:

  • Input dimension (input.shape[-1])
  • Batch dimensions (np.prod(input.shape[:-1]))

But it did not check the output dimension (weight.shape[0]).

When using GPT2ForSequenceClassification with num_labels=2, the score layer has weight shape [768, 2], where 2 is not divisible by 16.

Solution

Added a check for weight.shape[0] % 16 != 0 to fallback to regular F.linear when the output dimension is not compatible with FP8.

if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0 or weight.shape[0] % 16 != 0: return F.linear(input, weight, bias)

Testing

This fix allows the model to:

  • Use FP8 for layers with compatible dimensions (performance benefit)
  • Fallback to standard FP16/BF16 for incompatible layers (correctness)
  • Run successfully with small output dimensions (e.g., binary classification)

The change is backward compatible and doesn't affect existing working configurations.

Fixes hpcaitech#6390 The issue occurs when use_fp8=True is enabled and the model has output layers with dimensions not divisible by 16 (e.g., binary classification with 2 outputs). torch._scaled_mm requires BOTH dimensions of mat2 (weight matrix) to be divisible by 16. The previous check only validated input dimensions but not the weight output dimension (weight.shape[0]). When using GPT2ForSequenceClassification with num_labels=2, the score layer has weight shape [768, 2], causing the error: 'Expected both dimensions of mat2 to be divisible by 16 but got torch.Size([768, 2])' This fix adds a check for weight.shape[0] % 16 != 0 to fallback to regular F.linear when the output dimension is not compatible with FP8.
@ssam18 ssam18 requested a review from a team as a code owner November 12, 2025 17:19
@ssam18
Copy link
Author

ssam18 commented Nov 30, 2025

Can someone take a look on the PR?

@SamareshSingh
Copy link

@ryanrussell @gothicx @tiansiyuan @jeffra Can someone take a look in this PR? I am happy to help and contribute to this repo!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants