Constant fold branches of variadic add/mul #1422
Merged
Add this suggestion to a batch that can be applied as a single commit. This suggestion is invalid because no changes were made to the code. Suggestions cannot be applied while the pull request is closed. Suggestions cannot be applied while viewing a subset of changes. Only one suggestion per line can be applied in a batch. Add this suggestion to a batch that can be applied as a single commit. Applying suggestions on deleted lines is not supported. You must change the existing code in this line in order to create a valid suggestion. Outdated suggestions cannot be applied. This suggestion has been applied or marked resolved. Suggestions cannot be applied from pending reviews. Suggestions cannot be applied on multi-line comments. Suggestions cannot be applied while the pull request is queued to merge. Suggestion cannot be applied right now. Please check back later.
Refactoring and renaming:
local_add_mul_fusionfunction toflatten_nested_add_multo more precisely reflect how it works (one could also fuse non-nested add/mul, like the FusionOptimizer does). The function now explicitly tracksaddandmuloperations instead of relying on genericElemwisechecks. [1] [2] [3]New optimization for constant folding:
constant_fold_branches_of_add_mul, which folds constants in add/mul operations when it does not result in higher intermediate memory usage. This optimization is registered in a new sequence database,add_mul_flat_seqopt, which runs before generic elementwise fusion.The two rewrites are pulled out to a separate database so it's included in JAX rewrites (JAX does not include fusion rewrites). We've found this could help avoding XLA constant fold (CC @lucianopaz)
📚 Documentation preview 📚: https://pytensor--1422.org.readthedocs.build/en/1422/