-
[x] I have checked that this issue has not already been reported.
-
[x] I have confirmed this bug exists on the latest version of pandas.
-
[x] (optional) I have confirmed this bug exists on the master branch of pandas.
Code Sample, a copy-pastable example
import pandas as pd
df = pd.DataFrame({"str_col": ["a", "b", "c", "a"], "num_col": [1, 2, 3, 2]})
df["str_col"] = df["str_col"].astype("string")
print(df.dtypes)
avg = df.groupby("str_col", as_index=False).mean()
print(avg.dtypes)
Problem description
After grouping the string column loses it's string dtype and is object afterwards. This is rather unexpected. When using a string col as grouping column one would have to manually change the dtypes back to string. The output for avg.dtypes is:
str_col object
num_col float64
dtype: object
Expected Output
str_col string
num_col float64
dtype: object
Output of pd.show_versions()
Comment From: rhshadrach
Thanks for the report, I've confirmed this on master. Investigations and PR to fix are most welcome!
Comment From: pspachtholz
I just did a little more exploration and it seems that dtype is not preserved for any of the ExtensionDtypes, also tested for datetime64 where it works:
import pandas as pd
df = pd.DataFrame(
{
"str_col": ["a", "b", "c", "a"],
"bool_col": [False] * 3 + [True],
"date_col": pd.date_range("2021-01-01", periods=4),
"int_col": [1, 2, 3, 2],
"num_col": [1.0, 2.2, 3.1, 2.25],
}
)
df = df.convert_dtypes()
for col in ["str_col", "bool_col", "date_col", "int_col"]:
avg = df.groupby(col, as_index=False)["num_col"].mean()
print(f"{df[col].dtype} -> {avg[col].dtype}")
This prints:
string -> object
boolean -> bool
datetime64[ns] -> datetime64[ns]
Int64 -> int64
I have very limited knowledge about the workings of pandas but taking a look at the source code, it seems a grouping is created e.g. for the str_col
:
group_index:Index(['a', 'b', 'c'], dtype='object', name='str_col')
grouper:<StringArray>
['a', 'b', 'c', 'a']
Length: 4, dtype: string
groups:{'a': [0, 3], 'b': [1], 'c': [2]}
In Grouping._make_codes an index is created for the grouping: https://github.com/pandas-dev/pandas/blob/348d43f7bf63465dd8f6cca4e1bd4b608fb58597/pandas/core/groupby/grouper.py#L622,
This index (as returned by self.grouper.get_group_levels()) is then finally merged back to the computed aggregated result in _insert_inaxsis_grouper_inplace
Index(['a', 'b', 'c'], dtype='object', name='str_col')
Leading to the resulting dtypes
str_col object
num_col Float64
In the creation of the index the StringArray is converted to a numpy array of dtype object.
https://github.com/pandas-dev/pandas/blob/2d51ebb77f5dc7f7824fd0b7b7edd538f2eaa819/pandas/core/indexes/base.py#L359-L361
So this is where the dtype information is lost.
It is hard without deeper knowledge about pandas on where and how the extension dtype should be handled. I am happy to help with a pull request, if you could help me with that.
Comment From: rhshadrach
Thanks for digging into this @pspachtholz, very helpful! It seems the blocker here would be #39133 in the current implementation.
Comment From: jbrockmendel
This now retains string dtype. Could use a test (or confirm that one exists)
Comment From: srkds
take