44
55try :
66 from ..ir import *
7- from ._ods_common import get_op_result_or_value as _get_op_result_or_value
87 from ..dialects import pdl , transform
98except ImportError as e :
109 raise RuntimeError ("Error loading imports from extension module" ) from e
@@ -101,7 +100,7 @@ def _dispatch_mixed_values(
101100 static_values .append (size )
102101 else :
103102 static_values .append (ShapedType .get_dynamic_size ())
104- dynamic_values .append (_get_op_result_or_value ( size ) )
103+ dynamic_values .append (size )
105104 static_values = DenseI64ArrayAttr .get (static_values )
106105
107106 return (dynamic_values , packed_values , static_values )
@@ -204,9 +203,7 @@ class DecomposeOp:
204203 """Specialization for DecomposeOp class."""
205204
206205 def __init__ (self , target : Union [Operation , Value ], * , loc = None , ip = None ):
207- super ().__init__ (
208- pdl .OperationType .get (), _get_op_result_or_value (target ), loc = loc , ip = ip
209- )
206+ super ().__init__ (pdl .OperationType .get (), target , loc = loc , ip = ip )
210207
211208
212209class FuseIntoContainingOp :
@@ -277,9 +274,7 @@ class GeneralizeOp:
277274 """Specialization for GeneralizeOp class."""
278275
279276 def __init__ (self , target : Union [Operation , Value ], * , loc = None , ip = None ):
280- super ().__init__ (
281- pdl .OperationType .get (), _get_op_result_or_value (target ), loc = loc , ip = ip
282- )
277+ super ().__init__ (pdl .OperationType .get (), target , loc = loc , ip = ip )
283278
284279
285280class InterchangeOp :
@@ -296,7 +291,7 @@ def __init__(
296291 pdl_operation_type = pdl .OperationType .get ()
297292 super ().__init__ (
298293 pdl_operation_type ,
299- _get_op_result_or_value ( target ) ,
294+ target ,
300295 iterator_interchange = iterator_interchange ,
301296 loc = loc ,
302297 ip = ip ,
@@ -415,7 +410,7 @@ def match_op_names(
415410 loc = None ,
416411 ip = None ,
417412 ):
418- ...
413+ ...
419414
420415 @overload
421416 @classmethod
@@ -428,7 +423,7 @@ def match_op_names(
428423 loc = None ,
429424 ip = None ,
430425 ):
431- ...
426+ ...
432427
433428 @classmethod
434429 def match_op_names (
@@ -441,20 +436,20 @@ def match_op_names(
441436 ip = None ,
442437 ):
443438 if isinstance (result_type_or_target , Type ):
444- result_type = result_type_or_target
445- target = target_or_names
446- names = names_or_none
439+ result_type = result_type_or_target
440+ target = target_or_names
441+ names = names_or_none
447442 else :
448- result_type = transform .AnyOpType .get ()
449- target = result_type_or_target
450- names = target_or_names
443+ result_type = transform .AnyOpType .get ()
444+ target = result_type_or_target
445+ names = target_or_names
451446
452447 if isinstance (names , str ):
453- names = [names ]
448+ names = [names ]
454449
455450 return cls (
456451 result_type ,
457- _get_op_result_or_value ( target ) ,
452+ target ,
458453 ops = ArrayAttr .get (list (map (lambda s : StringAttr .get (s ), names ))),
459454 loc = loc ,
460455 ip = ip ,
@@ -479,7 +474,7 @@ def __init__(
479474 result_type ,
480475 result_type ,
481476 result_type ,
482- _get_op_result_or_value ( target ) ,
477+ target ,
483478 dimension = dimension ,
484479 target_size = target_size ,
485480 divisor = divisor ,
@@ -530,9 +525,7 @@ class ScalarizeOp:
530525
531526 def __init__ (self , target : Union [Operation , Value ], * , loc = None , ip = None ):
532527 pdl_operation_type = pdl .OperationType .get ()
533- super ().__init__ (
534- pdl_operation_type , _get_op_result_or_value (target ), loc = loc , ip = ip
535- )
528+ super ().__init__ (pdl_operation_type , target , loc = loc , ip = ip )
536529
537530
538531class SplitOp :
@@ -552,9 +545,7 @@ def __init__(
552545 dynamic_split_point = None
553546 else :
554547 static_split_point = ShapedType .get_dynamic_size ()
555- dynamic_split_point = _get_op_result_or_value (split_point )
556-
557- target = _get_op_result_or_value (target )
548+ dynamic_split_point = split_point
558549
559550 super ().__init__ (
560551 target .type ,
@@ -626,8 +617,6 @@ def __init__(
626617 )
627618 target = target_or_none
628619
629- target = _get_op_result_or_value (target )
630-
631620 super ().__init__ (
632621 target .type ,
633622 loop_types ,
@@ -750,7 +739,7 @@ def __init__(
750739 pdl_operation_type = pdl .OperationType .get ()
751740 super ().__init__ (
752741 pdl_operation_type ,
753- _get_op_result_or_value ( target ) ,
742+ target ,
754743 disable_multi_reduction_to_contract_patterns = disable_multi_reduction_to_contract_patterns ,
755744 disable_transfer_permutation_map_lowering_patterns = disable_transfer_permutation_map_lowering_patterns ,
756745 vectorize_nd_extract = vectorize_nd_extract ,
0 commit comments