@@ -32,6 +32,95 @@ using namespace mlir;
3232
3333namespace {
3434
35+ // / Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
36+ static Operation *createLoadTileSliceIntrinsic (
37+ RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
38+ arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
39+ IntegerAttr tileId, Value tileSliceI32) {
40+ if (layout == arm_sme::TileSliceLayout::Horizontal) {
41+ switch (type) {
42+ case arm_sme::ArmSMETileType::ZAB:
43+ return rewriter.create <arm_sme::aarch64_sme_ld1b_horiz>(
44+ loc, maskOp, ptr, tileId, tileSliceI32);
45+ case arm_sme::ArmSMETileType::ZAH:
46+ return rewriter.create <arm_sme::aarch64_sme_ld1h_horiz>(
47+ loc, maskOp, ptr, tileId, tileSliceI32);
48+ case arm_sme::ArmSMETileType::ZAS:
49+ return rewriter.create <arm_sme::aarch64_sme_ld1w_horiz>(
50+ loc, maskOp, ptr, tileId, tileSliceI32);
51+ case arm_sme::ArmSMETileType::ZAD:
52+ return rewriter.create <arm_sme::aarch64_sme_ld1d_horiz>(
53+ loc, maskOp, ptr, tileId, tileSliceI32);
54+ case arm_sme::ArmSMETileType::ZAQ:
55+ return rewriter.create <arm_sme::aarch64_sme_ld1q_horiz>(
56+ loc, maskOp, ptr, tileId, tileSliceI32);
57+ }
58+ } else {
59+ switch (type) {
60+ case arm_sme::ArmSMETileType::ZAB:
61+ return rewriter.create <arm_sme::aarch64_sme_ld1b_vert>(
62+ loc, maskOp, ptr, tileId, tileSliceI32);
63+ case arm_sme::ArmSMETileType::ZAH:
64+ return rewriter.create <arm_sme::aarch64_sme_ld1h_vert>(
65+ loc, maskOp, ptr, tileId, tileSliceI32);
66+ case arm_sme::ArmSMETileType::ZAS:
67+ return rewriter.create <arm_sme::aarch64_sme_ld1w_vert>(
68+ loc, maskOp, ptr, tileId, tileSliceI32);
69+ case arm_sme::ArmSMETileType::ZAD:
70+ return rewriter.create <arm_sme::aarch64_sme_ld1d_vert>(
71+ loc, maskOp, ptr, tileId, tileSliceI32);
72+ case arm_sme::ArmSMETileType::ZAQ:
73+ return rewriter.create <arm_sme::aarch64_sme_ld1q_vert>(
74+ loc, maskOp, ptr, tileId, tileSliceI32);
75+ break ;
76+ }
77+ }
78+ }
79+
80+ // / Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
81+ static Operation *createStoreTileSliceIntrinsic (
82+ RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
83+ arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
84+ IntegerAttr tileId, Value tileSliceI32) {
85+ if (layout == arm_sme::TileSliceLayout::Horizontal) {
86+ switch (type) {
87+ case arm_sme::ArmSMETileType::ZAB:
88+ return rewriter.create <arm_sme::aarch64_sme_st1b_horiz>(
89+ loc, maskOp, ptr, tileId, tileSliceI32);
90+ case arm_sme::ArmSMETileType::ZAH:
91+ return rewriter.create <arm_sme::aarch64_sme_st1h_horiz>(
92+ loc, maskOp, ptr, tileId, tileSliceI32);
93+ case arm_sme::ArmSMETileType::ZAS:
94+ return rewriter.create <arm_sme::aarch64_sme_st1w_horiz>(
95+ loc, maskOp, ptr, tileId, tileSliceI32);
96+ case arm_sme::ArmSMETileType::ZAD:
97+ return rewriter.create <arm_sme::aarch64_sme_st1d_horiz>(
98+ loc, maskOp, ptr, tileId, tileSliceI32);
99+ case arm_sme::ArmSMETileType::ZAQ:
100+ return rewriter.create <arm_sme::aarch64_sme_st1q_horiz>(
101+ loc, maskOp, ptr, tileId, tileSliceI32);
102+ }
103+ } else {
104+ switch (type) {
105+ case arm_sme::ArmSMETileType::ZAB:
106+ return rewriter.create <arm_sme::aarch64_sme_st1b_vert>(
107+ loc, maskOp, ptr, tileId, tileSliceI32);
108+ case arm_sme::ArmSMETileType::ZAH:
109+ return rewriter.create <arm_sme::aarch64_sme_st1h_vert>(
110+ loc, maskOp, ptr, tileId, tileSliceI32);
111+ case arm_sme::ArmSMETileType::ZAS:
112+ return rewriter.create <arm_sme::aarch64_sme_st1w_vert>(
113+ loc, maskOp, ptr, tileId, tileSliceI32);
114+ case arm_sme::ArmSMETileType::ZAD:
115+ return rewriter.create <arm_sme::aarch64_sme_st1d_vert>(
116+ loc, maskOp, ptr, tileId, tileSliceI32);
117+ case arm_sme::ArmSMETileType::ZAQ:
118+ return rewriter.create <arm_sme::aarch64_sme_st1q_vert>(
119+ loc, maskOp, ptr, tileId, tileSliceI32);
120+ }
121+ }
122+ }
123+
35124IntegerAttr getTileIdOrError (arm_sme::ArmSMETileOpInterface op) {
36125 auto tileId = op.getTileId ();
37126 if (!tileId)
@@ -75,9 +164,6 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
75164 ConversionPatternRewriter &rewriter) const override {
76165 auto loc = zero.getLoc ();
77166
78- unsigned tileElementWidth =
79- zero.getVectorType ().getElementType ().getIntOrFloatBitWidth ();
80-
81167 auto tileId = getTileIdOrError (zero);
82168 if (!tileId)
83169 return failure ();
@@ -86,23 +172,24 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
86172 // The base mask is just the mask to zero the first tile (of a size).
87173 // These masks are derived from:
88174 // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
175+ arm_sme::ArmSMETileType tileType = *zero.getAllocatedTileType ();
89176 auto baseMaskForSize = [&] {
90- switch (tileElementWidth ) {
91- case 8 :
177+ switch (tileType ) {
178+ case arm_sme::ArmSMETileType::ZAB :
92179 // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
93180 // 64-bit element tiles named ZA0.D to ZA7.D.
94181 return 0b1111'1111 ;
95- case 16 :
96- // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit element
97- // tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D.
98- // Shift this left once for ZA1.H.
182+ case arm_sme::ArmSMETileType::ZAH :
183+ // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
184+ // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
185+ // once for ZA1.H.
99186 return 0b0101'0101 ;
100- case 32 :
187+ case arm_sme::ArmSMETileType::ZAS :
101188 // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
102189 // element tiles named ZA0.D and ZA4.D.
103190 // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
104191 return 0b0001'0001 ;
105- case 64 :
192+ case arm_sme::ArmSMETileType::ZAD :
106193 // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
107194 // setting the bit for that tile.
108195 return 0b0000'0001 ;
@@ -172,63 +259,13 @@ struct LoadTileSliceConversion
172259 // Create all active predicate mask.
173260 auto maskOp = loadTileSliceOp.getMask ();
174261
175- auto tileType = loadTileSliceOp.getVectorType ();
176- auto tileElementType = tileType.getElementType ();
177- unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth ();
262+ auto tileVectorType = loadTileSliceOp.getVectorType ();
263+ arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType (tileVectorType);
178264 arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout ();
179265
180266 // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
181- if (layout == arm_sme::TileSliceLayout::Horizontal) {
182- switch (tileElementWidth) {
183- default :
184- llvm_unreachable (" unexpected element type!" );
185- case 8 :
186- rewriter.create <arm_sme::aarch64_sme_ld1b_horiz>(loc, maskOp, ptr,
187- tileId, tileSliceI32);
188- break ;
189- case 16 :
190- rewriter.create <arm_sme::aarch64_sme_ld1h_horiz>(loc, maskOp, ptr,
191- tileId, tileSliceI32);
192- break ;
193- case 32 :
194- rewriter.create <arm_sme::aarch64_sme_ld1w_horiz>(loc, maskOp, ptr,
195- tileId, tileSliceI32);
196- break ;
197- case 64 :
198- rewriter.create <arm_sme::aarch64_sme_ld1d_horiz>(loc, maskOp, ptr,
199- tileId, tileSliceI32);
200- break ;
201- case 128 :
202- rewriter.create <arm_sme::aarch64_sme_ld1q_horiz>(loc, maskOp, ptr,
203- tileId, tileSliceI32);
204- break ;
205- }
206- } else {
207- switch (tileElementWidth) {
208- default :
209- llvm_unreachable (" unexpected element type!" );
210- case 8 :
211- rewriter.create <arm_sme::aarch64_sme_ld1b_vert>(loc, maskOp, ptr,
212- tileId, tileSliceI32);
213- break ;
214- case 16 :
215- rewriter.create <arm_sme::aarch64_sme_ld1h_vert>(loc, maskOp, ptr,
216- tileId, tileSliceI32);
217- break ;
218- case 32 :
219- rewriter.create <arm_sme::aarch64_sme_ld1w_vert>(loc, maskOp, ptr,
220- tileId, tileSliceI32);
221- break ;
222- case 64 :
223- rewriter.create <arm_sme::aarch64_sme_ld1d_vert>(loc, maskOp, ptr,
224- tileId, tileSliceI32);
225- break ;
226- case 128 :
227- rewriter.create <arm_sme::aarch64_sme_ld1q_vert>(loc, maskOp, ptr,
228- tileId, tileSliceI32);
229- break ;
230- }
231- }
267+ createLoadTileSliceIntrinsic (rewriter, loc, tileType, layout, maskOp, ptr,
268+ tileId, tileSliceI32);
232269
233270 // The load intrinsics have no result, replace 'arm_sme.tile_load' with
234271 // the input tile to preserve dataflow.
@@ -249,9 +286,7 @@ struct StoreTileSliceConversion
249286 arm_sme::StoreTileSliceOp::Adaptor adaptor,
250287 ConversionPatternRewriter &rewriter) const override {
251288 auto loc = storeTileSliceOp.getLoc ();
252- auto tileType = storeTileSliceOp.getVectorType ();
253- auto tileElementType = tileType.getElementType ();
254- unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth ();
289+ auto tileVectorType = storeTileSliceOp.getVectorType ();
255290
256291 auto tileId = getTileIdOrError (storeTileSliceOp);
257292 if (!tileId)
@@ -271,58 +306,12 @@ struct StoreTileSliceConversion
271306 auto maskOp = storeTileSliceOp.getMask ();
272307
273308 arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout ();
309+ arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType (tileVectorType);
274310
275- if (layout == arm_sme::TileSliceLayout::Horizontal) {
276- switch (tileElementWidth) {
277- default :
278- llvm_unreachable (" unexpected element type!" );
279- case 8 :
280- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1b_horiz>(
281- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
282- break ;
283- case 16 :
284- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1h_horiz>(
285- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
286- break ;
287- case 32 :
288- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1w_horiz>(
289- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
290- break ;
291- case 64 :
292- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1d_horiz>(
293- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
294- break ;
295- case 128 :
296- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1q_horiz>(
297- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
298- break ;
299- }
300- } else {
301- switch (tileElementWidth) {
302- default :
303- llvm_unreachable (" unexpected element type!" );
304- case 8 :
305- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1b_vert>(
306- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
307- break ;
308- case 16 :
309- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1h_vert>(
310- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
311- break ;
312- case 32 :
313- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1w_vert>(
314- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
315- break ;
316- case 64 :
317- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1d_vert>(
318- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
319- break ;
320- case 128 :
321- rewriter.replaceOpWithNewOp <arm_sme::aarch64_sme_st1q_vert>(
322- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
323- break ;
324- }
325- }
311+ rewriter.replaceOp (storeTileSliceOp,
312+ createStoreTileSliceIntrinsic (rewriter, loc, tileType,
313+ layout, maskOp, ptr,
314+ tileId, tileSliceI32));
326315
327316 return success ();
328317 }
0 commit comments