@@ -1422,7 +1422,7 @@ def infer_shape(self, fgraph, node, shapes):
14221422 def _c_all (self , node , name , input_names , output_names , sub ):
14231423 [inp ] = node .inputs
14241424 [out ] = node .outputs
1425- ndim = inp .type .ndim
1425+ inp_ndim = inp .type .ndim
14261426
14271427 [inp_name ] = input_names
14281428 [out_name ] = output_names
@@ -1454,10 +1454,10 @@ def _c_all(self, node, name, input_names, output_names, sub):
14541454 assert var .dtype == node .outputs [0 ].dtype
14551455 return var .owner .op ._c_all (var .owner , name , input_names , output_names , sub )
14561456
1457- inp_dims = list (range (ndim ))
1457+ inp_dims = list (range (inp_ndim ))
14581458 non_reduced_dims = [i for i in inp_dims if i not in axis ]
1459- counter = iter (range (ndim ))
1460- acc_dims = ["x" if i in axis else next (counter ) for i in range (ndim )]
1459+ counter = iter (range (inp_ndim ))
1460+ acc_dims = ["x" if i in axis else next (counter ) for i in range (inp_ndim )]
14611461
14621462 sub = sub .copy ()
14631463 sub ["lv0" ] = inp_name
@@ -1484,7 +1484,9 @@ def _c_all(self, node, name, input_names, output_names, sub):
14841484 cgen .make_declare (
14851485 [acc_dims ], [out_dtype ], out_sub , compute_stride_jump = False
14861486 )
1487- + cgen .make_alloc ([non_reduced_dims ], out_dtype , sub )
1487+ + cgen .make_careduce_alloc (
1488+ inp_name , out_name , inp_ndim , axis , out_dtype , sub ["fail" ]
1489+ )
14881490 + cgen .make_checks (
14891491 [acc_dims ], [out_dtype ], out_sub , compute_stride_jump = False
14901492 )
@@ -1500,7 +1502,10 @@ def _c_all(self, node, name, input_names, output_names, sub):
15001502 cgen .make_declare (
15011503 [acc_dims ], [acc_dtype ], acc_sub , compute_stride_jump = False
15021504 )
1503- + cgen .make_alloc ([non_reduced_dims ], acc_dtype , sub )
1505+ + cgen .make_careduce_alloc (
1506+ inp_name , acc_name , inp_ndim , axis , out_dtype , sub ["fail" ]
1507+ )
1508+ + cgen .make_careduce_alloc ([non_reduced_dims ], acc_dtype , sub )
15041509 + cgen .make_checks (
15051510 [acc_dims ], [acc_dtype ], acc_sub , compute_stride_jump = False
15061511 )
@@ -1524,8 +1529,6 @@ def _c_all(self, node, name, input_names, output_names, sub):
15241529 elif identity is None :
15251530 raise TypeError (f"The { self .scalar_op } does not define an identity." )
15261531
1527- initial_value = f"{ acc_name } _i = { identity } ;"
1528-
15291532 inner_task = self .scalar_op .c_code (
15301533 Apply (
15311534 self .scalar_op ,
@@ -1544,28 +1547,16 @@ def _c_all(self, node, name, input_names, output_names, sub):
15441547 sub ,
15451548 )
15461549
1547- if out .type .ndim == 0 :
1548- # Simple case where everything is reduced, no need for loop ordering
1549- loop = cgen .make_complete_loop_careduce (
1550- inp_var = inp_name ,
1551- acc_var = acc_name ,
1552- inp_dtype = inp_dtype ,
1553- acc_dtype = acc_dtype ,
1554- initial_value = initial_value ,
1555- inner_task = inner_task ,
1556- fail_code = sub ["fail" ],
1557- )
1558- else :
1559- loop = cgen .make_reordered_loop_careduce (
1560- inp_var = inp_name ,
1561- acc_var = acc_name ,
1562- inp_dtype = inp_dtype ,
1563- acc_dtype = acc_dtype ,
1564- inp_ndim = ndim ,
1565- reduction_axes = axis ,
1566- initial_value = initial_value ,
1567- inner_task = inner_task ,
1568- )
1550+ loop = cgen .make_reordered_loop_careduce (
1551+ inp_var = inp_name ,
1552+ acc_var = acc_name ,
1553+ inp_dtype = inp_dtype ,
1554+ acc_dtype = acc_dtype ,
1555+ inp_ndim = inp_ndim ,
1556+ reduction_axes = axis ,
1557+ initial_value = identity ,
1558+ inner_task = inner_task ,
1559+ )
15691560
15701561 if acc_dtype != out_dtype :
15711562 cast = dedent (
@@ -1589,7 +1580,7 @@ def c_headers(self, **kwargs):
15891580
15901581 def c_code_cache_version_apply (self , node ):
15911582 # the version corresponding to the c code in this Op
1592- version = [10 ]
1583+ version = [11 ]
15931584
15941585 # now we insert versions for the ops on which we depend...
15951586 scalar_node = Apply (
0 commit comments