@@ -86,16 +86,12 @@ def test_jax_PosDefMatrix():
8686 pytest .param (1 ),
8787 pytest .param (
8888 2 ,
89- marks = pytest .mark .skipif (
90- len (jax .devices ()) < 2 , reason = "not enough devices"
91- ),
89+ marks = pytest .mark .skipif (len (jax .devices ()) < 2 , reason = "not enough devices" ),
9290 ),
9391 ],
9492)
9593@pytest .mark .parametrize ("postprocessing_vectorize" , ["scan" , "vmap" ])
96- def test_transform_samples (
97- sampler , postprocessing_backend , chains , postprocessing_vectorize
98- ):
94+ def test_transform_samples (sampler , postprocessing_backend , chains , postprocessing_vectorize ):
9995 pytensor .config .on_opt_error = "raise"
10096 np .random .seed (13244 )
10197
@@ -242,9 +238,7 @@ def test_replace_shared_variables():
242238 x = pytensor .shared (5 , name = "shared_x" )
243239
244240 new_x = _replace_shared_variables ([x ])
245- shared_variables = [
246- var for var in graph_inputs (new_x ) if isinstance (var , SharedVariable )
247- ]
241+ shared_variables = [var for var in graph_inputs (new_x ) if isinstance (var , SharedVariable )]
248242 assert not shared_variables
249243
250244 x .default_update = x + 1
@@ -332,30 +326,23 @@ def test_idata_kwargs(
332326
333327 posterior = idata .get ("posterior" )
334328 assert posterior is not None
335- x_dim_expected = idata_kwargs .get (
336- "dims" , model_test_idata_kwargs .named_vars_to_dims
337- )["x" ][0 ]
329+ x_dim_expected = idata_kwargs .get ("dims" , model_test_idata_kwargs .named_vars_to_dims )["x" ][0 ]
338330 assert x_dim_expected is not None
339331 assert posterior ["x" ].dims [- 1 ] == x_dim_expected
340332
341- x_coords_expected = idata_kwargs .get ("coords" , model_test_idata_kwargs .coords )[
342- x_dim_expected
343- ]
333+ x_coords_expected = idata_kwargs .get ("coords" , model_test_idata_kwargs .coords )[x_dim_expected ]
344334 assert x_coords_expected is not None
345335 assert list (x_coords_expected ) == list (posterior ["x" ].coords [x_dim_expected ].values )
346336
347337 assert posterior ["z" ].dims [2 ] == "z_coord"
348338 assert np .all (
349- posterior ["z" ].coords ["z_coord" ].values
350- == np .array (["apple" , "banana" , "orange" ])
339+ posterior ["z" ].coords ["z_coord" ].values == np .array (["apple" , "banana" , "orange" ])
351340 )
352341
353342
354343def test_get_batched_jittered_initial_points ():
355344 with pm .Model () as model :
356- x = pm .MvNormal (
357- "x" , mu = np .zeros (3 ), cov = np .eye (3 ), shape = (2 , 3 ), initval = np .zeros ((2 , 3 ))
358- )
345+ x = pm .MvNormal ("x" , mu = np .zeros (3 ), cov = np .eye (3 ), shape = (2 , 3 ), initval = np .zeros ((2 , 3 )))
359346
360347 # No jitter
361348 ips = _get_batched_jittered_initial_points (
@@ -364,17 +351,13 @@ def test_get_batched_jittered_initial_points():
364351 assert np .all (ips [0 ] == 0 )
365352
366353 # Single chain
367- ips = _get_batched_jittered_initial_points (
368- model = model , chains = 1 , random_seed = 1 , initvals = None
369- )
354+ ips = _get_batched_jittered_initial_points (model = model , chains = 1 , random_seed = 1 , initvals = None )
370355
371356 assert ips [0 ].shape == (2 , 3 )
372357 assert np .all (ips [0 ] != 0 )
373358
374359 # Multiple chains
375- ips = _get_batched_jittered_initial_points (
376- model = model , chains = 2 , random_seed = 1 , initvals = None
377- )
360+ ips = _get_batched_jittered_initial_points (model = model , chains = 2 , random_seed = 1 , initvals = None )
378361
379362 assert ips [0 ].shape == (2 , 2 , 3 )
380363 assert np .all (ips [0 ][0 ] != ips [0 ][1 ])
@@ -394,9 +377,7 @@ def test_get_batched_jittered_initial_points():
394377 pytest .param (1 ),
395378 pytest .param (
396379 2 ,
397- marks = pytest .mark .skipif (
398- len (jax .devices ()) < 2 , reason = "not enough devices"
399- ),
380+ marks = pytest .mark .skipif (len (jax .devices ()) < 2 , reason = "not enough devices" ),
400381 ),
401382 ],
402383)
@@ -420,12 +401,8 @@ def test_seeding(chains, random_seed, sampler):
420401 assert all_equal
421402
422403 if chains > 1 :
423- assert np .all (
424- result1 .posterior ["x" ].sel (chain = 0 ) != result1 .posterior ["x" ].sel (chain = 1 )
425- )
426- assert np .all (
427- result2 .posterior ["x" ].sel (chain = 0 ) != result2 .posterior ["x" ].sel (chain = 1 )
428- )
404+ assert np .all (result1 .posterior ["x" ].sel (chain = 0 ) != result1 .posterior ["x" ].sel (chain = 1 ))
405+ assert np .all (result2 .posterior ["x" ].sel (chain = 0 ) != result2 .posterior ["x" ].sel (chain = 1 ))
429406
430407
431408@mock .patch ("numpyro.infer.MCMC" )
@@ -555,7 +532,21 @@ def test_vi_sampling_jax(method):
555532 pm .fit (10 , method = method , fn_kwargs = dict (mode = "JAX" ))
556533
557534
558- @pytest .mark .xfail (reason = "Due to https://github.com/pymc-devs/pytensor/issues/595" )
535+ @pytest .mark .xfail (
536+ reason = """
537+ During equilibrium rewriter this error happens. Probably one of the routines in SVGD is problematic.
538+
539+ TypeError: The broadcast pattern of the output of scan
540+ (Matrix(float64, shape=(?, 1))) is inconsistent with the one provided in `output_info`
541+ (Vector(float64, shape=(?,))). The output on axis 0 is `True`, but it is `False` on axis
542+ 1 in `output_info`. This can happen if one of the dimension is fixed to 1 in the input,
543+ while it is still variable in the output, or vice-verca. You have to make them consistent,
544+ e.g. using pytensor.tensor.{unbroadcast, specify_broadcastable}.
545+
546+ Instead of fixing this error it makes sense to rework the internals of the variational to utilize
547+ pytensor vectorize instead of scan.
548+ """
549+ )
559550def test_vi_sampling_jax_svgd ():
560551 with pm .Model ():
561552 x = pm .Normal ("x" )
0 commit comments