1616
1717
1818@functools .cache
19- def make_looper (func , result_dtype , nopython , nogil , parallel ):
19+ def make_looper (func , result_dtype , is_grouped_kernel , nopython , nogil , parallel ):
2020 if TYPE_CHECKING :
2121 import numba
2222 else :
2323 numba = import_optional_dependency ("numba" )
2424
25- @numba .jit (nopython = nopython , nogil = nogil , parallel = parallel )
26- def column_looper (
27- values : np .ndarray ,
28- start : np .ndarray ,
29- end : np .ndarray ,
30- min_periods : int ,
31- * args ,
32- ):
33- result = np .empty ((values .shape [0 ], len (start )), dtype = result_dtype )
34- na_positions = {}
35- for i in numba .prange (values .shape [0 ]):
36- output , na_pos = func (
37- values [i ], result_dtype , start , end , min_periods , * args
38- )
39- result [i ] = output
40- if len (na_pos ) > 0 :
41- na_positions [i ] = np .array (na_pos )
42- return result , na_positions
25+ if is_grouped_kernel :
26+ @numba .jit (nopython = nopython , nogil = nogil , parallel = parallel )
27+ def column_looper (
28+ values : np .ndarray ,
29+ labels : np .ndarray ,
30+ ngroups : int ,
31+ min_periods : int ,
32+ * args ,
33+ ):
34+ result = np .empty ((values .shape [0 ], ngroups ), dtype = result_dtype )
35+ na_positions = {}
36+ for i in numba .prange (values .shape [0 ]):
37+ output , na_pos = func (
38+ values [i ], result_dtype , labels , ngroups , min_periods , * args
39+ )
40+ result [i ] = output
41+ if len (na_pos ) > 0 :
42+ na_positions [i ] = np .array (na_pos )
43+ return result , na_positions
44+ else :
45+ @numba .jit (nopython = nopython , nogil = nogil , parallel = parallel )
46+ def column_looper (
47+ values : np .ndarray ,
48+ start : np .ndarray ,
49+ end : np .ndarray ,
50+ min_periods : int ,
51+ * args ,
52+ ):
53+ result = np .empty ((values .shape [0 ], len (start )), dtype = result_dtype )
54+ na_positions = {}
55+ for i in numba .prange (values .shape [0 ]):
56+ output , na_pos = func (
57+ values [i ], result_dtype , start , end , min_periods , * args
58+ )
59+ result [i ] = output
60+ if len (na_pos ) > 0 :
61+ na_positions [i ] = np .array (na_pos )
62+ return result , na_positions
4363
4464 return column_looper
4565
@@ -96,6 +116,7 @@ def column_looper(
96116def generate_shared_aggregator (
97117 func : Callable [..., Scalar ],
98118 dtype_mapping : dict [np .dtype , np .dtype ],
119+ is_grouped_kernel : bool ,
99120 nopython : bool ,
100121 nogil : bool ,
101122 parallel : bool ,
@@ -111,6 +132,11 @@ def generate_shared_aggregator(
111132 dtype_mapping: dict or None
112133 If not None, maps a dtype to a result dtype.
113134 Otherwise, will fall back to default mapping.
135+ is_grouped_kernel: bool, default False
136+ Whether func operates using the group labels (True)
137+ or using starts/ends arrays
138+
139+ If true, you also need to pass the number of groups to this function
114140 nopython : bool
115141 nopython to be passed into numba.jit
116142 nogil : bool
@@ -130,13 +156,18 @@ def generate_shared_aggregator(
130156 # is less than min_periods
131157 # Cannot do this in numba nopython mode
132158 # (you'll run into type-unification error when you cast int -> float)
133- def looper_wrapper (values , start , end , min_periods , ** kwargs ):
159+ def looper_wrapper (values , start = None , end = None , labels = None , ngroups = None , min_periods = 0 , ** kwargs ):
134160 result_dtype = dtype_mapping [values .dtype ]
135- column_looper = make_looper (func , result_dtype , nopython , nogil , parallel )
161+ column_looper = make_looper (func , result_dtype , is_grouped_kernel , nopython , nogil , parallel )
136162 # Need to unpack kwargs since numba only supports *args
137- result , na_positions = column_looper (
138- values , start , end , min_periods , * kwargs .values ()
139- )
163+ if is_grouped_kernel :
164+ result , na_positions = column_looper (
165+ values , labels , ngroups , min_periods , * kwargs .values ()
166+ )
167+ else :
168+ result , na_positions = column_looper (
169+ values , start , end , min_periods , * kwargs .values ()
170+ )
140171 if result .dtype .kind == "i" :
141172 # Look if na_positions is not empty
142173 # If so, convert the whole block
0 commit comments