Skip to content

Expose MLX memory management APIs#98

Merged
polvalente merged 4 commits intoelixir-nx:mainfrom
dannote:add-memory-management-apis
Feb 22, 2026
Merged

Expose MLX memory management APIs#98
polvalente merged 4 commits intoelixir-nx:mainfrom
dannote:add-memory-management-apis

Conversation

@dannote
Copy link
Contributor

@dannote dannote commented Feb 22, 2026

Adds bindings for MLX's memory management functions:

  • EMLX.memory_info/0 — returns %{active_memory, peak_memory, cache_memory} in bytes
  • EMLX.clear_cache/0 — releases unused GPU memory back to the system
  • EMLX.reset_peak_memory/0 — resets the peak memory counter
  • EMLX.set_memory_limit/1 — sets the memory limit guideline
  • EMLX.set_cache_limit/1 — sets the cache size limit

Why this is needed

Without clear_cache, repeated model inference causes GPU memory to grow unbounded as MLX caches freed buffers. On a 24 GB Apple M5 running ai-forever/FRIDA (823M parameter T5 encoder), memory grows from 3 GB to 18 GB after just 4 inference batches, causing severe system-wide slowdowns as the GPU starts swapping:

Batch 1: 4873ms = 13.1 sent/s Batch 2: 5619ms = 11.4 sent/s Batch 3: 31385ms = 2.0 sent/s ← GPU memory exhausted Batch 4: 69177ms = 0.9 sent/s 

With EMLX.clear_cache() + :erlang.garbage_collect() between batches:

Batch 1: 4517ms = 14.2 sent/s Batch 2: 4587ms = 14.0 sent/s Batch 3: 4517ms = 14.2 sent/s Batch 4: 4556ms = 14.0 sent/s ← stable 

All 2117 tests pass (2115 existing + 6 new memory tests).

Add bindings for MLX's memory management functions: - EMLX.memory_info/0 - returns active, peak, and cache memory usage - EMLX.clear_cache/0 - releases unused GPU memory back to the system - EMLX.reset_peak_memory/0 - resets the peak memory counter - EMLX.set_memory_limit/1 - sets the memory limit guideline - EMLX.set_cache_limit/1 - sets the cache size limit Without clear_cache, repeated model inference causes GPU memory to grow unbounded as MLX caches freed buffers. On a 24 GB Apple M5 running a 823M parameter model, memory usage grew from 3 GB to 18 GB after just 4 batches, causing severe system-wide slowdowns and GPU swapping. Calling clear_cache + :erlang.garbage_collect() between batches keeps memory stable and inference throughput consistent.
}

NIF(clear_cache) {
mlx::core::clear_cache();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these return void?

t = Nx.iota({1024, 1024}, type: :f32, backend: EMLX.Backend)
EMLX.eval(EMLX.Backend.from_nx(t))
after_alloc = EMLX.memory_info().active_memory
assert after_alloc > before
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion could be more strict since we know the tensor size

EMLX.eval(EMLX.Backend.from_nx(t))
Nx.backend_deallocate(t)
EMLX.clear_cache()
info = EMLX.memory_info()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have assertions on what info returns?

{EMLX.Backend, device: device}
end

@doc """
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These docs could use some examples,.even if they aren't doctests per se

@polvalente polvalente merged commit 2d5742d into elixir-nx:main Feb 22, 2026
11 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants