- Notifications
You must be signed in to change notification settings - Fork 149
Add MLX backend #1365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add MLX backend #1365
Changes from all commits
Commits
Show all changes
88 commits Select commit Hold shift + click to select a range
9446d80 mlx poc
williambdean 8a38e2f add test for dot
williambdean d6feeba restore pytorch
williambdean c8f959e wrap in mx.array
williambdean 513ee3a modify the pytorch jit
williambdean 59f4a88 move file
williambdean 07e21e4 dont wrap
williambdean 0583bf7 attempt to fix github action
williambdean 9022edd change the rtol
williambdean ebe96e0 add init file
williambdean c859db0 skip if not installed
williambdean 90321ba remove torch related code / comments
williambdean 5e51402 simplify the fgraph_convert
williambdean 488ea5a assert type
williambdean d714fbc simplify the internal
williambdean 081806f remove the language
williambdean 6e312e0 Adding operations in pytensor
cetagostini b8a95ea add extension
williambdean 4083fe1 make compare function
williambdean 71ad63d rename function
williambdean 9f67c2c correct the function name
williambdean fa47b1a tests for elemwise
williambdean 292c01b Changes
cetagostini 9133b3c Toma tu tomate William
cetagostini 1d68d5e Pushing changes with the core shit.
cetagostini 2014390 add more tests
williambdean 89567aa additional tests
williambdean dccee53 test for switch with mlx
williambdean ff871c2 Pushing code
cetagostini 0c2fec1 Changes
cetagostini 5275bf5 A lot of new code
cetagostini 004ed73 almost there baby william
cetagostini 5257c96 Another push small
cetagostini 323045c fix for all
williambdean 0abac67 fix for carlos
williambdean 199f17c just return the compiled func
williambdean 7b6a3d2 A change for willy may!
cetagostini 710b563 FINALLY BABY LETS PARTY! (IF YOU ARE READING THIS MAKE MORE PRs)
cetagostini 1e1d8f9 THE SUPER BLOCKWISEE YA YA YA YA JUUUUU
cetagostini b5c02a7 refactor to use getattr
williambdean 8df5b09 bring argmax test
williambdean 454fda9 use deepcopy
williambdean 9299a28 move some tests
williambdean b4a9642 Guys, I'm getting sad. We need help yisus!!!!!
cetagostini 30850f3 WILLIAM YOU NEED TO GO ANOTHER MILE! GO ON MY MATEEEEEEE, GO PHILLIES!
cetagostini 84665e5 RETURN, WHAT A SHAME! Sad times are coming.
cetagostini 3041340 AI COULD BE COOL? OR WE ARE JUST FUCKING AROUND?
cetagostini 36f886b AI RULES BABY MY MATE
cetagostini ca7c77f I'm going for pizzas, it was an incredible day!
cetagostini 9688407 test conv1d case
williambdean 5d34fa6 SUUUUUUUUU!!!!!! LIFE IS GOING WELL. MLX FOR MEDIA MIX MODELS BAY
cetagostini b2fac8e pre-commit
cetagostini 10e1a40 Almost working
cetagostini c6841cb Last PR sampling working
cetagostini e81ba94 Requested changes by Ricardo
cetagostini 1da4530 Pre commit changes
cetagostini d6f6e2a More changes from Ricardo
cetagostini 3d144db Pre Commit RUN
cetagostini 8300fd4 Adding more operations for complex model
cetagostini d47de98 Working with simple model
cetagostini cfcb910 Change bad name
cetagostini 481e3ad Correcting test by Ricardo
cetagostini 9527f6c Changing synth test
cetagostini 13a700a Optimizing reshape
cetagostini fb46008 Comment
cetagostini 70734c9 Small changes and adding small benchmark
cetagostini a43f1cf Changes with Ricardo
cetagostini 5e53537 improving benchmark
cetagostini 127b896 pre commit
cetagostini 26a6d14 benchs
cetagostini b2e924d Changes on the branch
cetagostini a550919 Feedback from Ricardo
cetagostini e54c32f update test based on llm recommendation
cetagostini d5a4bf8 Streamline Blockwise impl
jessegrabowski 11faf7a clean up imports
jessegrabowski 2a86028 adjust github test.yml
jessegrabowski 53cdf49 adjust github test.yml
jessegrabowski d22851a skip mlx tests in benchmark ci
jessegrabowski ed0d687 Absolute imports
jessegrabowski 2589de4 Use `importorskip` in mlx tests
jessegrabowski d16d245 address feedback
jessegrabowski 9f41a4e Add function names and remove wrappers
jessegrabowski bad7c90 Copy jax CARReduce test
jessegrabowski 433a2cb Move alloc tests to test_core.py
jessegrabowski 2421a6f Handle dynamic shapes to AllocEmpty in non-compiled mode
jessegrabowski d33cda0 Simplify mlx_funcify_CAReduce
jessegrabowski 5940630 Delete AI cruft
jessegrabowski e484ba4 move all elemwise dispatches to elemwise.py
jessegrabowski File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -27,7 +27,6 @@ __pycache__ | |
| \#*\# | ||
| build | ||
| compiled/*.cpp | ||
| core.* | ||
| cutils_ext.cpp | ||
| dist | ||
| doc/.build/ | ||
| | ||
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from pytensor.link.mlx.linker import MLXLinker |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # isort: off | ||
| from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify | ||
| | ||
| import pytensor.link.mlx.dispatch.math | ||
| import pytensor.link.mlx.dispatch.basic | ||
| import pytensor.link.mlx.dispatch.elemwise | ||
| import pytensor.link.mlx.dispatch.shape | ||
| import pytensor.link.mlx.dispatch.subtensor | ||
| import pytensor.link.mlx.dispatch.core | ||
| import pytensor.link.mlx.dispatch.signal | ||
| import pytensor.link.mlx.dispatch.signal.conv | ||
| import pytensor.link.mlx.dispatch.blockwise | ||
| # isort: on |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| import warnings | ||
| from copy import deepcopy | ||
| from functools import singledispatch | ||
| from types import NoneType | ||
| | ||
| import mlx.core as mx | ||
| import numpy as np | ||
| | ||
| from pytensor.compile.ops import DeepCopyOp | ||
| from pytensor.graph import Constant | ||
| from pytensor.graph.fg import FunctionGraph | ||
| from pytensor.link.utils import fgraph_to_python | ||
| from pytensor.raise_op import Assert, CheckAndRaise | ||
| | ||
| | ||
| @singledispatch | ||
| def mlx_typify(data, **kwargs): | ||
| raise NotImplementedError(f"mlx_typify is not implemented for {type(data)}") | ||
| | ||
| | ||
| @mlx_typify.register(np.ndarray) | ||
| def mlx_typify_tensor(data, dtype=None, **kwargs): | ||
| return mx.array(data, dtype=dtype) | ||
| | ||
| | ||
| @mlx_typify.register(slice) | ||
| @mlx_typify.register(NoneType) | ||
| @mlx_typify.register(mx.array) | ||
| def mlx_typify_no_conversion_needed(data, **kwargs): | ||
| return data | ||
| | ||
| | ||
| @mlx_typify.register(int) | ||
| @mlx_typify.register(float) | ||
| def mlx_typify_python_scalar(data, **kwargs): | ||
| return mx.array(data) | ||
| | ||
| | ||
| @mlx_typify.register(bool) | ||
| @mlx_typify.register(np.bool_) | ||
| def mlx_typify_bool(data, **kwargs): | ||
| return bool(data) | ||
| | ||
| | ||
| @mlx_typify.register(np.integer) | ||
| @mlx_typify.register(np.floating) | ||
| @mlx_typify.register(np.complexfloating) | ||
| def mlx_typify_numpy_scalar(data, **kwargs): | ||
| return mx.array(data) | ||
| | ||
| | ||
| @singledispatch | ||
| def mlx_funcify(op, node=None, storage_map=None, **kwargs): | ||
| """Create a MLX compatible function from an PyTensor `Op`.""" | ||
| raise NotImplementedError( | ||
| f"No MLX conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/1350` for progress or to request we prioritize this operation" | ||
| ) | ||
| | ||
| | ||
| @mlx_funcify.register(FunctionGraph) | ||
| def mlx_funcify_FunctionGraph( | ||
| fgraph, | ||
| node=None, | ||
| fgraph_name="mlx_funcified_fgraph", | ||
| conversion_func=mlx_funcify, | ||
| **kwargs, | ||
| ): | ||
| built_kwargs = {"conversion_func": conversion_func, **kwargs} | ||
| return fgraph_to_python( | ||
| fgraph, | ||
| conversion_func, | ||
| type_conversion_fn=mlx_typify, | ||
| fgraph_name=fgraph_name, | ||
| **built_kwargs, | ||
| ) | ||
| | ||
| | ||
| @mlx_funcify.register(DeepCopyOp) | ||
| def mlx_funcify_DeepCopyOp(op, **kwargs): | ||
| def deepcopyop(x): | ||
| return deepcopy(x) | ||
| | ||
| return deepcopyop | ||
| | ||
| | ||
| @mlx_funcify.register(Assert) | ||
| @mlx_funcify.register(CheckAndRaise) | ||
| def mlx_funcify_CheckAndRaise(op, node, **kwargs): | ||
| conds = node.inputs[1:] | ||
| if any(isinstance(cond, Constant) and not bool(cond.data) for cond in conds): | ||
| raise op.exc_type(op.msg) | ||
| | ||
| warnings.warn( | ||
| f"""Skipping `{type(op).__name__}` Op (assertion: {op.msg}) as MLX tracing would remove it.""", | ||
| stacklevel=2, | ||
| ) | ||
| | ||
| def assert_fn(x, *inputs): | ||
| return x | ||
| | ||
| return assert_fn |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| import mlx.core as mx | ||
| | ||
| from pytensor.link.mlx.dispatch import mlx_funcify | ||
| from pytensor.tensor.blockwise import Blockwise | ||
| | ||
| | ||
| @mlx_funcify.register(Blockwise) | ||
| def funcify_Blockwise(op: Blockwise, node, **kwargs): | ||
| # 2) Otherwise, get the core python function for this Blockwise | ||
| core_node = op._create_dummy_core_node(node.inputs) | ||
| core_f = mlx_funcify(op.core_op, core_node) | ||
| | ||
| # 3) Determine how many inputs correspond to batch dimensions | ||
| n_batch = op.batch_ndim(node) | ||
| | ||
| # 4) Handle case where no vectorization is needed | ||
| if n_batch == 0: | ||
| return core_f | ||
| | ||
| # 5) Vectorize using mx.vmap over any batched inputs | ||
| in_axes: list[int | None] = [] | ||
| for inp, sig in zip(node.inputs, op.inputs_sig): | ||
| batch_ndim = inp.type.ndim - len(sig) | ||
| if batch_ndim == 0: | ||
| in_axes.append(None) | ||
| continue | ||
| | ||
| batch_bcast = inp.type.broadcastable[:batch_ndim] | ||
| # If all batch dims are broadcastable (size 1), treat input as static | ||
| in_axes.append(0 if not all(batch_bcast) else None) | ||
| | ||
| if not any(axis == 0 for axis in in_axes): | ||
| return core_f | ||
| | ||
| return mx.vmap(core_f, in_axes=tuple(in_axes)) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit. This suggestion is invalid because no changes were made to the code. Suggestions cannot be applied while the pull request is closed. Suggestions cannot be applied while viewing a subset of changes. Only one suggestion per line can be applied in a batch. Add this suggestion to a batch that can be applied as a single commit. Applying suggestions on deleted lines is not supported. You must change the existing code in this line in order to create a valid suggestion. Outdated suggestions cannot be applied. This suggestion has been applied or marked resolved. Suggestions cannot be applied from pending reviews. Suggestions cannot be applied on multi-line comments. Suggestions cannot be applied while the pull request is queued to merge. Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.