|
3 | 3 | import datetime |
4 | 4 | from functools import partial |
5 | 5 | from textwrap import dedent |
6 | | -from typing import TYPE_CHECKING, Optional, Union |
| 6 | +from typing import Optional, Union |
7 | 7 | import warnings |
8 | 8 |
|
9 | 9 | import numpy as np |
10 | 10 |
|
11 | 11 | from pandas._libs.tslibs import Timedelta |
12 | 12 | import pandas._libs.window.aggregations as window_aggregations |
13 | | -from pandas._typing import FrameOrSeries, TimedeltaConvertibleTypes |
| 13 | +from pandas._typing import FrameOrSeries, FrameOrSeriesUnion, TimedeltaConvertibleTypes |
14 | 14 | from pandas.compat.numpy import function as nv |
15 | 15 | from pandas.util._decorators import doc |
16 | 16 |
|
|
19 | 19 |
|
20 | 20 | import pandas.core.common as common |
21 | 21 | from pandas.core.util.numba_ import maybe_use_numba |
22 | | -from pandas.core.window.common import flex_binary_moment, zsqrt |
| 22 | +from pandas.core.window.common import zsqrt |
23 | 23 | from pandas.core.window.doc import ( |
24 | 24 | _shared_docs, |
25 | 25 | args_compat, |
|
35 | 35 | GroupbyIndexer, |
36 | 36 | ) |
37 | 37 | from pandas.core.window.numba_ import generate_numba_groupby_ewma_func |
38 | | -from pandas.core.window.rolling import BaseWindow, BaseWindowGroupby, dispatch |
39 | | - |
40 | | -if TYPE_CHECKING: |
41 | | - from pandas import Series |
| 38 | +from pandas.core.window.rolling import BaseWindow, BaseWindowGroupby |
42 | 39 |
|
43 | 40 |
|
44 | 41 | def get_center_of_mass( |
@@ -74,13 +71,20 @@ def get_center_of_mass( |
74 | 71 | return float(comass) |
75 | 72 |
|
76 | 73 |
|
77 | | -def wrap_result(obj: Series, result: np.ndarray) -> Series: |
| 74 | +def dispatch(name: str, *args, **kwargs): |
78 | 75 | """ |
79 | | - Wrap a single 1D result. |
| 76 | + Dispatch to groupby apply. |
80 | 77 | """ |
81 | | - obj = obj._selected_obj |
82 | 78 |
|
83 | | - return obj._constructor(result, obj.index, name=obj.name) |
| 79 | + def outer(self, *args, **kwargs): |
| 80 | + def f(x): |
| 81 | + x = self._shallow_copy(x, groupby=self._groupby) |
| 82 | + return getattr(x, name)(*args, **kwargs) |
| 83 | + |
| 84 | + return self._groupby.apply(f) |
| 85 | + |
| 86 | + outer.__name__ = name |
| 87 | + return outer |
84 | 88 |
|
85 | 89 |
|
86 | 90 | class ExponentialMovingWindow(BaseWindow): |
@@ -443,36 +447,30 @@ def var_func(values, begin, end, min_periods): |
443 | 447 | ) |
444 | 448 | def cov( |
445 | 449 | self, |
446 | | - other: Optional[Union[np.ndarray, FrameOrSeries]] = None, |
| 450 | + other: Optional[FrameOrSeriesUnion] = None, |
447 | 451 | pairwise: Optional[bool] = None, |
448 | 452 | bias: bool = False, |
449 | 453 | **kwargs, |
450 | 454 | ): |
451 | | - if other is None: |
452 | | - other = self._selected_obj |
453 | | - # only default unset |
454 | | - pairwise = True if pairwise is None else pairwise |
455 | | - other = self._shallow_copy(other) |
456 | | - |
457 | | - def _get_cov(X, Y): |
458 | | - X = self._shallow_copy(X) |
459 | | - Y = self._shallow_copy(Y) |
460 | | - cov = window_aggregations.ewmcov( |
461 | | - X._prep_values(), |
| 455 | + from pandas import Series |
| 456 | + |
| 457 | + def cov_func(x, y): |
| 458 | + x_array = self._prep_values(x) |
| 459 | + y_array = self._prep_values(y) |
| 460 | + result = window_aggregations.ewmcov( |
| 461 | + x_array, |
462 | 462 | np.array([0], dtype=np.int64), |
463 | 463 | np.array([0], dtype=np.int64), |
464 | 464 | self.min_periods, |
465 | | - Y._prep_values(), |
| 465 | + y_array, |
466 | 466 | self.com, |
467 | 467 | self.adjust, |
468 | 468 | self.ignore_na, |
469 | 469 | bias, |
470 | 470 | ) |
471 | | - return wrap_result(X, cov) |
| 471 | + return Series(result, index=x.index, name=x.name) |
472 | 472 |
|
473 | | - return flex_binary_moment( |
474 | | - self._selected_obj, other._selected_obj, _get_cov, pairwise=bool(pairwise) |
475 | | - ) |
| 473 | + return self._apply_pairwise(self._selected_obj, other, pairwise, cov_func) |
476 | 474 |
|
477 | 475 | @doc( |
478 | 476 | template_header, |
@@ -502,45 +500,37 @@ def _get_cov(X, Y): |
502 | 500 | ) |
503 | 501 | def corr( |
504 | 502 | self, |
505 | | - other: Optional[Union[np.ndarray, FrameOrSeries]] = None, |
| 503 | + other: Optional[FrameOrSeriesUnion] = None, |
506 | 504 | pairwise: Optional[bool] = None, |
507 | 505 | **kwargs, |
508 | 506 | ): |
509 | | - if other is None: |
510 | | - other = self._selected_obj |
511 | | - # only default unset |
512 | | - pairwise = True if pairwise is None else pairwise |
513 | | - other = self._shallow_copy(other) |
| 507 | + from pandas import Series |
514 | 508 |
|
515 | | - def _get_corr(X, Y): |
516 | | - X = self._shallow_copy(X) |
517 | | - Y = self._shallow_copy(Y) |
| 509 | + def cov_func(x, y): |
| 510 | + x_array = self._prep_values(x) |
| 511 | + y_array = self._prep_values(y) |
518 | 512 |
|
519 | | - def _cov(x, y): |
| 513 | + def _cov(X, Y): |
520 | 514 | return window_aggregations.ewmcov( |
521 | | - x, |
| 515 | + X, |
522 | 516 | np.array([0], dtype=np.int64), |
523 | 517 | np.array([0], dtype=np.int64), |
524 | 518 | self.min_periods, |
525 | | - y, |
| 519 | + Y, |
526 | 520 | self.com, |
527 | 521 | self.adjust, |
528 | 522 | self.ignore_na, |
529 | 523 | 1, |
530 | 524 | ) |
531 | 525 |
|
532 | | - x_values = X._prep_values() |
533 | | - y_values = Y._prep_values() |
534 | 526 | with np.errstate(all="ignore"): |
535 | | - cov = _cov(x_values, y_values) |
536 | | - x_var = _cov(x_values, x_values) |
537 | | - y_var = _cov(y_values, y_values) |
538 | | - corr = cov / zsqrt(x_var * y_var) |
539 | | - return wrap_result(X, corr) |
540 | | - |
541 | | - return flex_binary_moment( |
542 | | - self._selected_obj, other._selected_obj, _get_corr, pairwise=bool(pairwise) |
543 | | - ) |
| 527 | + cov = _cov(x_array, y_array) |
| 528 | + x_var = _cov(x_array, x_array) |
| 529 | + y_var = _cov(y_array, y_array) |
| 530 | + result = cov / zsqrt(x_var * y_var) |
| 531 | + return Series(result, index=x.index, name=x.name) |
| 532 | + |
| 533 | + return self._apply_pairwise(self._selected_obj, other, pairwise, cov_func) |
544 | 534 |
|
545 | 535 |
|
546 | 536 | class ExponentialMovingWindowGroupby(BaseWindowGroupby, ExponentialMovingWindow): |
|
0 commit comments