- Notifications
You must be signed in to change notification settings - Fork 149
Open
Labels
Description
Description
When we have a JAXOp in the final graph in a non-jax backend we may want to manipulate the JAX Op for efficiency. We could rewrite Blockwise(JAXOp) -> JAXOp whose inner function is vectorized.
If we have both the Op and the gradient, we could rewrite into a single op that uses value_and_grad under the hood.
And similarly if we only need the shape we could rewrite into an Op whose internal function only computes the shape. This last one is only relevant if the original Op doesn't remain in the graph.