@@ -3231,7 +3231,7 @@ def test_local_prod_of_div(self):
32313231class TestLocalReduce :
32323232 def setup_method (self ):
32333233 self .mode = get_default_mode ().including (
3234- "canonicalize" , "specialize" , "uncanonicalize" , "local_max_and_argmax"
3234+ "canonicalize" , "specialize" , "uncanonicalize"
32353235 )
32363236
32373237 def test_local_reduce_broadcast_all_0 (self ):
@@ -3304,62 +3304,94 @@ def test_local_reduce_broadcast_some_1(self):
33043304 isinstance (node .op , CAReduce ) for node in f .maker .fgraph .toposort ()
33053305 )
33063306
3307- def test_local_reduce_join (self ):
3307+
3308+ class TestReduceJoin :
3309+ def setup_method (self ):
3310+ self .mode = get_default_mode ().including (
3311+ "canonicalize" , "specialize" , "uncanonicalize"
3312+ )
3313+
3314+ @pytest .mark .parametrize (
3315+ "op, nin" , [(pt_sum , 3 ), (pt_max , 2 ), (pt_min , 2 ), (prod , 3 )]
3316+ )
3317+ def test_local_reduce_join (self , op , nin ):
33083318 vx = matrix ()
33093319 vy = matrix ()
33103320 vz = matrix ()
33113321 x = np .asarray ([[1 , 0 ], [3 , 4 ]], dtype = config .floatX )
33123322 y = np .asarray ([[4 , 0 ], [2 , 1 ]], dtype = config .floatX )
33133323 z = np .asarray ([[5 , 0 ], [1 , 2 ]], dtype = config .floatX )
3314- # Test different reduction scalar operation
3315- for out , res in [
3316- (pt_max ((vx , vy ), 0 ), np .max ((x , y ), 0 )),
3317- (pt_min ((vx , vy ), 0 ), np .min ((x , y ), 0 )),
3318- (pt_sum ((vx , vy , vz ), 0 ), np .sum ((x , y , z ), 0 )),
3319- (prod ((vx , vy , vz ), 0 ), np .prod ((x , y , z ), 0 )),
3320- (prod ((vx , vy .T , vz ), 0 ), np .prod ((x , y .T , z ), 0 )),
3321- ]:
3322- f = function ([vx , vy , vz ], out , on_unused_input = "ignore" , mode = self .mode )
3323- assert (f (x , y , z ) == res ).all (), out
3324- topo = f .maker .fgraph .toposort ()
3325- assert len (topo ) <= 2 , out
3326- assert isinstance (topo [- 1 ].op , Elemwise ), out
33273324
3325+ inputs = (vx , vy , vz )[:nin ]
3326+ test_values = (x , y , z )[:nin ]
3327+
3328+ out = op (inputs , axis = 0 )
3329+ f = function (inputs , out , mode = self .mode )
3330+ np .testing .assert_allclose (
3331+ f (* test_values ), getattr (np , op .__name__ )(test_values , axis = 0 )
3332+ )
3333+ topo = f .maker .fgraph .toposort ()
3334+ assert len (topo ) <= 2
3335+ assert isinstance (topo [- 1 ].op , Elemwise )
3336+
3337+ def test_type (self ):
33283338 # Test different axis for the join and the reduction
33293339 # We must force the dtype, of otherwise, this tests will fail
33303340 # on 32 bit systems
33313341 A = shared (np .array ([1 , 2 , 3 , 4 , 5 ], dtype = "int64" ))
33323342
33333343 f = function ([], pt_sum (pt .stack ([A , A ]), axis = 0 ), mode = self .mode )
3334- utt .assert_allclose (f (), [2 , 4 , 6 , 8 , 10 ])
3344+ np . testing .assert_allclose (f (), [2 , 4 , 6 , 8 , 10 ])
33353345 topo = f .maker .fgraph .toposort ()
33363346 assert isinstance (topo [- 1 ].op , Elemwise )
33373347
33383348 # Test a case that was bugged in a old PyTensor bug
33393349 f = function ([], pt_sum (pt .stack ([A , A ]), axis = 1 ), mode = self .mode )
33403350
3341- utt .assert_allclose (f (), [15 , 15 ])
3351+ np . testing .assert_allclose (f (), [15 , 15 ])
33423352 topo = f .maker .fgraph .toposort ()
33433353 assert not isinstance (topo [- 1 ].op , Elemwise )
33443354
33453355 # This case could be rewritten
33463356 A = shared (np .array ([1 , 2 , 3 , 4 , 5 ]).reshape (5 , 1 ))
33473357 f = function ([], pt_sum (pt .concatenate ((A , A ), axis = 1 ), axis = 1 ), mode = self .mode )
3348- utt .assert_allclose (f (), [2 , 4 , 6 , 8 , 10 ])
3358+ np . testing .assert_allclose (f (), [2 , 4 , 6 , 8 , 10 ])
33493359 topo = f .maker .fgraph .toposort ()
33503360 assert not isinstance (topo [- 1 ].op , Elemwise )
33513361
33523362 A = shared (np .array ([1 , 2 , 3 , 4 , 5 ]).reshape (5 , 1 ))
33533363 f = function ([], pt_sum (pt .concatenate ((A , A ), axis = 1 ), axis = 0 ), mode = self .mode )
3354- utt .assert_allclose (f (), [15 , 15 ])
3364+ np . testing .assert_allclose (f (), [15 , 15 ])
33553365 topo = f .maker .fgraph .toposort ()
33563366 assert not isinstance (topo [- 1 ].op , Elemwise )
33573367
3368+ def test_not_supported_axis_none (self ):
33583369 # Test that the rewrite does not crash in one case where it
33593370 # is not applied. Reported at
33603371 # https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion
3372+ vx = matrix ()
3373+ vy = matrix ()
3374+ vz = matrix ()
3375+ x = np .asarray ([[1 , 0 ], [3 , 4 ]], dtype = config .floatX )
3376+ y = np .asarray ([[4 , 0 ], [2 , 1 ]], dtype = config .floatX )
3377+ z = np .asarray ([[5 , 0 ], [1 , 2 ]], dtype = config .floatX )
3378+
33613379 out = pt_sum ([vx , vy , vz ], axis = None )
3362- f = function ([vx , vy , vz ], out )
3380+ f = function ([vx , vy , vz ], out , mode = self .mode )
3381+ np .testing .assert_allclose (f (x , y , z ), np .sum ([x , y , z ]))
3382+
3383+ def test_not_supported_unequal_shapes (self ):
3384+ # Not the same shape along the join axis
3385+ vx = matrix (shape = (1 , 3 ))
3386+ vy = matrix (shape = (2 , 3 ))
3387+ x = np .asarray ([[1 , 0 , 1 ]], dtype = config .floatX )
3388+ y = np .asarray ([[4 , 0 , 1 ], [2 , 1 , 1 ]], dtype = config .floatX )
3389+ out = pt_sum (join (0 , vx , vy ), axis = 0 )
3390+
3391+ f = function ([vx , vy ], out , mode = self .mode )
3392+ np .testing .assert_allclose (
3393+ f (x , y ), np .sum (np .concatenate ([x , y ], axis = 0 ), axis = 0 )
3394+ )
33633395
33643396
33653397def test_local_useless_adds ():
0 commit comments