- Notifications
You must be signed in to change notification settings - Fork 15.4k
[flang] implement VECTOR VECTORLENGTH directive #170114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This should match exactly the llvm attributes generated by classic flang.
| @llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-flang-parser Author: Tom Eccles (tblah) ChangesThis should match exactly the llvm attributes generated by classic flang. Full diff: https://github.com/llvm/llvm-project/pull/170114.diff 10 Files Affected:
diff --git a/flang/docs/Directives.md b/flang/docs/Directives.md index 128d8f9b6b707..5640e44e16bae 100644 --- a/flang/docs/Directives.md +++ b/flang/docs/Directives.md @@ -57,6 +57,15 @@ A list of non-standard directives supported by Flang * `!dir$ vector always` forces vectorization on the following loop regardless of cost model decisions. The loop must still be vectorizable. [This directive currently only works on plain do loops without labels]. +* `!dir$ vector vectorlength({fixed|scalable|<num>|<num>,fixed|<num>,scalable})` + specifies a hint to the compiler about the desired vectorization factor. If + `fixed` is used, the compiler should prefer fixed-width vectorization. + Scalable vectorization instructions may still be used with a fixed-width + predicate. If `scalable` is used the compiler should prefer scalable + vectorization, though it can choose to use fixed length vectorization or not + at all. `<num>` means that the compiler should consider using this specific + vectorization factor, which should be an integer literal. This directive + currently has the same limitations as `!dir$ vector always`. * `!dir$ unroll [n]` specifies that the compiler ought to unroll the immediately following loop `n` times. When `n` is `0` or `1`, the loop should not be unrolled at all. When `n` is `2` or greater, the loop should be unrolled exactly `n` diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h index f460e61fbb915..18218cee40ec9 100644 --- a/flang/include/flang/Parser/dump-parse-tree.h +++ b/flang/include/flang/Parser/dump-parse-tree.h @@ -229,6 +229,8 @@ class ParseTreeDumper { NODE(CompilerDirective, NoInline) NODE(CompilerDirective, Unrecognized) NODE(CompilerDirective, VectorAlways) + NODE_ENUM(CompilerDirective::VectorLength, VectorLength::Kind) + NODE(CompilerDirective, VectorLength) NODE(CompilerDirective, Unroll) NODE(CompilerDirective, UnrollAndJam) NODE(CompilerDirective, NoVector) diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h index dd928e1244a2f..10820b61a9ee8 100644 --- a/flang/include/flang/Parser/parse-tree.h +++ b/flang/include/flang/Parser/parse-tree.h @@ -3384,6 +3384,12 @@ struct CompilerDirective { std::tuple<common::Indirection<Designator>, uint64_t> t; }; EMPTY_CLASS(VectorAlways); + struct VectorLength { + TUPLE_CLASS_BOILERPLATE(VectorLength); + ENUM_CLASS(Kind, Auto, Fixed, Scalable); + + std::tuple<std::uint64_t, Kind> t; + }; struct NameValue { TUPLE_CLASS_BOILERPLATE(NameValue); std::tuple<Name, std::optional<std::uint64_t>> t; @@ -3408,9 +3414,9 @@ struct CompilerDirective { EMPTY_CLASS(Unrecognized); CharBlock source; std::variant<std::list<IgnoreTKR>, LoopCount, std::list<AssumeAligned>, - VectorAlways, std::list<NameValue>, Unroll, UnrollAndJam, Unrecognized, - NoVector, NoUnroll, NoUnrollAndJam, ForceInline, Inline, NoInline, - Prefetch, IVDep> + VectorAlways, VectorLength, std::list<NameValue>, Unroll, UnrollAndJam, + Unrecognized, NoVector, NoUnroll, NoUnrollAndJam, ForceInline, Inline, + NoInline, Prefetch, IVDep> u; }; diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 6f9dc32297272..8335fdd9a3b16 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -2576,12 +2576,16 @@ class FirConverter : public Fortran::lower::AbstractConverter { // Enabling loop vectorization attribute. mlir::LLVM::LoopVectorizeAttr - genLoopVectorizeAttr(mlir::BoolAttr disableAttr) { + genLoopVectorizeAttr(mlir::BoolAttr disableAttr, + mlir::BoolAttr scalableEnable, + mlir::IntegerAttr vectorWidth) { mlir::LLVM::LoopVectorizeAttr va; if (disableAttr) - va = mlir::LLVM::LoopVectorizeAttr::get(builder->getContext(), - /*disable=*/disableAttr, {}, {}, - {}, {}, {}, {}); + va = mlir::LLVM::LoopVectorizeAttr::get( + builder->getContext(), + /*disable=*/disableAttr, /*predicate=*/{}, + /*scalableEnable=*/scalableEnable, + /*vectorWidth=*/vectorWidth, {}, {}, {}); return va; } @@ -2589,6 +2593,8 @@ class FirConverter : public Fortran::lower::AbstractConverter { IncrementLoopInfo &info, llvm::SmallVectorImpl<const Fortran::parser::CompilerDirective *> &dirs) { mlir::BoolAttr disableVecAttr; + mlir::BoolAttr scalableEnable; + mlir::IntegerAttr vectorWidth; mlir::LLVM::LoopUnrollAttr ua; mlir::LLVM::LoopUnrollAndJamAttr uja; llvm::SmallVector<mlir::LLVM::AccessGroupAttr> aga; @@ -2601,6 +2607,30 @@ class FirConverter : public Fortran::lower::AbstractConverter { mlir::BoolAttr::get(builder->getContext(), false); has_attrs = true; }, + [&](const Fortran::parser::CompilerDirective::VectorLength &vl) { + using Kind = + Fortran::parser::CompilerDirective::VectorLength::Kind; + Kind kind = std::get<Kind>(vl.t); + uint64_t length = std::get<uint64_t>(vl.t); + disableVecAttr = + mlir::BoolAttr::get(builder->getContext(), false); + if (length != 0) + vectorWidth = + builder->getIntegerAttr(builder->getI64Type(), length); + switch (kind) { + case Kind::Scalable: + scalableEnable = + mlir::BoolAttr::get(builder->getContext(), true); + break; + case Kind::Fixed: + scalableEnable = + mlir::BoolAttr::get(builder->getContext(), false); + break; + case Kind::Auto: + break; + } + has_attrs = true; + }, [&](const Fortran::parser::CompilerDirective::Unroll &u) { ua = genLoopUnrollAttr(u.v); has_attrs = true; @@ -2632,7 +2662,8 @@ class FirConverter : public Fortran::lower::AbstractConverter { [&](const auto &) {}}, dir->u); } - mlir::LLVM::LoopVectorizeAttr va = genLoopVectorizeAttr(disableVecAttr); + mlir::LLVM::LoopVectorizeAttr va = + genLoopVectorizeAttr(disableVecAttr, scalableEnable, vectorWidth); mlir::LLVM::LoopAnnotationAttr la = mlir::LLVM::LoopAnnotationAttr::get( builder->getContext(), {}, /*vectorize=*/va, {}, /*unroll*/ ua, /*unroll_and_jam*/ uja, {}, {}, {}, {}, {}, {}, {}, {}, {}, @@ -3339,6 +3370,9 @@ class FirConverter : public Fortran::lower::AbstractConverter { [&](const Fortran::parser::CompilerDirective::VectorAlways &) { attachDirectiveToLoop(dir, &eval); }, + [&](const Fortran::parser::CompilerDirective::VectorLength &) { + attachDirectiveToLoop(dir, &eval); + }, [&](const Fortran::parser::CompilerDirective::Unroll &) { attachDirectiveToLoop(dir, &eval); }, diff --git a/flang/lib/Parser/Fortran-parsers.cpp b/flang/lib/Parser/Fortran-parsers.cpp index fccb9d82f4fc9..988db5450abc9 100644 --- a/flang/lib/Parser/Fortran-parsers.cpp +++ b/flang/lib/Parser/Fortran-parsers.cpp @@ -1295,6 +1295,7 @@ TYPE_PARSER(construct<StatOrErrmsg>("STAT =" >> statVariable) || // Directives, extensions, and deprecated statements // !DIR$ IGNORE_TKR [ [(tkrdmac...)] name ]... // !DIR$ LOOP COUNT (n1[, n2]...) +// !DIR$ VECTOR VECTORLENGTH ({FIXED|SCALABLE|<num>|<num>,FIXED|<num>,SCALABLE}) // !DIR$ name[=value] [, name[=value]]... // !DIR$ UNROLL [n] // !DIR$ PREFETCH designator[, designator]... @@ -1311,6 +1312,15 @@ constexpr auto assumeAligned{"ASSUME_ALIGNED" >> indirect(designator), ":"_tok >> digitString64))}; constexpr auto vectorAlways{ "VECTOR ALWAYS" >> construct<CompilerDirective::VectorAlways>()}; +constexpr auto vectorLengthKind{ + "FIXED" >> pure(CompilerDirective::VectorLength::Kind::Fixed) || + "SCALABLE" >> pure(CompilerDirective::VectorLength::Kind::Scalable)}; +constexpr auto vectorLength{"VECTOR VECTORLENGTH" >> + parenthesized(construct<CompilerDirective::VectorLength>( + digitString64, ","_tok >> vectorLengthKind) || + construct<CompilerDirective::VectorLength>(pure(0), vectorLengthKind) || + construct<CompilerDirective::VectorLength>( + digitString64, pure(CompilerDirective::VectorLength::Kind::Auto)))}; constexpr auto unroll{ "UNROLL" >> construct<CompilerDirective::Unroll>(maybe(digitString64))}; constexpr auto prefetch{"PREFETCH" >> @@ -1332,6 +1342,7 @@ TYPE_PARSER(beginDirective >> "DIR$ "_tok >> construct<CompilerDirective>(loopCount) || construct<CompilerDirective>(assumeAligned) || construct<CompilerDirective>(vectorAlways) || + construct<CompilerDirective>(vectorLength) || construct<CompilerDirective>(unrollAndJam) || construct<CompilerDirective>(unroll) || construct<CompilerDirective>(prefetch) || diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp index 8e9c7d04bc522..5421ee6f948a2 100644 --- a/flang/lib/Parser/unparse.cpp +++ b/flang/lib/Parser/unparse.cpp @@ -1848,6 +1848,25 @@ class UnparseVisitor { [&](const CompilerDirective::VectorAlways &valways) { Word("!DIR$ VECTOR ALWAYS"); }, + [&](const CompilerDirective::VectorLength &vlength) { + using Kind = CompilerDirective::VectorLength::Kind; + std::uint64_t length = std::get<std::uint64_t>(vlength.t); + Kind kind = std::get<Kind>(vlength.t); + + Word("!DIR$ VECTOR VECTORLENGTH ("); + // || kind == Kind::Auto handles the case of VECTORLENGTH(0) so we + // don't print nothing + if (length != 0 || kind == Kind::Auto) { + Walk(length); + } + if (length != 0 && kind != Kind::Auto) { + Word(", "); + } + if (kind != Kind::Auto) { + Word(CompilerDirective::VectorLength::EnumToString(kind)); + } + Word(")"); + }, [&](const std::list<CompilerDirective::NameValue> &names) { Walk("!DIR$ ", names, " "); }, diff --git a/flang/lib/Semantics/canonicalize-directives.cpp b/flang/lib/Semantics/canonicalize-directives.cpp index b21da4d041a97..f32a3d34c6572 100644 --- a/flang/lib/Semantics/canonicalize-directives.cpp +++ b/flang/lib/Semantics/canonicalize-directives.cpp @@ -56,6 +56,7 @@ bool CanonicalizeDirectives( static bool IsExecutionDirective(const parser::CompilerDirective &dir) { return std::holds_alternative<parser::CompilerDirective::VectorAlways>( dir.u) || + std::holds_alternative<parser::CompilerDirective::VectorLength>(dir.u) || std::holds_alternative<parser::CompilerDirective::Unroll>(dir.u) || std::holds_alternative<parser::CompilerDirective::UnrollAndJam>(dir.u) || std::holds_alternative<parser::CompilerDirective::NoVector>(dir.u) || @@ -121,6 +122,9 @@ void CanonicalizationOfDirectives::Post(parser::Block &block) { common::visitors{[&](parser::CompilerDirective::VectorAlways &) { CheckLoopDirective(*dir, block, it); }, + [&](parser::CompilerDirective::VectorLength &) { + CheckLoopDirective(*dir, block, it); + }, [&](parser::CompilerDirective::Unroll &) { CheckLoopDirective(*dir, block, it); }, diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp index 2a487a6d39d51..5814841053132 100644 --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -10075,6 +10075,7 @@ void ResolveNamesVisitor::Post(const parser::AssignedGotoStmt &x) { void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) { if (std::holds_alternative<parser::CompilerDirective::VectorAlways>(x.u) || + std::holds_alternative<parser::CompilerDirective::VectorLength>(x.u) || std::holds_alternative<parser::CompilerDirective::Unroll>(x.u) || std::holds_alternative<parser::CompilerDirective::UnrollAndJam>(x.u) || std::holds_alternative<parser::CompilerDirective::NoVector>(x.u) || diff --git a/flang/test/Lower/vectorlength.f90 b/flang/test/Lower/vectorlength.f90 new file mode 100644 index 0000000000000..95753c3f78090 --- /dev/null +++ b/flang/test/Lower/vectorlength.f90 @@ -0,0 +1,67 @@ +! RUN: %flang_fc1 -emit-hlfir -o - %s | FileCheck %s + +! CHECK: #[[FIXED:.*]] = #llvm.loop_vectorize<disable = false, scalableEnable = false> +! CHECK: #[[SCALABLE:.*]] = #llvm.loop_vectorize<disable = false, scalableEnable = true> +! CHECK: #[[WIDTH2:.*]] = #llvm.loop_vectorize<disable = false, width = 2 : i64> +! CHECK: #[[FIXED_WIDTH2:.*]] = #llvm.loop_vectorize<disable = false, scalableEnable = false, width = 2 : i64> +! CHECK: #[[SCALABLE_WIDTH2:.*]] = #llvm.loop_vectorize<disable = false, scalableEnable = true, width = 2 : i64> +! CHECK: #[[FIXED_TAG:.*]] = #llvm.loop_annotation<vectorize = #[[FIXED]]> +! CHECK: #[[SCALABLE_TAG:.*]] = #llvm.loop_annotation<vectorize = #[[SCALABLE]]> +! CHECK: #[[WIDTH2_TAG:.*]] = #llvm.loop_annotation<vectorize = #[[WIDTH2]]> +! CHECK: #[[FIXED_WIDTH2_TAG:.*]] = #llvm.loop_annotation<vectorize = #[[FIXED_WIDTH2]]> +! CHECK: #[[SCALABLE_WIDTH2_TAG:.*]] = #llvm.loop_annotation<vectorize = #[[SCALABLE_WIDTH2]]> + +! CHECK-LABEL: func.func @_QPfixed( +subroutine fixed(a, b, m) + integer :: i, m, a(m), b(m) + + !dir$ vector vectorlength(fixed) + ! CHECK: fir.do_loop {{.*}} attributes {loopAnnotation = #[[FIXED_TAG]]} + do i = 1, m + b(i) = a(i) + 1 + end do +end subroutine + +! CHECK-LABEL: func.func @_QPscalable( +subroutine scalable(a, b, m) + integer :: i, m, a(m), b(m) + + !dir$ vector vectorlength(scalable) + ! CHECK: fir.do_loop {{.*}} attributes {loopAnnotation = #[[SCALABLE_TAG]]} + do i = 1, m + b(i) = a(i) + 1 + end do +end subroutine + +! CHECK-LABEL: func.func @_QPlen2( +subroutine len2(a, b, m) + integer :: i, m, a(m), b(m) + + !dir$ vector vectorlength(2) + ! CHECK: fir.do_loop {{.*}} attributes {loopAnnotation = #[[WIDTH2_TAG]]} + do i = 1, m + b(i) = a(i) + 1 + end do +end subroutine + +! CHECK-LABEL: func.func @_QPlen2fixed( +subroutine len2fixed(a, b, m) + integer :: i, m, a(m), b(m) + + !dir$ vector vectorlength(2,fixed) + ! CHECK: fir.do_loop {{.*}} attributes {loopAnnotation = #[[FIXED_WIDTH2_TAG]]} + do i = 1, m + b(i) = a(i) + 1 + end do +end subroutine + +! CHECK-LABEL: func.func @_QPlen2scalable( +subroutine len2scalable(a, b, m) + integer :: i, m, a(m), b(m) + + !dir$ vector vectorlength(2,scalable) + ! CHECK: fir.do_loop {{.*}} attributes {loopAnnotation = #[[SCALABLE_WIDTH2_TAG]]} + do i = 1, m + b(i) = a(i) + 1 + end do +end subroutine diff --git a/flang/test/Parser/compiler-directives.f90 b/flang/test/Parser/compiler-directives.f90 index 56a10f9177997..ce592692cfc67 100644 --- a/flang/test/Parser/compiler-directives.f90 +++ b/flang/test/Parser/compiler-directives.f90 @@ -36,6 +36,28 @@ subroutine vector_always enddo end subroutine +subroutine vector_vectorlength + !dir$ vector vectorlength(fixed) + ! CHECK: !DIR$ VECTOR VECTORLENGTH (FIXED) + do i=1,10 + enddo + + !dir$ vector vectorlength(scalable) + ! CHECK: !DIR$ VECTOR VECTORLENGTH (SCALABLE) + do i=1,10 + enddo + + !dir$ vector vectorlength(8,scalable) + ! CHECK: !DIR$ VECTOR VECTORLENGTH (8, SCALABLE) + do i=1,10 + enddo + + !dir$ vector vectorlength(4) + ! CHECK: !DIR$ VECTOR VECTORLENGTH (4) + do i=1,10 + enddo +end subroutine + subroutine unroll !dir$ unroll ! CHECK: !DIR$ UNROLL |
JDPailleux left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for proposing this directive! LGTM :)
This should match exactly the llvm attributes generated by classic flang.