Skip to content

Conversation

@ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Apr 24, 2025

Reimplementing the core logic in the numba overload of convolve/correlate gives a speedup of 6x in the benchmarked test with relatively small inputs. I guess the overloads don't optimize/propagate constant checks as well? It's a bit surprising but the results are crystal clear.

Also added a rewrite to optimize the gradient of valid convolutions wrt to the smallest inputs, in which case we don't need a full convolve. This is done at the rewrite level because static shape may not be known at the time of grad.

Finally, renamed Conv1d to Convolve1d which is more in line with the user-facing function

@codecov
Copy link

codecov bot commented Apr 25, 2025

Codecov Report

❌ Patch coverage is 50.53763% with 46 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.01%. Comparing base (e98cbbc) to head (f0ef8fb).
⚠️ Report is 187 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/numba/dispatch/signal/conv.py 26.08% 33 Missing and 1 partial ⚠️
pytensor/tensor/rewriting/conv.py 70.00% 6 Missing and 6 partials ⚠️
Additional details and impacted files

Impacted file tree graph

@@ Coverage Diff @@ ## main #1378 +/- ## ========================================== - Coverage 82.07% 82.01% -0.06%  ========================================== Files 206 207 +1 Lines 49174 49250 +76 Branches 8720 8734 +14 ========================================== + Hits 40359 40394 +35  - Misses 6656 6692 +36  - Partials 2159 2164 +5 
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/signal/conv.py 100.00% <100.00%> (ø)
pytensor/tensor/signal/conv.py 97.05% <100.00%> (ø)
pytensor/tensor/rewriting/conv.py 70.00% <70.00%> (ø)
pytensor/link/numba/dispatch/signal/conv.py 32.00% <26.08%> (-58.91%) ⬇️

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR reimplements the core logic of convolve1d in the numba backend for a 6× speedup in benchmarks with small inputs, while also optimizing the gradient computation for valid convolutions when the smaller input’s shape is known statically. In addition, the PR renames Conv1d to Convolve1d for improved consistency in function naming and updates various test and dispatch files to reflect these changes.

  • Renames Conv1d to Convolve1d across modules.
  • Adds new tests for gradient optimization and benchmarks for numba convolve1d.
  • Updates rewriting and dispatch code to support the new implementation.

Reviewed Changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.

Show a summary per file
File Description
tests/tensor/signal/test_conv.py Updated to import Convolve1d and added a test for gradient rewrite optimization.
tests/link/numba/signal/test_conv.py Adjusted tests to optionally swap inputs, and added a benchmark test.
pytensor/tensor/signal/conv.py Renamed Conv1d to Convolve1d and updated internal variable naming for clarity.
pytensor/tensor/rewriting/conv.py Added a rewrite rule to optimize valid convolution gradients for static shapes.
pytensor/tensor/rewriting/init.py Imported the new conv rewriting module.
pytensor/link/numba/dispatch/signal/conv.py Updated to register Convolve1d and implemented specialized numba functions.
pytensor/link/jax/dispatch/signal/conv.py Updated to register Convolve1d.
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, left ignorable suggestions


if (
start == len_y - 1
# equivalent to stop = conv.shape[-1] - len_y - 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use that form then? I don't understand this comment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I already extracted len_x, and I can use that directly

@ricardoV94 ricardoV94 force-pushed the faster_conv1d_numba branch from 02823cc to f1102ba Compare April 27, 2025 08:44
@ricardoV94 ricardoV94 force-pushed the faster_conv1d_numba branch from f1102ba to e2c8464 Compare April 27, 2025 08:46
@ricardoV94 ricardoV94 force-pushed the faster_conv1d_numba branch from e2c8464 to f0ef8fb Compare April 27, 2025 08:54
@ricardoV94 ricardoV94 merged commit 4378d48 into pymc-devs:main Apr 27, 2025
72 of 73 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment