@@ -85,16 +85,12 @@ def test_jax_PosDefMatrix():
8585 pytest .param (1 ),
8686 pytest .param (
8787 2 ,
88- marks = pytest .mark .skipif (
89- len (jax .devices ()) < 2 , reason = "not enough devices"
90- ),
88+ marks = pytest .mark .skipif (len (jax .devices ()) < 2 , reason = "not enough devices" ),
9189 ),
9290 ],
9391)
9492@pytest .mark .parametrize ("postprocessing_vectorize" , ["scan" , "vmap" ])
95- def test_transform_samples (
96- sampler , postprocessing_backend , chains , postprocessing_vectorize
97- ):
93+ def test_transform_samples (sampler , postprocessing_backend , chains , postprocessing_vectorize ):
9894 pytensor .config .on_opt_error = "raise"
9995 np .random .seed (13244 )
10096
@@ -241,9 +237,7 @@ def test_replace_shared_variables():
241237 x = pytensor .shared (5 , name = "shared_x" )
242238
243239 new_x = _replace_shared_variables ([x ])
244- shared_variables = [
245- var for var in graph_inputs (new_x ) if isinstance (var , SharedVariable )
246- ]
240+ shared_variables = [var for var in graph_inputs (new_x ) if isinstance (var , SharedVariable )]
247241 assert not shared_variables
248242
249243 x .default_update = x + 1
@@ -333,30 +327,23 @@ def test_idata_kwargs(
333327
334328 posterior = idata .get ("posterior" )
335329 assert posterior is not None
336- x_dim_expected = idata_kwargs .get (
337- "dims" , model_test_idata_kwargs .named_vars_to_dims
338- )["x" ][0 ]
330+ x_dim_expected = idata_kwargs .get ("dims" , model_test_idata_kwargs .named_vars_to_dims )["x" ][0 ]
339331 assert x_dim_expected is not None
340332 assert posterior ["x" ].dims [- 1 ] == x_dim_expected
341333
342- x_coords_expected = idata_kwargs .get ("coords" , model_test_idata_kwargs .coords )[
343- x_dim_expected
344- ]
334+ x_coords_expected = idata_kwargs .get ("coords" , model_test_idata_kwargs .coords )[x_dim_expected ]
345335 assert x_coords_expected is not None
346336 assert list (x_coords_expected ) == list (posterior ["x" ].coords [x_dim_expected ].values )
347337
348338 assert posterior ["z" ].dims [2 ] == "z_coord"
349339 assert np .all (
350- posterior ["z" ].coords ["z_coord" ].values
351- == np .array (["apple" , "banana" , "orange" ])
340+ posterior ["z" ].coords ["z_coord" ].values == np .array (["apple" , "banana" , "orange" ])
352341 )
353342
354343
355344def test_get_batched_jittered_initial_points ():
356345 with pm .Model () as model :
357- x = pm .MvNormal (
358- "x" , mu = np .zeros (3 ), cov = np .eye (3 ), shape = (2 , 3 ), initval = np .zeros ((2 , 3 ))
359- )
346+ x = pm .MvNormal ("x" , mu = np .zeros (3 ), cov = np .eye (3 ), shape = (2 , 3 ), initval = np .zeros ((2 , 3 )))
360347
361348 # No jitter
362349 ips = _get_batched_jittered_initial_points (
@@ -365,17 +352,13 @@ def test_get_batched_jittered_initial_points():
365352 assert np .all (ips [0 ] == 0 )
366353
367354 # Single chain
368- ips = _get_batched_jittered_initial_points (
369- model = model , chains = 1 , random_seed = 1 , initvals = None
370- )
355+ ips = _get_batched_jittered_initial_points (model = model , chains = 1 , random_seed = 1 , initvals = None )
371356
372357 assert ips [0 ].shape == (2 , 3 )
373358 assert np .all (ips [0 ] != 0 )
374359
375360 # Multiple chains
376- ips = _get_batched_jittered_initial_points (
377- model = model , chains = 2 , random_seed = 1 , initvals = None
378- )
361+ ips = _get_batched_jittered_initial_points (model = model , chains = 2 , random_seed = 1 , initvals = None )
379362
380363 assert ips [0 ].shape == (2 , 2 , 3 )
381364 assert np .all (ips [0 ][0 ] != ips [0 ][1 ])
@@ -395,9 +378,7 @@ def test_get_batched_jittered_initial_points():
395378 pytest .param (1 ),
396379 pytest .param (
397380 2 ,
398- marks = pytest .mark .skipif (
399- len (jax .devices ()) < 2 , reason = "not enough devices"
400- ),
381+ marks = pytest .mark .skipif (len (jax .devices ()) < 2 , reason = "not enough devices" ),
401382 ),
402383 ],
403384)
@@ -421,12 +402,8 @@ def test_seeding(chains, random_seed, sampler):
421402 assert all_equal
422403
423404 if chains > 1 :
424- assert np .all (
425- result1 .posterior ["x" ].sel (chain = 0 ) != result1 .posterior ["x" ].sel (chain = 1 )
426- )
427- assert np .all (
428- result2 .posterior ["x" ].sel (chain = 0 ) != result2 .posterior ["x" ].sel (chain = 1 )
429- )
405+ assert np .all (result1 .posterior ["x" ].sel (chain = 0 ) != result1 .posterior ["x" ].sel (chain = 1 ))
406+ assert np .all (result2 .posterior ["x" ].sel (chain = 0 ) != result2 .posterior ["x" ].sel (chain = 1 ))
430407
431408
432409@mock .patch ("numpyro.infer.MCMC" )
@@ -541,7 +518,21 @@ def test_vi_sampling_jax(method):
541518 pm .fit (10 , method = method , fn_kwargs = dict (mode = "JAX" ))
542519
543520
544- @pytest .mark .xfail (reason = "Due to https://github.com/pymc-devs/pytensor/issues/595" )
521+ @pytest .mark .xfail (
522+ reason = """
523+ During equilibrium rewriter this error happens. Probably one of the routines in SVGD is problematic.
524+
525+ TypeError: The broadcast pattern of the output of scan
526+ (Matrix(float64, shape=(?, 1))) is inconsistent with the one provided in `output_info`
527+ (Vector(float64, shape=(?,))). The output on axis 0 is `True`, but it is `False` on axis
528+ 1 in `output_info`. This can happen if one of the dimension is fixed to 1 in the input,
529+ while it is still variable in the output, or vice-verca. You have to make them consistent,
530+ e.g. using pytensor.tensor.{unbroadcast, specify_broadcastable}.
531+
532+ Instead of fixing this error it makes sense to rework the internals of the variational to utilize
533+ pytensor vectorize instead of scan.
534+ """
535+ )
545536def test_vi_sampling_jax_svgd ():
546537 with pm .Model ():
547538 x = pm .Normal ("x" )
0 commit comments