public note

pandas MultiIndex DataFrameからの抽出方法、ついでにMatplotlib

pandasのマルチインデックスなDataFrame(以下、MultiIndex)からデータを抽出する方法と、Matplotlibでの可視化について書きます。 これまであまり経験がなく、別にいつもと変わらないっしょ!とやってみたところ(案の定)うまくいかなかったので、 そこで得た学びをここに記します。

pandas.DataFrame とは

DataFrame is a 2-dimensional labeled data structure with columns of potentially different types. You can think of it like a spreadsheet or SQL table, or a dict of Series objects. It is generally the most commonly used pandas object. pandas.pydata.org

Pythonにおいてpandas.DataFrame を使うと、データを表形式として扱うことができます。格納したデータ行はIndexというキー情報で識別するのですが、そのIndexを複数設定したDataFrameを、ここではMultiIndexと呼ぶことにします。MultiIndexにすると、データを階層的に表現できたり、特定のグループを抽出・集計する処理を可読性高くコーディングすることができる、といったメリットがあります。

Matplotlib とは

Pythonで使える可視化ライブラリです。 matplotlib.org

本題: なにをしようとしたか

pandas.DataFrameにMultiIndexなデータを入れて良い感じに整形して、Matplotlibで可視化しようとしました。が、うまくいかなかったので以下のような試行錯誤をしました。

コードサンプル

適当に作ったCSVデータをもとに、Jupyter Notebook で再現してみます。 github.com

CSV -> DataFrame

CSVのデータをDataFrameへ格納します。Indexは設定していないので、連番が振られています。

os.chdir(os.path.dirname(os.path.abspath("__file__")))
df = pd.read_csv('./test.csv', header=0, encoding='utf-8')
df['date'] = pd.to_datetime(df['date'])
df['col_3'] = df['col_1'] * df['col_2']
df
 date     id col_1   col_2   col_3
0   2021-01-09  1   50  1000    50000
1   2021-01-10  2   100 1000    100000
2   2021-01-11  2   100 1000    100000
3   2021-01-12  3   150 1000    150000
4   2021-01-13  3   150 1000    150000
5   2021-01-14  3   150 1000    150000
6   2021-01-15  1   50  1000    50000
7   2021-01-16  1   50  2000    100000
8   2021-01-17  2   100 2000    200000
9   2021-01-18  3   150 2000    300000
10  2021-01-19  2   100 2000    200000
...
30  2021-02-08  3   150 3000    450000
31  2021-02-09  3   150 3000    450000
32  2021-02-10  1   50  3000    150000

MultiIndexを設定

以下のように、date列とid列をIndexとして指定します。年月日で条件抽出できるように、date列を分解して先頭に追加しています。

df = df.set_index(['date'])
df = df.set_index([df.index.year, df.index.month, df.index.day, df.index, 'id'])
df.index.names = ['year', 'month', 'day', 'date', 'id']
df.sort_index()
df
                                     col_1   col_2   col_3
year    month   day date        id          
2021    1       9   2021-01-09  1   50      1000    50000
                10  2021-01-10  2   100     1000    100000
                11  2021-01-11  2   100     1000    100000
                12  2021-01-12  3   150     1000    150000
                13  2021-01-13  3   150     1000    150000
                14  2021-01-14  3   150     1000    150000
                15  2021-01-15  1   50      1000    50000
...

Indexを指定して抽出

こういった方法があるようです。

  • groupby -> get_group
  • loc
  • xs

groupby -> get_group

groupby()によって、指定したlevelでグルーピングできます。get_group()でレベル別に値を指定してデータを抽出します。戻り値のtypeはDataFrameで、MultiIndexの設定もそのまま維持されています。

df.groupby(['year', 'month']).get_group((2021, 2))
print(df.groupby(['year', 'month']).get_group((2021, 2)).index.names)
                 col_1   col_2   col_3
year    month   day date    id          
2021    2   1   2021-02-01  2   100 3000    300000
            2   2021-02-02  1   50  3000    150000
            3   2021-02-03  3   150 3000    450000
            4   2021-02-04  1   50  3000    150000
            5   2021-02-05  2   100 3000    300000
            6   2021-02-06  2   100 3000    300000
            7   2021-02-07  3   150 3000    450000
            8   2021-02-08  3   150 3000    450000
            9   2021-02-09  3   150 3000    450000
            10  2021-02-10  1   50  3000    150000


['year', 'month', 'day', 'date', 'id']

ちなみに、groupby()しただけだとtypeはpandas.core.groupby.generic.DataFrameGroupByで、get_group()しないとDataFrameになりません。

deepage.net

xs

level別に抽出したいindexを指定する方法です。戻り値のtypeは同じくDataFrameですが、指定したIndexがなくなっています。

df.xs(2021, level='year').xs(1, level='month')
print(df.xs(2021, level='year').xs(2, level='month').index.names)
                 col_1   col_2   col_3
day date        id          
1   2021-02-01  2   100 3000    300000
2   2021-02-02  1   50  3000    150000
3   2021-02-03  3   150 3000    450000
4   2021-02-04  1   50  3000    150000
5   2021-02-05  2   100 3000    300000
6   2021-02-06  2   100 3000    300000
7   2021-02-07  3   150 3000    450000
8   2021-02-08  3   150 3000    450000
9   2021-02-09  3   150 3000    450000
10  2021-02-10  1   50  3000    150000

['day', 'date', 'id']

loc

tupleでインデックスを指定する方法です。locなので戻り値のtypeはSeries、指定したIndexはなくなっています。

df.loc[(2021, 1), 'col_3']
print(df.loc[(2021, 2), 'col_3'].index.names)
day  date        id
1    2021-02-01  2     300000
2    2021-02-02  1     150000
3    2021-02-03  3     450000
4    2021-02-04  1     150000
5    2021-02-05  2     300000
6    2021-02-06  2     300000
7    2021-02-07  3     450000
8    2021-02-08  3     450000
9    2021-02-09  3     450000
10   2021-02-10  1     150000
Name: col_3, dtype: int64

['day', 'date', 'id']

グラフにする

以下のようにしてみました。Matplotlib難しい…

ヒストグラム

import matplotlib.pyplot as plt

# histogram
fig, ax = plt.subplots(1, 2, figsize=(10, 5))

for i, month in enumerate(df.index.unique(level='month')):
    ax[i].set_title(str(month))
    pd.cut(
        # 2021年の月ごとに、col_3をヒストグラムで表示
        df.groupby(['year', 'month']).get_group((2021, month))['col_3'],
        bins=3,     # 階級数
        right=False # True: aより大きくb以下; False: a以上b未満
    ) \
    .value_counts() \
    .sort_index() \
    .plot.bar(color='indigo', ax=ax[i], sharex=True, sharey=True)
plt.show()

f:id:ts223:20210110010336p:plain

箱ひげ図

df.groupby('year').get_group(2021).boxplot(
    column='col_3',
    by='month',
    figsize=(10, 5),
    meanline=True,
    showmeans=True,
    showcaps=True,
    showbox=True,
    showfliers=False
)

f:id:ts223:20210110010521p:plain

まとめ

抽出結果が、方法によって異なるところが面白かったポイントです。そのあたりを理解したうえで使い分けができるとかなりおしゃれです。 一方で、グルーピングがそこまで重要でない場面では、開き直ってreset_indexして、フラットな構造にするのもありです。