Skip to content

Conversation

@giuseros
Copy link
Contributor

@giuseros giuseros commented Jun 5, 2024

This PR adds the global.atomic.fadd intrinsic in ROCDL (which supports f32 and vector<2xf16>)

@llvmbot
Copy link
Member

llvmbot commented Jun 5, 2024

@llvm/pr-subscribers-mlir-llvm

Author: Giuseppe Rossini (giuseros)

Changes

This PR adds the global.atomic.fadd intrinsic in ROCDL (which supports f32 and vector&lt;2xf16&gt;)


Full diff: https://github.com/llvm/llvm-project/pull/94486.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+15-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp (+20)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+9)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 1dabf5d7979b7..c8d4e4c03486e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -165,7 +165,7 @@ def ROCDL_BallotOp : let summary = "Vote across thread group"; let description = [{ - Ballot provides a bit mask containing the 1-bit predicate value from each lane.  + Ballot provides a bit mask containing the 1-bit predicate value from each lane. The nth bit of the result contains the 1 bit contributed by the nth warp lane. }]; @@ -516,7 +516,7 @@ def ROCDL_RawBufferAtomicCmpSwap : } //===---------------------------------------------------------------------===// -// MI-100 and MI-200 buffer atomic floating point add intrinsic +// MI-100, MI-200 and MI-300 global/buffer atomic floating point add intrinsic def ROCDL_RawBufferAtomicFAddOp : ROCDL_Op<"raw.buffer.atomic.fadd">, @@ -534,6 +534,19 @@ def ROCDL_RawBufferAtomicFAddOp : let hasCustomAssemblyFormat = 1; } +def ROCDL_GlobalAtomicFAddOp : + ROCDL_Op<"global.atomic.fadd">, + Arguments<(ins LLVM_Type:$ptr, + LLVM_Type:$vdata)>{ + string llvmBuilder = [{ + auto vdataType = moduleTranslation.convertType(op.getVdata().getType()); + auto ptrType = moduleTranslation.convertType(op.getPtr().getType()); + createIntrinsicCall(builder, + llvm::Intrinsic::amdgcn_global_atomic_fadd, {$ptr, $vdata}, {vdataType, ptrType, vdataType}); + }]; + let hasCustomAssemblyFormat = 1; +} + //===---------------------------------------------------------------------===// // Buffer atomic floating point max intrinsic. GFX9 does not support fp32. diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp index 65b770ae32610..34ebdb2ffd3d0 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -157,6 +157,26 @@ void RawBufferAtomicFAddOp::print(mlir::OpAsmPrinter &p) { p << " " << getOperands() << " : " << getVdata().getType(); } +// <operation> ::= +// `llvm.amdgcn.global.atomic.fadd.* %vdata, %ptr +ParseResult GlobalAtomicFAddOp::parse(OpAsmParser &parser, + OperationState &result) { + SmallVector<OpAsmParser::UnresolvedOperand, 5> ops; + Type type; + if (parser.parseOperandList(ops, 2) || parser.parseColonType(type)) + return failure(); + + auto ptrType = LLVM::LLVMPointerType::get(parser.getContext()); + if (parser.resolveOperands(ops, {ptrType, type}, parser.getNameLoc(), + result.operands)) + return failure(); + return success(); +} + +void GlobalAtomicFAddOp::print(mlir::OpAsmPrinter &p) { + p << " " << getOperands() << " : " << getVdata().getType(); +} + // <operation> ::= // `llvm.amdgcn.raw.buffer.atomic.fmax.* %vdata, %rsrc, %offset, // %soffset, %aux : result_type` diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index ce6b56d48437a..9d22b80748e14 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -494,6 +494,15 @@ llvm.func @rocdl.raw.buffer.atomic.f32(%rsrc : vector<4xi32>, llvm.return } +// CHECK-LABEL: rocdl.global.atomic +llvm.func @rocdl.global.atomic(%vdata0 : f32, %vdata1 : vector<2xf16>, %ptr : !llvm.ptr) { + // CHECK: call float @llvm.amdgcn.global.atomic.fadd.f32.p0.f32(ptr %{{.*}}, float %{{.*}} + rocdl.global.atomic.fadd %ptr, %vdata0: f32 + // CHECK: call <2 x half> @llvm.amdgcn.global.atomic.fadd.v2f16.p0.v2f16(ptr %{{.*}}, <2 x half> %{{.*}}) + rocdl.global.atomic.fadd %ptr, %vdata1: vector<2xf16> + llvm.return +} + llvm.func @rocdl.raw.buffer.atomic.i32(%rsrc : vector<4xi32>, %offset : i32, %soffset : i32, %vdata1 : i32) { 
@llvmbot
Copy link
Member

llvmbot commented Jun 5, 2024

@llvm/pr-subscribers-mlir

Author: Giuseppe Rossini (giuseros)

Changes

This PR adds the global.atomic.fadd intrinsic in ROCDL (which supports f32 and vector&lt;2xf16&gt;)


Full diff: https://github.com/llvm/llvm-project/pull/94486.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+15-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp (+20)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+9)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 1dabf5d7979b7..c8d4e4c03486e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -165,7 +165,7 @@ def ROCDL_BallotOp : let summary = "Vote across thread group"; let description = [{ - Ballot provides a bit mask containing the 1-bit predicate value from each lane.  + Ballot provides a bit mask containing the 1-bit predicate value from each lane. The nth bit of the result contains the 1 bit contributed by the nth warp lane. }]; @@ -516,7 +516,7 @@ def ROCDL_RawBufferAtomicCmpSwap : } //===---------------------------------------------------------------------===// -// MI-100 and MI-200 buffer atomic floating point add intrinsic +// MI-100, MI-200 and MI-300 global/buffer atomic floating point add intrinsic def ROCDL_RawBufferAtomicFAddOp : ROCDL_Op<"raw.buffer.atomic.fadd">, @@ -534,6 +534,19 @@ def ROCDL_RawBufferAtomicFAddOp : let hasCustomAssemblyFormat = 1; } +def ROCDL_GlobalAtomicFAddOp : + ROCDL_Op<"global.atomic.fadd">, + Arguments<(ins LLVM_Type:$ptr, + LLVM_Type:$vdata)>{ + string llvmBuilder = [{ + auto vdataType = moduleTranslation.convertType(op.getVdata().getType()); + auto ptrType = moduleTranslation.convertType(op.getPtr().getType()); + createIntrinsicCall(builder, + llvm::Intrinsic::amdgcn_global_atomic_fadd, {$ptr, $vdata}, {vdataType, ptrType, vdataType}); + }]; + let hasCustomAssemblyFormat = 1; +} + //===---------------------------------------------------------------------===// // Buffer atomic floating point max intrinsic. GFX9 does not support fp32. diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp index 65b770ae32610..34ebdb2ffd3d0 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -157,6 +157,26 @@ void RawBufferAtomicFAddOp::print(mlir::OpAsmPrinter &p) { p << " " << getOperands() << " : " << getVdata().getType(); } +// <operation> ::= +// `llvm.amdgcn.global.atomic.fadd.* %vdata, %ptr +ParseResult GlobalAtomicFAddOp::parse(OpAsmParser &parser, + OperationState &result) { + SmallVector<OpAsmParser::UnresolvedOperand, 5> ops; + Type type; + if (parser.parseOperandList(ops, 2) || parser.parseColonType(type)) + return failure(); + + auto ptrType = LLVM::LLVMPointerType::get(parser.getContext()); + if (parser.resolveOperands(ops, {ptrType, type}, parser.getNameLoc(), + result.operands)) + return failure(); + return success(); +} + +void GlobalAtomicFAddOp::print(mlir::OpAsmPrinter &p) { + p << " " << getOperands() << " : " << getVdata().getType(); +} + // <operation> ::= // `llvm.amdgcn.raw.buffer.atomic.fmax.* %vdata, %rsrc, %offset, // %soffset, %aux : result_type` diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index ce6b56d48437a..9d22b80748e14 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -494,6 +494,15 @@ llvm.func @rocdl.raw.buffer.atomic.f32(%rsrc : vector<4xi32>, llvm.return } +// CHECK-LABEL: rocdl.global.atomic +llvm.func @rocdl.global.atomic(%vdata0 : f32, %vdata1 : vector<2xf16>, %ptr : !llvm.ptr) { + // CHECK: call float @llvm.amdgcn.global.atomic.fadd.f32.p0.f32(ptr %{{.*}}, float %{{.*}} + rocdl.global.atomic.fadd %ptr, %vdata0: f32 + // CHECK: call <2 x half> @llvm.amdgcn.global.atomic.fadd.v2f16.p0.v2f16(ptr %{{.*}}, <2 x half> %{{.*}}) + rocdl.global.atomic.fadd %ptr, %vdata1: vector<2xf16> + llvm.return +} + llvm.func @rocdl.raw.buffer.atomic.i32(%rsrc : vector<4xi32>, %offset : i32, %soffset : i32, %vdata1 : i32) { 
@krzysz00 krzysz00 self-requested a review June 5, 2024 18:38
@giuseros giuseros marked this pull request as draft June 6, 2024 10:56
@giuseros
Copy link
Contributor Author

giuseros commented Jun 6, 2024

I noticed that the use case I am dealing with needs also the output from the global.fadd (i.e., the original value in global memory). I am converting this to draft to avoid being accidentally merged, and I will ping you guys back once I sort this out. Thanks!

@giuseros giuseros marked this pull request as ready for review June 6, 2024 11:32
@giuseros
Copy link
Contributor Author

giuseros commented Jun 6, 2024

Done

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please just use atomicrmw fadd. I will shortly be pushing to remove the intrinsic

@giuseros
Copy link
Contributor Author

giuseros commented Jun 6, 2024

Please just use atomicrmw fadd. I will shortly be pushing to remove the intrinsic

Hi @arsenm , the problem is that atomicrmw fadd does not support vectors. So, in the case of fp16, this gets translated into a cas loop which is very slow

@giuseros
Copy link
Contributor Author

giuseros commented Jun 6, 2024

Please just use atomicrmw fadd. I will shortly be pushing to remove the intrinsic

Hi @arsenm , the problem is that atomicrmw fadd does not support vectors. So, in the case of fp16, this gets translated into a cas loop which is very slow

Or maybe it does?

@krzysz00
Copy link
Contributor

krzysz00 commented Jun 6, 2024

@giuseros I wonder if it's that MLIR's wrappers around atomicrmw don't support vectors ... which seems like an extension we could do

@arsenm
Copy link
Contributor

arsenm commented Jun 6, 2024

Please just use atomicrmw fadd. I will shortly be pushing to remove the intrinsic

Hi @arsenm , the problem is that atomicrmw fadd does not support vectors. So, in the case of fp16, this gets translated into a cas loop which is very slow

Or maybe it does?

atomicrmw FP operations do since 4cb110a. I still need to implement the AMDGPU codegen changes to start using the vector instructions though (plus eventually the new metadata from #85052 will be needed

@giuseros
Copy link
Contributor Author

giuseros commented Jun 6, 2024

Ok, given I am exactly after that vector instruction, how about we merge this PR and then we enable vector support for atomicrmw in MLIR (like @krzysz00 was suggesting) once it emits the vector instruction?

@giuseros
Copy link
Contributor Author

giuseros commented Jun 7, 2024

Hi @arsenm , is it ok for this to merge?

@arsenm
Copy link
Contributor

arsenm commented Jun 7, 2024

Hi @arsenm , is it ok for this to merge?

I guess, though I always prefer to just do whatever is needed move towards the end goal instead of adding new throwaway code

@giuseros
Copy link
Contributor Author

giuseros commented Jun 7, 2024

Ok, after a chat with Matthew, we agree on closing this for now and trying to emit the vectorized atocmirmw intrinsic. If we urgently need the feature, we will get back to this.

@giuseros giuseros closed this Jun 7, 2024
@arsenm
Copy link
Contributor

arsenm commented Jun 8, 2024

Part 1 to start supporting the vector selection is in #94845

@arsenm
Copy link
Contributor

arsenm commented Jun 13, 2024

#95393 for the LDS case, #95394 for global and flat

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment