1717import ctypes
1818import threading
1919import os
20- from numbers import Real, Integral
20+ from numbers import Integral, Real
2121from typing import Any, Callable
2222
2323
@@ -276,7 +276,33 @@ cdef int TVMFFIPyArgSetterDLPack_(
276276 return 0
277277
278278
279- cdef int TVMFFIPyArgSetterFFIObjectCompatible_(
279+ cdef int TVMFFIPyArgSetterIntegral_(
280+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
281+ PyObject* py_arg, TVMFFIAny* out
282+ ) except - 1 :
283+ """ Setter for Integral"""
284+ cdef object arg = < object > py_arg
285+ out.type_index = kTVMFFIInt
286+ # keep it in cython so it will also check for fallback cases
287+ # where the arg is not exactly the int class
288+ out.v_int64 = < long long > arg
289+ return 0
290+
291+
292+ cdef int TVMFFIPyArgSetterReal_(
293+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
294+ PyObject* py_arg, TVMFFIAny* out
295+ ) except - 1 :
296+ """ Setter for Real"""
297+ cdef object arg = < object > py_arg
298+ out.type_index = kTVMFFIFloat
299+ # keep it in cython so it will also check for fallback cases
300+ # where the arg is not exactly the float class
301+ out.v_float64 = < double > arg
302+ return 0
303+
304+
305+ cdef int TVMFFIPyArgSetterFFIObjectProtocol_(
280306 TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
281307 PyObject* py_arg, TVMFFIAny* out
282308) except - 1 :
@@ -608,6 +634,7 @@ cdef int TVMFFIPyArgSetterDTypeFromNumpy_(
608634 out.v_dtype = NUMPY_DTYPE_TO_DL_DATA_TYPE[py_obj]
609635 return 0
610636
637+
611638cdef int TVMFFIPyArgSetterDLPackDataTypeProtocol_(
612639 TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
613640 PyObject* py_arg, TVMFFIAny* out
@@ -621,6 +648,29 @@ cdef int TVMFFIPyArgSetterDLPackDataTypeProtocol_(
621648 out.v_dtype.lanes = < long long > dltype_data_type[2 ]
622649 return 0
623650
651+
652+ cdef int TVMFFIPyArgSetterIntProtocol_(
653+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
654+ PyObject* py_arg, TVMFFIAny* out
655+ ) except - 1 :
656+ """ Setter for class with __tvm_ffi_int__() method"""
657+ cdef object arg = < object > py_arg
658+ out.type_index = kTVMFFIInt
659+ out.v_int64 = < long long > (arg.__tvm_ffi_int__())
660+ return 0
661+
662+
663+ cdef int TVMFFIPyArgSetterFloatProtocol_(
664+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
665+ PyObject* py_arg, TVMFFIAny* out
666+ ) except - 1 :
667+ """ Setter for class with __tvm_ffi_float__() method"""
668+ cdef object arg = < object > py_arg
669+ out.type_index = kTVMFFIFloat
670+ out.v_float64 = < double > (arg.__tvm_ffi_float__())
671+ return 0
672+
673+
624674cdef _DISPATCH_TYPE_KEEP_ALIVE = set ()
625675cdef _DISPATCH_TYPE_KEEP_ALIVE_LOCK = threading.Lock()
626676
@@ -668,7 +718,7 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce
668718 # can directly map to tvm ffi object
669719 # usually used for solutions that takes subclass of ffi.Object
670720 # as a member variable
671- out.func = TVMFFIPyArgSetterFFIObjectCompatible_
721+ out.func = TVMFFIPyArgSetterFFIObjectProtocol_
672722 return 0
673723 if os.environ.get(" TVM_FFI_SKIP_C_DLPACK_EXCHANGE_API" , " 0" ) != " 1" :
674724 # Check for DLPackExchangeAPI struct (new approach)
@@ -698,10 +748,15 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce
698748 out.func = TVMFFIPyArgSetterBool_
699749 return 0
700750 if isinstance (arg, Integral):
701- out.func = TVMFFIPyArgSetterInt_
751+ # must occur before Real check
752+ # cannot simply use TVMFFIPyArgSetterInt
753+ # because Integral may not be exactly the int class
754+ out.func = TVMFFIPyArgSetterIntegral_
702755 return 0
703756 if isinstance (arg, Real):
704- out.func = TVMFFIPyArgSetterFloat_
757+ # cannot simply use TVMFFIPyArgSetterFloat
758+ # because Real may not be exactly the float class
759+ out.func = TVMFFIPyArgSetterReal_
705760 return 0
706761 # dtype is a subclass of str, so this check must occur before str
707762 if isinstance (arg, _CLASS_DTYPE):
@@ -760,6 +815,12 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce
760815 # then it is a DLPack device protocol
761816 out.func = TVMFFIPyArgSetterDLPackDeviceProtocol_
762817 return 0
818+ if hasattr (arg_class, " __tvm_ffi_int__" ):
819+ out.func = TVMFFIPyArgSetterIntProtocol_
820+ return 0
821+ if hasattr (arg_class, " __tvm_ffi_float__" ):
822+ out.func = TVMFFIPyArgSetterFloatProtocol_
823+ return 0
763824 if isinstance (arg, Exception ):
764825 out.func = TVMFFIPyArgSetterException_
765826 return 0
0 commit comments