Skip to content

Conversation

@MacDue
Copy link
Member

@MacDue MacDue commented Nov 26, 2025

Assuming the predicate is hoisted, this should have a slightly better throughput: https://godbolt.org/z/jb7aP7Efc

Note: SVE must be used to convert back to bf16 as the bfmlalb/t instructions operate on even/odd lanes, but the neon bfcvtn/2 process the top/bottom halves of vectors.

Assuming the predicate is hoisted, this should have a slightly better throughput: https://godbolt.org/z/jb7aP7Efc Note: SVE must be used to convert back to bf16 as the bfmlalb/t instructions operate on even/odd lanes, but the neon bfcvtn/2 process the top/bottom halves of vectors.
@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Benjamin Maxwell (MacDue)

Changes

Assuming the predicate is hoisted, this should have a slightly better throughput: https://godbolt.org/z/jb7aP7Efc

Note: SVE must be used to convert back to bf16 as the bfmlalb/t instructions operate on even/odd lanes, but the neon bfcvtn/2 process the top/bottom halves of vectors.


Full diff: https://github.com/llvm/llvm-project/pull/169655.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+44-18)
  • (modified) llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll (+25-12)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 83ce39fa314d1..9451a508033a1 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1824,6 +1824,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, else setOperationPromotedToType(ISD::FMUL, VT, PromotedVT); } + + if (Subtarget->hasBF16() && Subtarget->isNeonAvailable()) + setOperationAction(ISD::FMUL, MVT::v8bf16, Custom); } setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom); @@ -7688,7 +7691,8 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const { return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED); assert(Subtarget->hasBF16() && "Expected +bf16 for custom FMUL lowering"); - assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16) && "Unexpected FMUL VT"); + assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16 || VT == MVT::v8bf16) && + "Unexpected FMUL VT"); auto MakeGetIntrinsic = [&](Intrinsic::ID IID) { return [&, IID](EVT VT, auto... Ops) { @@ -7697,37 +7701,59 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const { }; }; - auto ReinterpretCast = [&](SDValue Value, EVT VT) { - if (VT == Value.getValueType()) + auto Reinterpret = [&](SDValue Value, EVT VT) { + EVT SrcVT = Value.getValueType(); + if (VT == SrcVT) return Value; + if (SrcVT.isFixedLengthVector()) + return convertToScalableVector(DAG, VT, Value); + if (VT.isFixedLengthVector()) + return convertFromScalableVector(DAG, VT, Value); return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Value); }; - // Create helpers for building intrinsic calls. - auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb); - auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt); auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2); auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2); - // All intrinsics expect to operate on full bf16 vector types. - SDValue LHS = ReinterpretCast(Op.getOperand(0), MVT::nxv8bf16); - SDValue RHS = ReinterpretCast(Op.getOperand(1), MVT::nxv8bf16); - - SDValue Zero = - DAG.getNeutralElement(ISD::FADD, DL, MVT::nxv4f32, Op->getFlags()); - SDValue Pg = DAG.getConstant(1, DL, MVT::nxv4i1); + EVT AccVT = VT.isFixedLengthVector() ? MVT::v4f32 : MVT::nxv4f32; + SDValue Zero = DAG.getNeutralElement(ISD::FADD, DL, AccVT, Op->getFlags()); + SDValue Pg = getPredicateForVector(DAG, DL, AccVT); - // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom + // Lower bf16 FMUL as a pair (VT == [nx]v8bf16) of BFMLAL top/bottom // instructions. These result in two f32 vectors, which can be converted back // to bf16 with FCVT and FCVTNT. - SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS); + SDValue TopF32; + SDValue BottomF32; + if (VT == MVT::v8bf16) { + SDValue LHS = Op.getOperand(0); + SDValue RHS = Op.getOperand(1); + + auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_neon_bfmlalb); + auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_neon_bfmlalt); + + // Note: The NEON BFMLAL[BT] reads even/odd lanes like the SVE variant. + // This does not match BFCVTN[2], so we use SVE to convert back to bf16. + BottomF32 = Reinterpret(BFMLALB(MVT::v4f32, Zero, LHS, RHS), MVT::nxv4f32); + TopF32 = Reinterpret(BFMLALT(MVT::v4f32, Zero, LHS, RHS), MVT::nxv4f32); + } else { + // All SVE intrinsics expect to operate on full bf16 vector types. + SDValue LHS = Reinterpret(Op.getOperand(0), MVT::nxv8bf16); + SDValue RHS = Reinterpret(Op.getOperand(1), MVT::nxv8bf16); + + auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb); + auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt); + + BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS); + TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS); + } + SDValue BottomBF16 = FCVT(MVT::nxv8bf16, DAG.getPOISON(MVT::nxv8bf16), Pg, BottomF32); // Note: nxv4bf16 only uses even lanes. if (VT == MVT::nxv4bf16) - return ReinterpretCast(BottomBF16, VT); - SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS); - return FCVTNT(VT, BottomBF16, Pg, TopF32); + return Reinterpret(BottomBF16, VT); + SDValue TopBF16 = FCVTNT(MVT::nxv8bf16, BottomBF16, Pg, TopF32); + return Reinterpret(TopBF16, VT); } SDValue AArch64TargetLowering::LowerOperation(SDValue Op, diff --git a/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll b/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll index 6a7a4cbd8b20a..e3c0d97c08f54 100644 --- a/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll +++ b/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll @@ -1,6 +1,7 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ; RUN: llc < %s -mtriple=aarch64 -mattr=-bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-CVT -; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-BF16 +; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-BF16,CHECK-NOSVE-BF16 +; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16,+sve | FileCheck %s --check-prefixes=CHECK,CHECK-BF16,CHECK-SVE-BF16 define <8 x bfloat> @add_h(<8 x bfloat> %a, <8 x bfloat> %b) { ; CHECK-CVT-LABEL: add_h: @@ -117,17 +118,29 @@ define <8 x bfloat> @mul_h(<8 x bfloat> %a, <8 x bfloat> %b) { ; CHECK-CVT-NEXT: uzp2 v0.8h, v0.8h, v2.8h ; CHECK-CVT-NEXT: ret ; -; CHECK-BF16-LABEL: mul_h: -; CHECK-BF16: // %bb.0: // %entry -; CHECK-BF16-NEXT: shll v2.4s, v1.4h, #16 -; CHECK-BF16-NEXT: shll v3.4s, v0.4h, #16 -; CHECK-BF16-NEXT: shll2 v1.4s, v1.8h, #16 -; CHECK-BF16-NEXT: shll2 v0.4s, v0.8h, #16 -; CHECK-BF16-NEXT: fmul v2.4s, v3.4s, v2.4s -; CHECK-BF16-NEXT: fmul v1.4s, v0.4s, v1.4s -; CHECK-BF16-NEXT: bfcvtn v0.4h, v2.4s -; CHECK-BF16-NEXT: bfcvtn2 v0.8h, v1.4s -; CHECK-BF16-NEXT: ret +; CHECK-NOSVE-BF16-LABEL: mul_h: +; CHECK-NOSVE-BF16: // %bb.0: // %entry +; CHECK-NOSVE-BF16-NEXT: shll v2.4s, v1.4h, #16 +; CHECK-NOSVE-BF16-NEXT: shll v3.4s, v0.4h, #16 +; CHECK-NOSVE-BF16-NEXT: shll2 v1.4s, v1.8h, #16 +; CHECK-NOSVE-BF16-NEXT: shll2 v0.4s, v0.8h, #16 +; CHECK-NOSVE-BF16-NEXT: fmul v2.4s, v3.4s, v2.4s +; CHECK-NOSVE-BF16-NEXT: fmul v1.4s, v0.4s, v1.4s +; CHECK-NOSVE-BF16-NEXT: bfcvtn v0.4h, v2.4s +; CHECK-NOSVE-BF16-NEXT: bfcvtn2 v0.8h, v1.4s +; CHECK-NOSVE-BF16-NEXT: ret +; +; CHECK-SVE-BF16-LABEL: mul_h: +; CHECK-SVE-BF16: // %bb.0: // %entry +; CHECK-SVE-BF16-NEXT: movi v2.4s, #128, lsl #24 +; CHECK-SVE-BF16-NEXT: movi v3.4s, #128, lsl #24 +; CHECK-SVE-BF16-NEXT: ptrue p0.s, vl4 +; CHECK-SVE-BF16-NEXT: bfmlalb v2.4s, v0.8h, v1.8h +; CHECK-SVE-BF16-NEXT: bfmlalt v3.4s, v0.8h, v1.8h +; CHECK-SVE-BF16-NEXT: bfcvt z2.h, p0/m, z2.s +; CHECK-SVE-BF16-NEXT: bfcvtnt z2.h, p0/m, z3.s +; CHECK-SVE-BF16-NEXT: mov v0.16b, v2.16b +; CHECK-SVE-BF16-NEXT: ret entry: %0 = fmul <8 x bfloat> %a, %b ret <8 x bfloat> %0 
if (VT == SrcVT)
return Value;
if (SrcVT.isFixedLengthVector())
return convertToScalableVector(DAG, VT, Value);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be SrcVT instead of VT? If so then it looks like we need some test coverage for this line.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. If SrcVT is a fixed-length vector, then we want to convert to the scalable vector VT.

Note: If one of VT/SrcVT is fixed-length the is assumed to be scalable (this is checked with an assert in convertTo/FromScalableVector().

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

3 participants