@@ -3231,7 +3231,7 @@ def test_local_prod_of_div(self):
32313231class TestLocalReduce :
32323232 def setup_method (self ):
32333233 self .mode = get_default_mode ().including (
3234- "canonicalize" , "specialize" , "uncanonicalize" , "local_max_and_argmax"
3234+ "canonicalize" , "specialize" , "uncanonicalize"
32353235 )
32363236
32373237 def test_local_reduce_broadcast_all_0 (self ):
@@ -3304,62 +3304,92 @@ def test_local_reduce_broadcast_some_1(self):
33043304 isinstance (node .op , CAReduce ) for node in f .maker .fgraph .toposort ()
33053305 )
33063306
3307- def test_local_reduce_join (self ):
3307+
3308+ class TestReduceJoin :
3309+ def setup_method (self ):
3310+ self .mode = get_default_mode ().including ("canonicalize" , "specialize" )
3311+
3312+ @pytest .mark .parametrize (
3313+ "op, nin" , [(pt_sum , 3 ), (pt_max , 2 ), (pt_min , 2 ), (prod , 3 )]
3314+ )
3315+ def test_local_reduce_join (self , op , nin ):
33083316 vx = matrix ()
33093317 vy = matrix ()
33103318 vz = matrix ()
33113319 x = np .asarray ([[1 , 0 ], [3 , 4 ]], dtype = config .floatX )
33123320 y = np .asarray ([[4 , 0 ], [2 , 1 ]], dtype = config .floatX )
33133321 z = np .asarray ([[5 , 0 ], [1 , 2 ]], dtype = config .floatX )
3314- # Test different reduction scalar operation
3315- for out , res in [
3316- (pt_max ((vx , vy ), 0 ), np .max ((x , y ), 0 )),
3317- (pt_min ((vx , vy ), 0 ), np .min ((x , y ), 0 )),
3318- (pt_sum ((vx , vy , vz ), 0 ), np .sum ((x , y , z ), 0 )),
3319- (prod ((vx , vy , vz ), 0 ), np .prod ((x , y , z ), 0 )),
3320- (prod ((vx , vy .T , vz ), 0 ), np .prod ((x , y .T , z ), 0 )),
3321- ]:
3322- f = function ([vx , vy , vz ], out , on_unused_input = "ignore" , mode = self .mode )
3323- assert (f (x , y , z ) == res ).all (), out
3324- topo = f .maker .fgraph .toposort ()
3325- assert len (topo ) <= 2 , out
3326- assert isinstance (topo [- 1 ].op , Elemwise ), out
33273322
3323+ inputs = (vx , vy , vz )[:nin ]
3324+ test_values = (x , y , z )[:nin ]
3325+
3326+ out = op (inputs , axis = 0 )
3327+ f = function (inputs , out , mode = self .mode )
3328+ np .testing .assert_allclose (
3329+ f (* test_values ), getattr (np , op .__name__ )(test_values , axis = 0 )
3330+ )
3331+ topo = f .maker .fgraph .toposort ()
3332+ assert len (topo ) <= 2
3333+ assert isinstance (topo [- 1 ].op , Elemwise )
3334+
3335+ def test_type (self ):
33283336 # Test different axis for the join and the reduction
33293337 # We must force the dtype, of otherwise, this tests will fail
33303338 # on 32 bit systems
33313339 A = shared (np .array ([1 , 2 , 3 , 4 , 5 ], dtype = "int64" ))
33323340
33333341 f = function ([], pt_sum (pt .stack ([A , A ]), axis = 0 ), mode = self .mode )
3334- utt .assert_allclose (f (), [2 , 4 , 6 , 8 , 10 ])
3342+ np . testing .assert_allclose (f (), [2 , 4 , 6 , 8 , 10 ])
33353343 topo = f .maker .fgraph .toposort ()
33363344 assert isinstance (topo [- 1 ].op , Elemwise )
33373345
33383346 # Test a case that was bugged in a old PyTensor bug
33393347 f = function ([], pt_sum (pt .stack ([A , A ]), axis = 1 ), mode = self .mode )
33403348
3341- utt .assert_allclose (f (), [15 , 15 ])
3349+ np . testing .assert_allclose (f (), [15 , 15 ])
33423350 topo = f .maker .fgraph .toposort ()
33433351 assert not isinstance (topo [- 1 ].op , Elemwise )
33443352
33453353 # This case could be rewritten
33463354 A = shared (np .array ([1 , 2 , 3 , 4 , 5 ]).reshape (5 , 1 ))
33473355 f = function ([], pt_sum (pt .concatenate ((A , A ), axis = 1 ), axis = 1 ), mode = self .mode )
3348- utt .assert_allclose (f (), [2 , 4 , 6 , 8 , 10 ])
3356+ np . testing .assert_allclose (f (), [2 , 4 , 6 , 8 , 10 ])
33493357 topo = f .maker .fgraph .toposort ()
33503358 assert not isinstance (topo [- 1 ].op , Elemwise )
33513359
33523360 A = shared (np .array ([1 , 2 , 3 , 4 , 5 ]).reshape (5 , 1 ))
33533361 f = function ([], pt_sum (pt .concatenate ((A , A ), axis = 1 ), axis = 0 ), mode = self .mode )
3354- utt .assert_allclose (f (), [15 , 15 ])
3362+ np . testing .assert_allclose (f (), [15 , 15 ])
33553363 topo = f .maker .fgraph .toposort ()
33563364 assert not isinstance (topo [- 1 ].op , Elemwise )
33573365
3366+ def test_not_supported_axis_none (self ):
33583367 # Test that the rewrite does not crash in one case where it
33593368 # is not applied. Reported at
33603369 # https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion
3370+ vx = matrix ()
3371+ vy = matrix ()
3372+ vz = matrix ()
3373+ x = np .asarray ([[1 , 0 ], [3 , 4 ]], dtype = config .floatX )
3374+ y = np .asarray ([[4 , 0 ], [2 , 1 ]], dtype = config .floatX )
3375+ z = np .asarray ([[5 , 0 ], [1 , 2 ]], dtype = config .floatX )
3376+
33613377 out = pt_sum ([vx , vy , vz ], axis = None )
3362- f = function ([vx , vy , vz ], out )
3378+ f = function ([vx , vy , vz ], out , mode = self .mode )
3379+ np .testing .assert_allclose (f (x , y , z ), np .sum ([x , y , z ]))
3380+
3381+ def test_not_supported_unequal_shapes (self ):
3382+ # Not the same shape along the join axis
3383+ vx = matrix (shape = (1 , 3 ))
3384+ vy = matrix (shape = (2 , 3 ))
3385+ x = np .asarray ([[1 , 0 , 1 ]], dtype = config .floatX )
3386+ y = np .asarray ([[4 , 0 , 1 ], [2 , 1 , 1 ]], dtype = config .floatX )
3387+ out = pt_sum (join (0 , vx , vy ), axis = 0 )
3388+
3389+ f = function ([vx , vy ], out , mode = self .mode )
3390+ np .testing .assert_allclose (
3391+ f (x , y ), np .sum (np .concatenate ([x , y ], axis = 0 ), axis = 0 )
3392+ )
33633393
33643394
33653395def test_local_useless_adds ():
0 commit comments