Skip to content

Conversation

@junpenglao
Copy link
Member

@junpenglao junpenglao commented Jul 9, 2024

Description

Reduce blackjax sampling memory usage by not outputting the warm up diagnostics

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7407.org.readthedocs.build/en/7407/

... by not outputing the warmup diagnositics
@junpenglao
Copy link
Member Author

Will need to upgrade jaxlib requirement first (conda-forge/jaxlib-feedstock#272)

@codecov
Copy link

codecov bot commented Jul 11, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.18%. Comparing base (641a60b) to head (e81d828).
Report is 97 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@ Coverage Diff @@ ## main #7407 +/- ## ======================================= Coverage 92.18% 92.18% ======================================= Files 103 103 Lines 17258 17259 +1 ======================================= + Hits 15909 15910 +1  Misses 1349 1349 
Files with missing lines Coverage Δ
pymc/sampling/jax.py 94.03% <100.00%> (+0.02%) ⬆️
@junpenglao junpenglao merged commit c8b22df into main Jul 13, 2024
@junpenglao junpenglao deleted the blackjax_memory branch July 13, 2024 07:33
vandalt pushed a commit to vandalt/pymc that referenced this pull request May 14, 2025
* Reduce blackjax sampling memory usage ... by not outputing the warmup diagnositics * Update jax env * fix pre-commit * skip also RuntimeWarning * ping jax versions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

2 participants