Streamlitで折れ線グラフを見やすくする
-
2023/07/21
-
2023/07/21
折れ線グラフが複数重なると、下図のように見えにくくなることがあります。
そこでStreamlitを使って、選択したデータだけを色を付けで表示し、それ以外をグレーアウトさせて見やすくしてみたいと思います。
まず、プロットするデータを用意します。
今回は横軸と縦軸がそれぞれ0-100になるように、乱数のデータを5つ用意しました。
これをst.session_stateに保存しておきます。
if 'data1' not in st.session_state['data']:
st.session_state['data']['data1'] = rand(100) * 100
if 'data2' not in st.session_state['data']:
st.session_state['data']['data2'] = rand(100) * 100
if 'data3' not in st.session_state['data']:
st.session_state['data']['data3'] = rand(100) * 100
if 'data4' not in st.session_state['data']:
st.session_state['data']['data4'] = rand(100) * 100
if 'data5' not in st.session_state['data']:
st.session_state['data']['data5'] = rand(100) * 100
5つのデータに対応したチェックボックスを作成します。
また、st.session_stateにチェックを入れたデータを登録するキーを用意しておきます。
今回は、st.session_state['plots']をlist型とし、チェックを入れたデータの名前をこのリストに入れていきます。
登録する名前は先ほど用意したst.session_state['data']の中の名前と同じになるようにしています。
また、チェックがないデータの名前は、st.session_state['plots']から削除します。
def add_plot(name):
if name not in st.session_state['plots']:
st.session_state['plots'].append(name)
def del_plot(name):
if name in st.session_state['plots']:
st.session_state['plots'].remove(name)
for i in range(5):
name = "data" + str(i + 1)
check = col1.checkbox(name)
if check:
add_plot(name)
else:
del_plot(name)
st.session_state['data']とst.session_state['plots']を使ってプロットしていきます。
st.session_state['data']の中にあるデータをそれぞれプロットします。
データの名前がst.session_state['plots']に含まれていない場合は、線の色に'lightgray'を指定して目立たなくします。
st.session_state['plots']には明るい色を指定します。
今回は、matplotlib.colorsのBASE_COLORSを順に指定しています。
先にチェックを入れていないデータをすべてプロットしてしまうことで、チェックを入れた線が上に重なるようにしています。
def create_graph():
nums = np.array(range(100))
fig = plt.figure()
if 'plots' not in st.session_state:
return
for i, data in enumerate(st.session_state['data']):
if data not in st.session_state['plots']:
plt.plot(nums, st.session_state['data'][data], label=data, color='lightgray')
for i, data in enumerate(st.session_state['data']):
if data in st.session_state['plots']:
plt.plot(nums, st.session_state['data'][data], label=data, color=colors[i])
plt.legend(loc = 'upper right')
col2.pyplot(fig)
実装は以上です。
ブラウザで実行すると、下記のようにチェックを入れたデータだけを表示させることができるようになりました。
最後にコード全体を載せておきます。
import streamlit as st
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
from numpy.random import rand
colors = list(mcolors.BASE_COLORS.keys())
if 'data' not in st.session_state:
st.session_state['data'] = {}
if 'plots' not in st.session_state:
st.session_state['plots'] = []
if 'data1' not in st.session_state['data']:
st.session_state['data']['data1'] = rand(100) * 100
if 'data2' not in st.session_state['data']:
st.session_state['data']['data2'] = rand(100) * 100
if 'data3' not in st.session_state['data']:
st.session_state['data']['data3'] = rand(100) * 100
if 'data4' not in st.session_state['data']:
st.session_state['data']['data4'] = rand(100) * 100
if 'data5' not in st.session_state['data']:
st.session_state['data']['data5'] = rand(100) * 100
col1, col2 = st.columns([1, 3])
def create_graph():
nums = np.array(range(100))
fig = plt.figure()
if 'plots' not in st.session_state:
return
for i, data in enumerate(st.session_state['data']):
if data not in st.session_state['plots']:
plt.plot(nums, st.session_state['data'][data], label=data, color='lightgray')
for i, data in enumerate(st.session_state['data']):
if data in st.session_state['plots']:
plt.plot(nums, st.session_state['data'][data], label=data, color=colors[i])
plt.legend(loc = 'upper right')
col2.pyplot(fig)
def add_plot(name):
if name not in st.session_state['plots']:
st.session_state['plots'].append(name)
def del_plot(name):
if name in st.session_state['plots']:
st.session_state['plots'].remove(name)
for i in range(5):
name = "data" + str(i + 1)
check = col1.checkbox(name)
if check:
add_plot(name)
else:
del_plot(name)
create_graph()