@@ -454,25 +454,44 @@ def test_random_concrete_shape():
454454 assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
455455
456456
457- @pytest .mark .xfail (reason = "size argument specified as a tuple is a `DimShuffle` node" )
458457def test_random_concrete_shape_subtensor ():
458+ """JAX should compile when a concrete value is passed for the `size` parameter.
459+
460+ This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
461+ inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
462+ inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
463+ rewrite.
464+
465+ JAX does not accept scalars as `size` or `shape` arguments, so this is a
466+ slight improvement over their API.
467+
468+ """
459469 rng = shared (np .random .RandomState (123 ))
460470 x_at = at .dmatrix ()
461471 out = at .random .normal (0 , 1 , size = x_at .shape [1 ], rng = rng )
462472 jax_fn = function ([x_at ], out , mode = jax_mode )
463473 assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
464474
465475
466- @pytest .mark .xfail (reason = "size argument specified as a tuple is a `MakeVector` node" )
467476def test_random_concrete_shape_subtensor_tuple ():
477+ """JAX should compile when a tuple of concrete values is passed for the `size` parameter.
478+
479+ This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
480+ inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
481+ scalar inputs into tuples of concrete values using the
482+ `jax_size_parameter_as_tuple` rewrite.
483+
484+ """
468485 rng = shared (np .random .RandomState (123 ))
469486 x_at = at .dmatrix ()
470487 out = at .random .normal (0 , 1 , size = (x_at .shape [0 ],), rng = rng )
471488 jax_fn = function ([x_at ], out , mode = jax_mode )
472489 assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
473490
474491
475- @pytest .mark .xfail (reason = "`size_at` should be specified as a static argument" )
492+ @pytest .mark .xfail (
493+ reason = "`size_at` should be specified as a static argument" , strict = True
494+ )
476495def test_random_concrete_shape_graph_input ():
477496 rng = shared (np .random .RandomState (123 ))
478497 size_at = at .scalar ()
0 commit comments