Skip to content
Prev Previous commit
Next Next commit
Address PR comments
  • Loading branch information
Roger Thomas committed Apr 29, 2022
commit e07f02c3a4ad0e9ceadcbd7ded9bf7c53eb5025c
47 changes: 26 additions & 21 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1909,22 +1909,26 @@ def to_dict(self, orient: str = "dict", into=dict):
elif orient.startswith("i"):
orient = "index"

object_dtype_cols = {
col for col, dtype in self.dtypes.items() if is_object_dtype(dtype)
}
are_all_object_dtype_cols = len(object_dtype_cols) == len(self.dtypes)
object_dtype_indices = [
i
for i, col_dtype in enumerate(self.dtypes.values)
if is_object_dtype(col_dtype)
]
are_all_object_dtype_cols = len(object_dtype_indices) == len(self.dtypes)

if orient == "dict":
return into_c((k, v.to_dict(into)) for k, v in self.items())

elif orient == "list":
object_dtype_indices = set(object_dtype_indices)
return into_c(
(
k,
list(map(maybe_box_native, v.tolist()))
if k in object_dtype_cols
if i in object_dtype_indices
else v.tolist(),
)
for k, v in self.items()
for i, (k, v) in enumerate(self.items())
)

elif orient == "split":
Expand All @@ -1935,12 +1939,9 @@ def to_dict(self, orient: str = "dict", into=dict):
]
else:
data = [list(t) for t in self.itertuples(index=False, name=None)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you share code between any of these cases? e.g. make a helper function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jreback done

if object_dtype_cols:
object_dtype_indices = [
i
for i, col in enumerate(self.columns)
if col in object_dtype_cols
]
if object_dtype_indices:
# If we have object_dtype_cols, apply maybe_box_naive after list
# comprehension for perf
for row in data:
for i in object_dtype_indices:
row[i] = maybe_box_native(row[i])
Expand All @@ -1960,12 +1961,9 @@ def to_dict(self, orient: str = "dict", into=dict):
]
else:
data = [list(t) for t in self.itertuples(index=False, name=None)]
if object_dtype_cols:
object_dtype_indices = [
i
for i, col in enumerate(self.columns)
if col in object_dtype_cols
]
if object_dtype_indices:
# If we have object_dtype_cols, apply maybe_box_naive after list
# comprehension for perf
for row in data:
for i in object_dtype_indices:
row[i] = maybe_box_native(row[i])
Expand Down Expand Up @@ -1998,7 +1996,13 @@ def to_dict(self, orient: str = "dict", into=dict):
into_c(zip(columns, t))
for t in self.itertuples(index=False, name=None)
]
if object_dtype_cols:
if object_dtype_indices:
object_dtype_indices = set(object_dtype_indices)
object_dtype_cols = {
col
for i, col in enumerate(self.columns)
if i in object_dtype_indices
}
for row in data:
for col in object_dtype_cols:
row[col] = maybe_box_native(row[col])
Expand All @@ -2013,9 +2017,10 @@ def to_dict(self, orient: str = "dict", into=dict):
(t[0], dict(zip(self.columns, map(maybe_box_native, t[1:]))))
for t in self.itertuples(name=None)
)
elif object_dtype_cols:
elif object_dtype_indices:
object_dtype_indices = set(object_dtype_indices)
is_object_dtype_by_index = [
col in object_dtype_cols for col in self.columns
i in object_dtype_indices for i in range(len(self.columns))
]
return into_c(
(
Expand Down