44
55try :
66 from ..ir import *
7- from ..dialects import pdl , transform
7+ from ..dialects import transform
88except ImportError as e :
99 raise RuntimeError ("Error loading imports from extension module" ) from e
1010
@@ -203,7 +203,8 @@ class DecomposeOp:
203203 """Specialization for DecomposeOp class."""
204204
205205 def __init__ (self , target : Union [Operation , Value ], * , loc = None , ip = None ):
206- super ().__init__ (pdl .OperationType .get (), target , loc = loc , ip = ip )
206+ transformed_type = transform .AnyOpType .get ()
207+ super ().__init__ (transformed_type , target , loc = loc , ip = ip )
207208
208209
209210class FuseIntoContainingOp :
@@ -274,7 +275,8 @@ class GeneralizeOp:
274275 """Specialization for GeneralizeOp class."""
275276
276277 def __init__ (self , target : Union [Operation , Value ], * , loc = None , ip = None ):
277- super ().__init__ (pdl .OperationType .get (), target , loc = loc , ip = ip )
278+ transformed_type = transform .AnyOpType .get ()
279+ super ().__init__ (transformed_type , target , loc = loc , ip = ip )
278280
279281
280282class InterchangeOp :
@@ -288,9 +290,9 @@ def __init__(
288290 loc = None ,
289291 ip = None ,
290292 ):
291- pdl_operation_type = pdl . OperationType .get ()
293+ transformed_type = transform . AnyOpType .get ()
292294 super ().__init__ (
293- pdl_operation_type ,
295+ transformed_type ,
294296 target ,
295297 iterator_interchange = iterator_interchange ,
296298 loc = loc ,
@@ -503,11 +505,11 @@ def __init__(
503505 ):
504506 transpose_paddings = _get_int_array_array_attr (transpose_paddings )
505507
506- pdl_operation_type = pdl . OperationType .get ()
508+ any_op_type = transform . AnyOpType .get ()
507509 super ().__init__ (
508- pdl_operation_type ,
509- pdl_operation_type ,
510- pdl_operation_type ,
510+ any_op_type ,
511+ any_op_type ,
512+ any_op_type ,
511513 target ,
512514 padding_values = padding_values ,
513515 padding_dimensions = padding_dimensions ,
@@ -524,8 +526,8 @@ class ScalarizeOp:
524526 """Specialization for ScalarizeOp class."""
525527
526528 def __init__ (self , target : Union [Operation , Value ], * , loc = None , ip = None ):
527- pdl_operation_type = pdl . OperationType .get ()
528- super ().__init__ (pdl_operation_type , target , loc = loc , ip = ip )
529+ result_type = transform . AnyOpType .get ()
530+ super ().__init__ (result_type , target , loc = loc , ip = ip )
529531
530532
531533class SplitOp :
@@ -736,9 +738,9 @@ def __init__(
736738 loc = None ,
737739 ip = None ,
738740 ):
739- pdl_operation_type = pdl . OperationType .get ()
741+ transformed_type = transform . AnyOpType .get ()
740742 super ().__init__ (
741- pdl_operation_type ,
743+ transformed_type ,
742744 target ,
743745 disable_multi_reduction_to_contract_patterns = disable_multi_reduction_to_contract_patterns ,
744746 disable_transfer_permutation_map_lowering_patterns = disable_transfer_permutation_map_lowering_patterns ,
0 commit comments