Skip to content

Commit c1df05f

Browse files
authored
[PYTHON] Further streamline number handling (#242)
This PR further streamlines number handling by introducing two custom protocols and move Integral and Real handling to more conservative path.
1 parent 8ee0e49 commit c1df05f

File tree

3 files changed

+103
-5
lines changed

3 files changed

+103
-5
lines changed

python/tvm_ffi/_convert.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ def convert(value: Any) -> Any: # noqa: PLR0911,PLR0912
129129
return value
130130
elif hasattr(value, "__dlpack_device__"):
131131
return value
132+
elif hasattr(value, "__tvm_ffi_int__"):
133+
return value
134+
elif hasattr(value, "__tvm_ffi_float__"):
135+
return value
132136
else:
133137
# in this case, it is an opaque python object
134138
return core._convert_to_opaque_object(value)

python/tvm_ffi/cython/function.pxi

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import ctypes
1818
import threading
1919
import os
20-
from numbers import Real, Integral
20+
from numbers import Integral, Real
2121
from typing import Any, Callable
2222

2323

@@ -276,7 +276,33 @@ cdef int TVMFFIPyArgSetterDLPack_(
276276
return 0
277277

278278

279-
cdef int TVMFFIPyArgSetterFFIObjectCompatible_(
279+
cdef int TVMFFIPyArgSetterIntegral_(
280+
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
281+
PyObject* py_arg, TVMFFIAny* out
282+
) except -1:
283+
"""Setter for Integral"""
284+
cdef object arg = <object>py_arg
285+
out.type_index = kTVMFFIInt
286+
# keep it in cython so it will also check for fallback cases
287+
# where the arg is not exactly the int class
288+
out.v_int64 = <long long>arg
289+
return 0
290+
291+
292+
cdef int TVMFFIPyArgSetterReal_(
293+
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
294+
PyObject* py_arg, TVMFFIAny* out
295+
) except -1:
296+
"""Setter for Real"""
297+
cdef object arg = <object>py_arg
298+
out.type_index = kTVMFFIFloat
299+
# keep it in cython so it will also check for fallback cases
300+
# where the arg is not exactly the float class
301+
out.v_float64 = <double>arg
302+
return 0
303+
304+
305+
cdef int TVMFFIPyArgSetterFFIObjectProtocol_(
280306
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
281307
PyObject* py_arg, TVMFFIAny* out
282308
) except -1:
@@ -608,6 +634,7 @@ cdef int TVMFFIPyArgSetterDTypeFromNumpy_(
608634
out.v_dtype = NUMPY_DTYPE_TO_DL_DATA_TYPE[py_obj]
609635
return 0
610636

637+
611638
cdef int TVMFFIPyArgSetterDLPackDataTypeProtocol_(
612639
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
613640
PyObject* py_arg, TVMFFIAny* out
@@ -621,6 +648,29 @@ cdef int TVMFFIPyArgSetterDLPackDataTypeProtocol_(
621648
out.v_dtype.lanes = <long long>dltype_data_type[2]
622649
return 0
623650

651+
652+
cdef int TVMFFIPyArgSetterIntProtocol_(
653+
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
654+
PyObject* py_arg, TVMFFIAny* out
655+
) except -1:
656+
"""Setter for class with __tvm_ffi_int__() method"""
657+
cdef object arg = <object>py_arg
658+
out.type_index = kTVMFFIInt
659+
out.v_int64 = <long long>(arg.__tvm_ffi_int__())
660+
return 0
661+
662+
663+
cdef int TVMFFIPyArgSetterFloatProtocol_(
664+
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
665+
PyObject* py_arg, TVMFFIAny* out
666+
) except -1:
667+
"""Setter for class with __tvm_ffi_float__() method"""
668+
cdef object arg = <object>py_arg
669+
out.type_index = kTVMFFIFloat
670+
out.v_float64 = <double>(arg.__tvm_ffi_float__())
671+
return 0
672+
673+
624674
cdef _DISPATCH_TYPE_KEEP_ALIVE = set()
625675
cdef _DISPATCH_TYPE_KEEP_ALIVE_LOCK = threading.Lock()
626676

@@ -668,7 +718,7 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce
668718
# can directly map to tvm ffi object
669719
# usually used for solutions that takes subclass of ffi.Object
670720
# as a member variable
671-
out.func = TVMFFIPyArgSetterFFIObjectCompatible_
721+
out.func = TVMFFIPyArgSetterFFIObjectProtocol_
672722
return 0
673723
if os.environ.get("TVM_FFI_SKIP_C_DLPACK_EXCHANGE_API", "0") != "1":
674724
# Check for DLPackExchangeAPI struct (new approach)
@@ -698,10 +748,15 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce
698748
out.func = TVMFFIPyArgSetterBool_
699749
return 0
700750
if isinstance(arg, Integral):
701-
out.func = TVMFFIPyArgSetterInt_
751+
# must occur before Real check
752+
# cannot simply use TVMFFIPyArgSetterInt
753+
# because Integral may not be exactly the int class
754+
out.func = TVMFFIPyArgSetterIntegral_
702755
return 0
703756
if isinstance(arg, Real):
704-
out.func = TVMFFIPyArgSetterFloat_
757+
# cannot simply use TVMFFIPyArgSetterFloat
758+
# because Real may not be exactly the float class
759+
out.func = TVMFFIPyArgSetterReal_
705760
return 0
706761
# dtype is a subclass of str, so this check must occur before str
707762
if isinstance(arg, _CLASS_DTYPE):
@@ -760,6 +815,12 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce
760815
# then it is a DLPack device protocol
761816
out.func = TVMFFIPyArgSetterDLPackDeviceProtocol_
762817
return 0
818+
if hasattr(arg_class, "__tvm_ffi_int__"):
819+
out.func = TVMFFIPyArgSetterIntProtocol_
820+
return 0
821+
if hasattr(arg_class, "__tvm_ffi_float__"):
822+
out.func = TVMFFIPyArgSetterFloatProtocol_
823+
return 0
763824
if isinstance(arg, Exception):
764825
out.func = TVMFFIPyArgSetterException_
765826
return 0

tests/python/test_function.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,3 +350,36 @@ def __dlpack_device__(self) -> tuple[int, int]:
350350
x = DLPackDeviceProtocol(device)
351351
y = fecho(x)
352352
assert y == device
353+
354+
355+
def test_integral_float_variants_passing() -> None:
356+
fecho = tvm_ffi.get_global_func("testing.echo")
357+
y = fecho(np.int32(1))
358+
assert isinstance(y, int)
359+
assert y == 1
360+
361+
y = fecho(np.float64(2.0))
362+
assert isinstance(y, float)
363+
assert y == 2.0
364+
365+
class IntProtocol:
366+
def __init__(self, value: int) -> None:
367+
self.value = value
368+
369+
def __tvm_ffi_int__(self) -> int:
370+
return self.value
371+
372+
y = fecho(IntProtocol(10))
373+
assert isinstance(y, int)
374+
assert y == 10
375+
376+
class FloatProtocol:
377+
def __init__(self, value: float) -> None:
378+
self.value = value
379+
380+
def __tvm_ffi_float__(self) -> float:
381+
return self.value
382+
383+
y = fecho(FloatProtocol(10))
384+
assert isinstance(y, float)
385+
assert y == 10

0 commit comments

Comments
 (0)