Skip to content

Conversation

@ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented May 27, 2025

Refactoring and renaming:

  • Renamed the local_add_mul_fusion function to flatten_nested_add_mul to more precisely reflect how it works (one could also fuse non-nested add/mul, like the FusionOptimizer does). The function now explicitly tracks add and mul operations instead of relying on generic Elemwise checks. [1] [2] [3]

New optimization for constant folding:

  • Introduced a new rewrite function, constant_fold_branches_of_add_mul, which folds constants in add/mul operations when it does not result in higher intermediate memory usage. This optimization is registered in a new sequence database, add_mul_flat_seqopt, which runs before generic elementwise fusion.

The two rewrites are pulled out to a separate database so it's included in JAX rewrites (JAX does not include fusion rewrites). We've found this could help avoding XLA constant fold (CC @lucianopaz)


📚 Documentation preview 📚: https://pytensor--1422.org.readthedocs.build/en/1422/

Copy link
Member

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t know why the test fails but it looks like the fusion rewrite is applied only once. Maybe the equilibrium rewrite that you took out should be added back in?

@lucianopaz
Copy link
Member

lucianopaz commented May 29, 2025

@ricardoV94, I just went through your branch's code and found that the error is coming from the fact that the TestFusion class is including: "canonicalize", "fusion", and "inplace" rewrite databases. The add_mul flatten and fusion rewrites that you moved or added here are only included in "fast_run". My question then is whether your rewrites should also be added to fusion, or if you only want to add the fast_run database rewrites to the TestFusion includes?

@ricardoV94 ricardoV94 force-pushed the constant_fold_variadic_add_mul branch from ca12b58 to 082e1b7 Compare May 30, 2025 10:24
include=[
"canonicalize",
"fusion",
"add_mul_flat",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the change needed to get the test to pass @lucianopaz

@ricardoV94
Copy link
Member Author

@lucianopaz I came to the same conclusion, I just added the rewrite explicitly. Mentioned in an inline comment above

@ricardoV94 ricardoV94 requested a review from lucianopaz May 30, 2025 10:30
@ricardoV94 ricardoV94 force-pushed the constant_fold_variadic_add_mul branch from 082e1b7 to 70db72e Compare May 30, 2025 10:35
Copy link
Member

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @ricardoV94 !

@codecov
Copy link

codecov bot commented May 30, 2025

Codecov Report

❌ Patch coverage is 95.34884% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.13%. Comparing base (5a462e9) to head (70db72e).
⚠️ Report is 136 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/rewriting/elemwise.py 95.34% 1 Missing and 1 partial ⚠️
Additional details and impacted files

Impacted file tree graph

@@ Coverage Diff @@ ## main #1422 +/- ## ========================================== + Coverage 82.11% 82.13% +0.01%  ========================================== Files 211 211 Lines 49743 49773 +30 Branches 8824 8830 +6 ========================================== + Hits 40847 40879 +32  + Misses 6715 6714 -1  + Partials 2181 2180 -1 
Files with missing lines Coverage Δ
pytensor/tensor/rewriting/elemwise.py 91.77% <95.34%> (+0.74%) ⬆️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
@ricardoV94 ricardoV94 merged commit ff09268 into pymc-devs:main May 30, 2025
72 of 73 checks passed
@ricardoV94 ricardoV94 deleted the constant_fold_variadic_add_mul branch May 30, 2025 11:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment