Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
fbd7f3c
Define more linker requirements
ricardoV94 Nov 28, 2025
fa7ed79
Remove old GPU related check
ricardoV94 Nov 28, 2025
9f877ba
More informative Scan error message
ricardoV94 Nov 28, 2025
0a596f4
Validate compatible linker in Scan make_thunk
ricardoV94 Nov 28, 2025
ff16133
Numba Scan: correct handling of signed mitmot taps
ricardoV94 Nov 27, 2025
76d0d84
Numba Scan: zero out unwritten buffers
ricardoV94 Nov 27, 2025
d68665a
Numba Scan: prevent alias of outputs
ricardoV94 Nov 24, 2025
930d819
Numba Scan: make codegen more readable
ricardoV94 Nov 24, 2025
065e90b
Handle upcasting of scalar to vector arrays by scipy vector optimizers
ricardoV94 Nov 21, 2025
cecb6a5
Numba DimShuffle: validate squeeze
ricardoV94 Nov 25, 2025
80c3f03
Numba Split: Validate sizes
ricardoV94 Nov 19, 2025
0664fe5
Numba Argmax: Fix axis=None
ricardoV94 Nov 21, 2025
39cac02
Numba Dot: Handle complex inputs
ricardoV94 Nov 21, 2025
f5f1d31
Numba Unique: align with Python implementation
ricardoV94 Nov 24, 2025
402c965
Numba CAReduce: respect acc_dtype
ricardoV94 Nov 21, 2025
5f40dbf
Numba uint: fix Sigmoid and Softplus with uint inputs
ricardoV94 Nov 28, 2025
d1ee4b2
Numba: inplace list_type Ops failing with caching
ricardoV94 Nov 28, 2025
363b125
Numba ListType: Use correct type
ricardoV94 Nov 29, 2025
bc5e8e6
Numba UnravelIndex and RavelMultiIndex is incorrect
ricardoV94 Nov 30, 2025
330bb73
Numba deepcopy: Support boolean
ricardoV94 Nov 30, 2025
febd311
More informative NotImplementedError
ricardoV94 Nov 21, 2025
461ea85
Numba Alloc: Patch so it works inside a Blockwise
ricardoV94 Nov 30, 2025
543ddc1
Remove tensor/io
ricardoV94 Nov 30, 2025
dc479b0
Fix output dtype of LstSQ Op
ricardoV94 Nov 30, 2025
d69a8a4
Numba int_to_float: Remove buggy helper
ricardoV94 Nov 30, 2025
7acff0a
Numba linalg: Handle empty inputs
ricardoV94 Nov 30, 2025
68f7dbc
Numba FillDiagonal: Copy input
ricardoV94 Nov 30, 2025
60b94dd
Numba linalg: Fallback to objmode with complex inputs
ricardoV94 Nov 30, 2025
c8ee772
Numba eigh: Must cast to promised dtype
ricardoV94 Nov 30, 2025
8183442
Try to run full test suite in Numba backend
ricardoV94 Nov 19, 2025
05de712
Split linalg tests into their own job
ricardoV94 Nov 30, 2025
7f6d4ca
Revert mode for tests that are C-specific
ricardoV94 Nov 24, 2025
682bd3a
Temporarily remove numba object mode warning
ricardoV94 Jun 13, 2024
dcc3e33
Tweak expected errors
ricardoV94 Nov 24, 2025
3407d02
Tweak Blockwise/RandomVariable tests
ricardoV94 Nov 24, 2025
9b01f57
Align numba reciprocal with C backend
ricardoV94 Nov 21, 2025
0ed50dd
XFAIL conv tests of Ops without Python implementation
ricardoV94 Jun 13, 2024
2bf252c
XFAIL/SKIP float16 tests
ricardoV94 Jun 7, 2024
2cadf88
XFAIL TypedList global constant
ricardoV94 Nov 25, 2025
e45aaf7
XFAIL/SKIP Sparse tests
ricardoV94 Nov 24, 2025
f718510
Fix Eye test
ricardoV94 Nov 19, 2025
bc7027d
Numba does not output numpy scalars
ricardoV94 Nov 19, 2025
09b0f99
Tweak test tolerances
ricardoV94 Nov 19, 2025
28d1d19
Tweak RandomGenerator tests
ricardoV94 Nov 30, 2025
5307039
Test wasn't actually covering rewrite
ricardoV94 Nov 24, 2025
5962f8a
Allow burn-in in memory leak test
ricardoV94 Nov 28, 2025
0ed3464
Test SVD allow benign sign change
ricardoV94 Nov 30, 2025
77c7a80
XFAIL eig test due to domain change
ricardoV94 Nov 30, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,14 @@ jobs:
install-mlx: [0]
install-xarray: [0]
part:
- "tests --ignore=tests/scan --ignore=tests/tensor --ignore=tests/xtensor"
- "tests --ignore=tests/scan --ignore=tests/tensor --ignore=tests/xtensor --ignore=tests/link/numba"
- "tests/scan"
- "tests/tensor --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting"
- "tests/tensor --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/linalg --ignore=tests/tensor/test_nlinalg.py --ignore=tests/tensor/test_slinalg.py"
- "tests/tensor/test_basic.py tests/tensor/test_elemwise.py"
- "tests/tensor/test_math.py"
- "tests/tensor/test_math_scipy.py tests/tensor/test_blas.py tests/tensor/conv"
- "tests/tensor/rewriting"
- "tests/tensor/linalg tests/tensor/test_nlinalg.py tests/tensor/test_slinalg.py"
exclude:
- python-version: "3.11"
fast-compile: 1
Expand Down Expand Up @@ -202,7 +203,7 @@ jobs:
else
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx;
fi
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "mlx<0.29.4"; fi
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ tag_prefix = "rel-"
addopts = "--durations=50 --doctest-modules --ignore=pytensor/link --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/ipython.py"
testpaths = ["pytensor/", "tests/"]
xfail_strict = true
filterwarnings =[
'ignore:^Numba will use object mode to run.*perform method\.:UserWarning',
]

[tool.ruff]
line-length = 88
Expand Down
33 changes: 10 additions & 23 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ def register_linker(name, linker):
predefined_linkers[name] = linker


# If a string is passed as the optimizer argument in the constructor
# for Mode, it will be used as the key to retrieve the real optimizer
# in this dictionary
exclude = []
if not config.cxx:
exclude = ["cxx_only"]
Expand Down Expand Up @@ -451,30 +448,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
return new_mode


# If a string is passed as the mode argument in function or
# FunctionMaker, the Mode will be taken from this dictionary using the
# string as the key
# Use VM_linker to allow lazy evaluation by default.
FAST_COMPILE = Mode(
VMLinker(use_cloop=False, c_thunks=False),
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
)
if config.cxx:
FAST_RUN = Mode("cvm", "fast_run")
else:
FAST_RUN = Mode(
"vm",
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
)

C = Mode("c", "fast_run")
C_VM = Mode("cvm", "fast_run")

NUMBA = Mode(
NumbaLinker(),
RewriteDatabaseQuery(include=["fast_run", "numba"]),
)

FAST_COMPILE = Mode(
NumbaLinker(),
RewriteDatabaseQuery(include=["fast_compile"]),
)
FAST_RUN = NUMBA

C = Mode("c", "fast_run")
CVM = Mode("cvm", "fast_run")

JAX = Mode(
JAXLinker(),
RewriteDatabaseQuery(include=["fast_run", "jax"]),
Expand All @@ -494,7 +481,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
"FAST_COMPILE": FAST_COMPILE,
"FAST_RUN": FAST_RUN,
"C": C,
"C_VM": C_VM,
"CVM": CVM,
"JAX": JAX,
"NUMBA": NUMBA,
"PYTORCH": PYTORCH,
Expand Down
7 changes: 3 additions & 4 deletions pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,25 +371,24 @@ def add_compile_configvars():
)
del param

default_linker = "cvm"
default_linker = "numba"

if rc == 0 and config.cxx != "":
# Keep the default linker the same as the one for the mode FAST_RUN
linker_options = [
"c|py",
"cvmc|py",
"py",
"c",
"c|py_nogc",
"vm",
"vm_nogc",
"cvm_nogc",
"numba",
"jax",
]
else:
# g++ is not present or the user disabled it,
# linker should default to python only.
linker_options = ["py", "vm", "vm_nogc", "numba", "jax"]
linker_options = ["py", "vm", "vm_nogc", "jax"]
if type(config).cxx.is_default:
# If the user provided an empty value for cxx, do not warn.
_logger.warning(
Expand Down
6 changes: 6 additions & 0 deletions pytensor/link/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ class PerformLinker(LocalLinker):

"""

required_rewrites: tuple[str, ...] = ("minimum_compile", "py_only")
incompatible_rewrites: tuple[str, ...] = ("cxx",)

def __init__(
self, allow_gc: bool | None = None, schedule: Callable | None = None
) -> None:
Expand Down Expand Up @@ -584,6 +587,9 @@ class JITLinker(PerformLinker):

"""

required_rewrites: tuple[str, ...] = ("minimum_compile",)
incompatible_rewrites: tuple[str, ...] = ()

@abstractmethod
def fgraph_convert(
self, fgraph, order, input_storage, output_storage, storage_map, **kwargs
Expand Down
101 changes: 37 additions & 64 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def get_numba_type(
elif isinstance(pytensor_type, RandomGeneratorType):
return numba.types.NumPyRandomGeneratorType("NumPyRandomGeneratorType")
elif isinstance(pytensor_type, TypedListType):
return numba.types.List(get_numba_type(pytensor_type.ttype))
return numba.types.ListType(get_numba_type(pytensor_type.ttype))
else:
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")

Expand Down Expand Up @@ -199,13 +199,10 @@ def creator(args, creator=creator, i=i):


def create_tuple_string(x):
args = ", ".join(x + ([""] if len(x) == 1 else []))
return f"({args})"


def create_arg_string(x):
args = ", ".join(x)
return args
if len(x) == 1:
return f"({x[0]},)"
else:
return f"({', '.join(x)})"


@numba.extending.intrinsic
Expand All @@ -227,36 +224,6 @@ def codegen(context, builder, signature, args):
return sig, codegen


def int_to_float_fn(inputs, out_dtype):
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""

if (
all(inp.type.dtype == out_dtype for inp in inputs)
and np.dtype(out_dtype).kind == "f"
):

@numba_njit(inline="always")
def inputs_cast(x):
return x

elif any(i.type.numpy_dtype.kind in "uib" for i in inputs):
args_dtype = np.dtype(f"f{out_dtype.itemsize}")

@numba_njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)

else:
args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs)
args_dtype = np.dtype(f"f{args_dtype_sz}")

@numba_njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)

return inputs_cast


@singledispatch
def numba_typify(data, dtype=None, **kwargs):
return data
Expand Down Expand Up @@ -368,6 +335,36 @@ def dispatch_func_wrapper(*args, **kwargs):
return decorator


def default_hash_key_from_props(op, **extra_fields):
props_dict = op._props_dict()
if not props_dict:
# Simple op, just use the type string as key
hash = sha256(
f"({type(op)}, {tuple(extra_fields.items())})".encode()
).hexdigest()
else:
# Simple props, can use string representation of props as key
simple_types = (str, bool, int, type(None), float)
container_types = (tuple, frozenset)
if all(
isinstance(v, simple_types)
or (
isinstance(v, container_types)
and all(isinstance(i, simple_types) for i in v)
)
for v in props_dict.values()
):
hash = sha256(
f"({type(op)}, {tuple(props_dict.items())}, {tuple(extra_fields.items())})".encode()
).hexdigest()
else:
# Complex props, use pickle to serialize them
hash = hash_from_pickle_dump(
(str(type(op)), tuple(props_dict.items()), tuple(extra_fields.items())),
)
return hash


@singledispatch
def numba_funcify_and_cache_key(op, node=None, **kwargs) -> tuple[Callable, str | None]:
"""Funcify an Op and return a unique cache key that can be used by numba caching.
Expand Down Expand Up @@ -411,36 +408,12 @@ def numba_funcify_and_cache_key(op, node=None, **kwargs) -> tuple[Callable, str
else:
func, integer_str = func_and_int, "None"

try:
props_dict = op._props_dict()
except AttributeError:
if not hasattr(op, "__props__"):
raise ValueError(
"The function wrapped by `numba_funcify_default_op_cache_key` can only be used with Ops with `_props`, "
f"but {op} of type {type(op)} has no _props defined (not even empty)."
)
if not props_dict:
# Simple op, just use the type string as key
hash = sha256(f"({type(op)}, {integer_str})".encode()).hexdigest()
else:
# Simple props, can use string representation of props as key
simple_types = (str, bool, int, type(None), float)
container_types = (tuple, frozenset)
if all(
isinstance(v, simple_types)
or (
isinstance(v, container_types)
and all(isinstance(i, simple_types) for i in v)
)
for v in props_dict.values()
):
hash = sha256(
f"({type(op)}, {tuple(props_dict.items())}, {integer_str})".encode()
).hexdigest()
else:
# Complex props, use pickle to serialize them
hash = hash_from_pickle_dump(
(str(type(op)), tuple(props_dict.items()), integer_str),
)
hash = default_hash_key_from_props(op, cache_version=integer_str)
return func, hash


Expand Down
3 changes: 2 additions & 1 deletion pytensor/link/numba/dispatch/compile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def numba_deepcopy(x):

@numba.extending.overload(numba_deepcopy)
def numba_deepcopy_tensor(x):
if isinstance(x, numba.types.Number):
if isinstance(x, numba.types.Number | numba.types.Boolean):

def number_deepcopy(x):
return x
Expand Down Expand Up @@ -61,6 +61,7 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
accept_inplace=True,
)
# TODO: Prevent output aliasing like we do for Scan/outer function
NUMBA.optimizer(fgraph)
fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key(
op.fgraph, squeeze_output=True, **kwargs
Expand Down
Loading
Loading