Streamlitで折れ線グラフを見やすくする

Python
  • 2023/07/21

折れ線グラフが複数重なると、下図のように見えにくくなることがあります。 image.png

そこでStreamlitを使って、選択したデータだけを色を付けで表示し、それ以外をグレーアウトさせて見やすくしてみたいと思います。 image.png

まず、プロットするデータを用意します。
今回は横軸と縦軸がそれぞれ0-100になるように、乱数のデータを5つ用意しました。
これをst.session_stateに保存しておきます。

if 'data1' not in st.session_state['data']:
    st.session_state['data']['data1'] = rand(100) * 100
if 'data2' not in st.session_state['data']:
    st.session_state['data']['data2'] = rand(100) * 100
if 'data3' not in st.session_state['data']:
    st.session_state['data']['data3'] = rand(100) * 100
if 'data4' not in st.session_state['data']:
    st.session_state['data']['data4'] = rand(100) * 100
if 'data5' not in st.session_state['data']:
    st.session_state['data']['data5'] = rand(100) * 100

5つのデータに対応したチェックボックスを作成します。
また、st.session_stateにチェックを入れたデータを登録するキーを用意しておきます。
今回は、st.session_state['plots']をlist型とし、チェックを入れたデータの名前をこのリストに入れていきます。
登録する名前は先ほど用意したst.session_state['data']の中の名前と同じになるようにしています。
また、チェックがないデータの名前は、st.session_state['plots']から削除します。

def add_plot(name):
    if name not in st.session_state['plots']:
        st.session_state['plots'].append(name)

def del_plot(name):
    if name in st.session_state['plots']:
        st.session_state['plots'].remove(name)

for i in range(5):
    name = "data" + str(i + 1)
    check = col1.checkbox(name)
    if check:
        add_plot(name)
    else:
        del_plot(name)

st.session_state['data']とst.session_state['plots']を使ってプロットしていきます。
st.session_state['data']の中にあるデータをそれぞれプロットします。
データの名前がst.session_state['plots']に含まれていない場合は、線の色に'lightgray'を指定して目立たなくします。
st.session_state['plots']には明るい色を指定します。
今回は、matplotlib.colorsのBASE_COLORSを順に指定しています。
先にチェックを入れていないデータをすべてプロットしてしまうことで、チェックを入れた線が上に重なるようにしています。

def create_graph():
    nums = np.array(range(100))

    fig = plt.figure()

    if 'plots' not in st.session_state:
        return

    for i, data in enumerate(st.session_state['data']):
        if data not in st.session_state['plots']:
            plt.plot(nums, st.session_state['data'][data], label=data, color='lightgray')
    for i, data in enumerate(st.session_state['data']):
        if data in st.session_state['plots']:
            plt.plot(nums, st.session_state['data'][data], label=data, color=colors[i])

    plt.legend(loc = 'upper right')

    col2.pyplot(fig)

実装は以上です。
ブラウザで実行すると、下記のようにチェックを入れたデータだけを表示させることができるようになりました。

image.png

最後にコード全体を載せておきます。

import streamlit as st
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
from numpy.random import rand

colors = list(mcolors.BASE_COLORS.keys())

if 'data' not in st.session_state:
    st.session_state['data'] = {}

if 'plots' not in st.session_state:
    st.session_state['plots'] = []

if 'data1' not in st.session_state['data']:
    st.session_state['data']['data1'] = rand(100) * 100
if 'data2' not in st.session_state['data']:
    st.session_state['data']['data2'] = rand(100) * 100
if 'data3' not in st.session_state['data']:
    st.session_state['data']['data3'] = rand(100) * 100
if 'data4' not in st.session_state['data']:
    st.session_state['data']['data4'] = rand(100) * 100
if 'data5' not in st.session_state['data']:
    st.session_state['data']['data5'] = rand(100) * 100

col1, col2 = st.columns([1, 3])

def create_graph():
    nums = np.array(range(100))

    fig = plt.figure()

    if 'plots' not in st.session_state:
        return

    for i, data in enumerate(st.session_state['data']):
        if data not in st.session_state['plots']:
            plt.plot(nums, st.session_state['data'][data], label=data, color='lightgray')
    for i, data in enumerate(st.session_state['data']):
        if data in st.session_state['plots']:
            plt.plot(nums, st.session_state['data'][data], label=data, color=colors[i])

    plt.legend(loc = 'upper right')

    col2.pyplot(fig)

def add_plot(name):
    if name not in st.session_state['plots']:
        st.session_state['plots'].append(name)

def del_plot(name):
    if name in st.session_state['plots']:
        st.session_state['plots'].remove(name)

for i in range(5):
    name = "data" + str(i + 1)
    check = col1.checkbox(name)
    if check:
        add_plot(name)
    else:
        del_plot(name)

create_graph()

Profile

Hotaru

メーカーで組み込み系のソフトウェアやファームウェアの開発をしています。

仕事では主にC言語、Python、C#を使っています。 ...もっと見る