Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
use pd.factorize
  • Loading branch information
TLouf committed Aug 8, 2021
commit e09d760abf852b723ab0c1d18ef07581dd385918
41 changes: 11 additions & 30 deletions pandas/core/arrays/sparse/scipy_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
import numpy as np

from pandas.core.algorithms import factorize
from pandas.core.indexes.api import MultiIndex
from pandas.core.series import Series

Expand All @@ -18,31 +19,17 @@ def _check_is_partition(parts, whole):
raise ValueError("Is not a partition because union is not the whole.")


def _levels_to_axis(levels_codes, levels_labels, valid_ilocs, sort_labels=False):
if sort_labels and levels_codes.shape[0] == 1:
ax_coords = levels_codes[0][valid_ilocs]
ax_labels = levels_labels[0].tolist()
def _levels_to_axis(ss, levels, valid_ilocs, sort_labels=False):
if sort_labels and len(levels) == 1:
ax_coords = ss.index.codes[levels[0]][valid_ilocs]
ax_labels = ss.index.levels[levels[0]]

else:
# Why return_index anyway : https://github.com/numpy/numpy/issues/16923
ucodes, ucodes_idx, ucodes_inv = np.unique(
levels_codes.T, axis=0, return_index=True, return_inverse=True
)

if sort_labels:
ax_coords = ucodes_inv[valid_ilocs]

else:
og_order = np.argsort(ucodes_idx)
ucodes = ucodes[og_order, :]
ax_coords = og_order.argsort()[ucodes_inv[valid_ilocs]]

ax_labels = list(
zip(
*(tuple(lbls[ucodes[:, lvl]]) for lvl, lbls in enumerate(levels_labels))
)
)
levels_values = list(zip(*(ss.index.get_level_values(lvl) for lvl in levels)))
codes, ax_labels = factorize(levels_values, sort=sort_labels)
ax_coords = codes[valid_ilocs]

ax_labels = ax_labels.tolist()
return ax_coords, ax_labels


Expand All @@ -57,20 +44,14 @@ def _to_ijv(ss, row_levels=(0,), column_levels=(1,), sort_labels=False):
# from the sparse Series: get the labels and data for non-null entries
values = ss.array._valid_sp_values

codes = ss.index.codes
labels = ss.index.levels
valid_ilocs = np.where(ss.notnull())[0]

row_labels = [labels[lvl] for lvl in row_levels]
row_codes = np.asarray([codes[lvl] for lvl in row_levels])
i_coords, i_labels = _levels_to_axis(
row_codes, row_labels, valid_ilocs, sort_labels=sort_labels
ss, row_levels, valid_ilocs, sort_labels=sort_labels
)

col_labels = [labels[lvl] for lvl in column_levels]
col_codes = np.asarray([codes[lvl] for lvl in column_levels])
j_coords, j_labels = _levels_to_axis(
col_codes, col_labels, valid_ilocs, sort_labels=sort_labels
ss, column_levels, valid_ilocs, sort_labels=sort_labels
)

return values, i_coords, j_coords, i_labels, j_labels
Expand Down