- Notifications
You must be signed in to change notification settings - Fork 15.3k
[AArch64] Lower v8bf16 FMUL to BFMLAL top/bottom with +sve #169655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Assuming the predicate is hoisted, this should have a slightly better throughput: https://godbolt.org/z/jb7aP7Efc Note: SVE must be used to convert back to bf16 as the bfmlalb/t instructions operate on even/odd lanes, but the neon bfcvtn/2 process the top/bottom halves of vectors.
| @llvm/pr-subscribers-backend-aarch64 Author: Benjamin Maxwell (MacDue) ChangesAssuming the predicate is hoisted, this should have a slightly better throughput: https://godbolt.org/z/jb7aP7Efc Note: SVE must be used to convert back to bf16 as the bfmlalb/t instructions operate on even/odd lanes, but the neon bfcvtn/2 process the top/bottom halves of vectors. Full diff: https://github.com/llvm/llvm-project/pull/169655.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 83ce39fa314d1..9451a508033a1 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1824,6 +1824,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, else setOperationPromotedToType(ISD::FMUL, VT, PromotedVT); } + + if (Subtarget->hasBF16() && Subtarget->isNeonAvailable()) + setOperationAction(ISD::FMUL, MVT::v8bf16, Custom); } setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom); @@ -7688,7 +7691,8 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const { return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED); assert(Subtarget->hasBF16() && "Expected +bf16 for custom FMUL lowering"); - assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16) && "Unexpected FMUL VT"); + assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16 || VT == MVT::v8bf16) && + "Unexpected FMUL VT"); auto MakeGetIntrinsic = [&](Intrinsic::ID IID) { return [&, IID](EVT VT, auto... Ops) { @@ -7697,37 +7701,59 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const { }; }; - auto ReinterpretCast = [&](SDValue Value, EVT VT) { - if (VT == Value.getValueType()) + auto Reinterpret = [&](SDValue Value, EVT VT) { + EVT SrcVT = Value.getValueType(); + if (VT == SrcVT) return Value; + if (SrcVT.isFixedLengthVector()) + return convertToScalableVector(DAG, VT, Value); + if (VT.isFixedLengthVector()) + return convertFromScalableVector(DAG, VT, Value); return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Value); }; - // Create helpers for building intrinsic calls. - auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb); - auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt); auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2); auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2); - // All intrinsics expect to operate on full bf16 vector types. - SDValue LHS = ReinterpretCast(Op.getOperand(0), MVT::nxv8bf16); - SDValue RHS = ReinterpretCast(Op.getOperand(1), MVT::nxv8bf16); - - SDValue Zero = - DAG.getNeutralElement(ISD::FADD, DL, MVT::nxv4f32, Op->getFlags()); - SDValue Pg = DAG.getConstant(1, DL, MVT::nxv4i1); + EVT AccVT = VT.isFixedLengthVector() ? MVT::v4f32 : MVT::nxv4f32; + SDValue Zero = DAG.getNeutralElement(ISD::FADD, DL, AccVT, Op->getFlags()); + SDValue Pg = getPredicateForVector(DAG, DL, AccVT); - // Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom + // Lower bf16 FMUL as a pair (VT == [nx]v8bf16) of BFMLAL top/bottom // instructions. These result in two f32 vectors, which can be converted back // to bf16 with FCVT and FCVTNT. - SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS); + SDValue TopF32; + SDValue BottomF32; + if (VT == MVT::v8bf16) { + SDValue LHS = Op.getOperand(0); + SDValue RHS = Op.getOperand(1); + + auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_neon_bfmlalb); + auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_neon_bfmlalt); + + // Note: The NEON BFMLAL[BT] reads even/odd lanes like the SVE variant. + // This does not match BFCVTN[2], so we use SVE to convert back to bf16. + BottomF32 = Reinterpret(BFMLALB(MVT::v4f32, Zero, LHS, RHS), MVT::nxv4f32); + TopF32 = Reinterpret(BFMLALT(MVT::v4f32, Zero, LHS, RHS), MVT::nxv4f32); + } else { + // All SVE intrinsics expect to operate on full bf16 vector types. + SDValue LHS = Reinterpret(Op.getOperand(0), MVT::nxv8bf16); + SDValue RHS = Reinterpret(Op.getOperand(1), MVT::nxv8bf16); + + auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb); + auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt); + + BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS); + TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS); + } + SDValue BottomBF16 = FCVT(MVT::nxv8bf16, DAG.getPOISON(MVT::nxv8bf16), Pg, BottomF32); // Note: nxv4bf16 only uses even lanes. if (VT == MVT::nxv4bf16) - return ReinterpretCast(BottomBF16, VT); - SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS); - return FCVTNT(VT, BottomBF16, Pg, TopF32); + return Reinterpret(BottomBF16, VT); + SDValue TopBF16 = FCVTNT(MVT::nxv8bf16, BottomBF16, Pg, TopF32); + return Reinterpret(TopBF16, VT); } SDValue AArch64TargetLowering::LowerOperation(SDValue Op, diff --git a/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll b/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll index 6a7a4cbd8b20a..e3c0d97c08f54 100644 --- a/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll +++ b/llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll @@ -1,6 +1,7 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ; RUN: llc < %s -mtriple=aarch64 -mattr=-bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-CVT -; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-BF16 +; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-BF16,CHECK-NOSVE-BF16 +; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16,+sve | FileCheck %s --check-prefixes=CHECK,CHECK-BF16,CHECK-SVE-BF16 define <8 x bfloat> @add_h(<8 x bfloat> %a, <8 x bfloat> %b) { ; CHECK-CVT-LABEL: add_h: @@ -117,17 +118,29 @@ define <8 x bfloat> @mul_h(<8 x bfloat> %a, <8 x bfloat> %b) { ; CHECK-CVT-NEXT: uzp2 v0.8h, v0.8h, v2.8h ; CHECK-CVT-NEXT: ret ; -; CHECK-BF16-LABEL: mul_h: -; CHECK-BF16: // %bb.0: // %entry -; CHECK-BF16-NEXT: shll v2.4s, v1.4h, #16 -; CHECK-BF16-NEXT: shll v3.4s, v0.4h, #16 -; CHECK-BF16-NEXT: shll2 v1.4s, v1.8h, #16 -; CHECK-BF16-NEXT: shll2 v0.4s, v0.8h, #16 -; CHECK-BF16-NEXT: fmul v2.4s, v3.4s, v2.4s -; CHECK-BF16-NEXT: fmul v1.4s, v0.4s, v1.4s -; CHECK-BF16-NEXT: bfcvtn v0.4h, v2.4s -; CHECK-BF16-NEXT: bfcvtn2 v0.8h, v1.4s -; CHECK-BF16-NEXT: ret +; CHECK-NOSVE-BF16-LABEL: mul_h: +; CHECK-NOSVE-BF16: // %bb.0: // %entry +; CHECK-NOSVE-BF16-NEXT: shll v2.4s, v1.4h, #16 +; CHECK-NOSVE-BF16-NEXT: shll v3.4s, v0.4h, #16 +; CHECK-NOSVE-BF16-NEXT: shll2 v1.4s, v1.8h, #16 +; CHECK-NOSVE-BF16-NEXT: shll2 v0.4s, v0.8h, #16 +; CHECK-NOSVE-BF16-NEXT: fmul v2.4s, v3.4s, v2.4s +; CHECK-NOSVE-BF16-NEXT: fmul v1.4s, v0.4s, v1.4s +; CHECK-NOSVE-BF16-NEXT: bfcvtn v0.4h, v2.4s +; CHECK-NOSVE-BF16-NEXT: bfcvtn2 v0.8h, v1.4s +; CHECK-NOSVE-BF16-NEXT: ret +; +; CHECK-SVE-BF16-LABEL: mul_h: +; CHECK-SVE-BF16: // %bb.0: // %entry +; CHECK-SVE-BF16-NEXT: movi v2.4s, #128, lsl #24 +; CHECK-SVE-BF16-NEXT: movi v3.4s, #128, lsl #24 +; CHECK-SVE-BF16-NEXT: ptrue p0.s, vl4 +; CHECK-SVE-BF16-NEXT: bfmlalb v2.4s, v0.8h, v1.8h +; CHECK-SVE-BF16-NEXT: bfmlalt v3.4s, v0.8h, v1.8h +; CHECK-SVE-BF16-NEXT: bfcvt z2.h, p0/m, z2.s +; CHECK-SVE-BF16-NEXT: bfcvtnt z2.h, p0/m, z3.s +; CHECK-SVE-BF16-NEXT: mov v0.16b, v2.16b +; CHECK-SVE-BF16-NEXT: ret entry: %0 = fmul <8 x bfloat> %a, %b ret <8 x bfloat> %0 |
| if (VT == SrcVT) | ||
| return Value; | ||
| if (SrcVT.isFixedLengthVector()) | ||
| return convertToScalableVector(DAG, VT, Value); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be SrcVT instead of VT? If so then it looks like we need some test coverage for this line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so. If SrcVT is a fixed-length vector, then we want to convert to the scalable vector VT.
Note: If one of VT/SrcVT is fixed-length the is assumed to be scalable (this is checked with an assert in convertTo/FromScalableVector().
Assuming the predicate is hoisted, this should have a slightly better throughput: https://godbolt.org/z/jb7aP7Efc
Note: SVE must be used to convert back to bf16 as the bfmlalb/t instructions operate on even/odd lanes, but the neon bfcvtn/2 process the top/bottom halves of vectors.