- Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
Describe the issue:
The Jax-based samplers crash after sampling, following the "Transforming variables..." message on medium-to-large models (thousands of rows, hundreds of parameters). This occurs both on GPU and CPU systems, and using either the numpyro or blackjax samplers. The failure on GPU returns a backtrace that isolates the issue at the vmap in _postprocess_samples. On a CPU (MacBook Pro M1), the process is simply killed without any error messages. I have tried running the GPU model with the postprocessing_backend="cpu" argument for the numpyro sampler, but this does not seem to make a difference. Should it be using vmap when the postprocessing backend is CPU?
Reproduceable code example:
Will add example when I can come up with one
Error message:
CPU machine error: Compilation time = 0:00:09.225151 Sampling... Running chain 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [12:00:33<00:00, 21.62s/it] Running chain 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [12:00:33<00:00, 21.62s/it] Sampling time = 12:00:35.215191 Transforming variables... Killed: 9 /Users/cfonnesbeck/mambaforge/envs/pymc/lib/python3.11/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d 'PyMC version information:
PyMC 5.3.0
PyTensor 2.11.1
Context for the issue:
The numpyro sampler is currently unusable for moderate-sized models due to this issue.