pandas MultiIndex DataFrameからの抽出方法、ついでにMatplotlib
pandasのマルチインデックスなDataFrame(以下、MultiIndex)からデータを抽出する方法と、Matplotlibでの可視化について書きます。 これまであまり経験がなく、別にいつもと変わらないっしょ!とやってみたところ(案の定)うまくいかなかったので、 そこで得た学びをここに記します。
pandas.DataFrame とは
DataFrame is a 2-dimensional labeled data structure with columns of potentially different types. You can think of it like a spreadsheet or SQL table, or a dict of Series objects. It is generally the most commonly used pandas object. pandas.pydata.org
Pythonにおいてpandas.DataFrame を使うと、データを表形式として扱うことができます。格納したデータ行はIndexというキー情報で識別するのですが、そのIndexを複数設定したDataFrameを、ここではMultiIndexと呼ぶことにします。MultiIndexにすると、データを階層的に表現できたり、特定のグループを抽出・集計する処理を可読性高くコーディングすることができる、といったメリットがあります。
Matplotlib とは
Pythonで使える可視化ライブラリです。 matplotlib.org
本題: なにをしようとしたか
pandas.DataFrameにMultiIndexなデータを入れて良い感じに整形して、Matplotlibで可視化しようとしました。が、うまくいかなかったので以下のような試行錯誤をしました。
コードサンプル
適当に作ったCSVデータをもとに、Jupyter Notebook で再現してみます。 github.com
CSV -> DataFrame
CSVのデータをDataFrameへ格納します。Indexは設定していないので、連番が振られています。
os.chdir(os.path.dirname(os.path.abspath("__file__"))) df = pd.read_csv('./test.csv', header=0, encoding='utf-8') df['date'] = pd.to_datetime(df['date']) df['col_3'] = df['col_1'] * df['col_2'] df
date id col_1 col_2 col_3 0 2021-01-09 1 50 1000 50000 1 2021-01-10 2 100 1000 100000 2 2021-01-11 2 100 1000 100000 3 2021-01-12 3 150 1000 150000 4 2021-01-13 3 150 1000 150000 5 2021-01-14 3 150 1000 150000 6 2021-01-15 1 50 1000 50000 7 2021-01-16 1 50 2000 100000 8 2021-01-17 2 100 2000 200000 9 2021-01-18 3 150 2000 300000 10 2021-01-19 2 100 2000 200000 ... 30 2021-02-08 3 150 3000 450000 31 2021-02-09 3 150 3000 450000 32 2021-02-10 1 50 3000 150000
MultiIndexを設定
以下のように、date列とid列をIndexとして指定します。年月日で条件抽出できるように、date列を分解して先頭に追加しています。
df = df.set_index(['date']) df = df.set_index([df.index.year, df.index.month, df.index.day, df.index, 'id']) df.index.names = ['year', 'month', 'day', 'date', 'id'] df.sort_index() df
col_1 col_2 col_3 year month day date id 2021 1 9 2021-01-09 1 50 1000 50000 10 2021-01-10 2 100 1000 100000 11 2021-01-11 2 100 1000 100000 12 2021-01-12 3 150 1000 150000 13 2021-01-13 3 150 1000 150000 14 2021-01-14 3 150 1000 150000 15 2021-01-15 1 50 1000 50000 ...
Indexを指定して抽出
こういった方法があるようです。
- groupby -> get_group
- loc
- xs
groupby -> get_group
groupby()
によって、指定したlevelでグルーピングできます。get_group()
でレベル別に値を指定してデータを抽出します。戻り値のtypeはDataFrame
で、MultiIndexの設定もそのまま維持されています。
df.groupby(['year', 'month']).get_group((2021, 2)) print(df.groupby(['year', 'month']).get_group((2021, 2)).index.names)
col_1 col_2 col_3 year month day date id 2021 2 1 2021-02-01 2 100 3000 300000 2 2021-02-02 1 50 3000 150000 3 2021-02-03 3 150 3000 450000 4 2021-02-04 1 50 3000 150000 5 2021-02-05 2 100 3000 300000 6 2021-02-06 2 100 3000 300000 7 2021-02-07 3 150 3000 450000 8 2021-02-08 3 150 3000 450000 9 2021-02-09 3 150 3000 450000 10 2021-02-10 1 50 3000 150000 ['year', 'month', 'day', 'date', 'id']
ちなみに、groupby()
しただけだとtypeはpandas.core.groupby.generic.DataFrameGroupBy
で、get_group()
しないとDataFrame
になりません。
xs
level別に抽出したいindexを指定する方法です。戻り値のtypeは同じくDataFrame
ですが、指定したIndexがなくなっています。
df.xs(2021, level='year').xs(1, level='month') print(df.xs(2021, level='year').xs(2, level='month').index.names)
col_1 col_2 col_3 day date id 1 2021-02-01 2 100 3000 300000 2 2021-02-02 1 50 3000 150000 3 2021-02-03 3 150 3000 450000 4 2021-02-04 1 50 3000 150000 5 2021-02-05 2 100 3000 300000 6 2021-02-06 2 100 3000 300000 7 2021-02-07 3 150 3000 450000 8 2021-02-08 3 150 3000 450000 9 2021-02-09 3 150 3000 450000 10 2021-02-10 1 50 3000 150000 ['day', 'date', 'id']
loc
tupleでインデックスを指定する方法です。locなので戻り値のtypeはSeries
、指定したIndexはなくなっています。
df.loc[(2021, 1), 'col_3'] print(df.loc[(2021, 2), 'col_3'].index.names)
day date id 1 2021-02-01 2 300000 2 2021-02-02 1 150000 3 2021-02-03 3 450000 4 2021-02-04 1 150000 5 2021-02-05 2 300000 6 2021-02-06 2 300000 7 2021-02-07 3 450000 8 2021-02-08 3 450000 9 2021-02-09 3 450000 10 2021-02-10 1 150000 Name: col_3, dtype: int64 ['day', 'date', 'id']
グラフにする
以下のようにしてみました。Matplotlib難しい…
ヒストグラム
import matplotlib.pyplot as plt # histogram fig, ax = plt.subplots(1, 2, figsize=(10, 5)) for i, month in enumerate(df.index.unique(level='month')): ax[i].set_title(str(month)) pd.cut( # 2021年の月ごとに、col_3をヒストグラムで表示 df.groupby(['year', 'month']).get_group((2021, month))['col_3'], bins=3, # 階級数 right=False # True: aより大きくb以下; False: a以上b未満 ) \ .value_counts() \ .sort_index() \ .plot.bar(color='indigo', ax=ax[i], sharex=True, sharey=True) plt.show()
箱ひげ図
df.groupby('year').get_group(2021).boxplot( column='col_3', by='month', figsize=(10, 5), meanline=True, showmeans=True, showcaps=True, showbox=True, showfliers=False )
まとめ
抽出結果が、方法によって異なるところが面白かったポイントです。そのあたりを理解したうえで使い分けができるとかなりおしゃれです。
一方で、グルーピングがそこまで重要でない場面では、開き直ってreset_index
して、フラットな構造にするのもありです。