matplotlib – 図を numpy 配列、PIL.Image や base64 文字列として取得する方法

目次

概要

matplotlib の図をファイルに一度保存することなく、bytes オブジェクト、画像を表す numpy 配列、base64 文字列などに変換する方法を解説します。

図を画像化する

共通する作業として、次の手順で Figure オブジェクトをレンダリングし、作成された画像のバイト列を取得する必要があります。

  1. Figure.canvas.draw() で図をレンダリングする。
  2. Figure.canvas.tostring_argb() または Figure.canvas.tostring_rgb() で画像を表すバイト列を取得する。
  3. Figure.canvas.get_width_height() で画像の大きさを取得する。
In [1]:
import numpy as np
from matplotlib import pyplot as plt

fig, ax = plt.subplots()
ax.pie([100, 200, 300, 400, 500])

# Figure をレンダリングする。
fig.canvas.draw()

# 画像をバイト列で取得する。
data = fig.canvas.tostring_rgb()
# アルファチャンネルを含む場合は tostring_argb()
# data = fig.canvas.tostring_argb()

# 画像の大きさを取得する。
w, h = fig.canvas.get_width_height()
c = len(data) // (w * h)

print(f"data size: {len(data)} bytes")
print(f"image shape: ({w}, {h}, {c})")
data size: 373248 bytes
image shape: (432, 288, 3)

Figure を numpy 配列に変換する

OpenCV の画像形式である numpy 配列に変換する方法です。 matplotlib はチャンネル順が RGB のため、OpenCV で扱う場合は最後にチャンネル順を BGR に変更します。

In [2]:
import cv2
import numpy as np
from matplotlib import pyplot as plt

fig, ax = plt.subplots()
ax.pie([100, 200, 300, 400, 500])

# Figure をレンダリングする。
fig.canvas.draw()

# 画像をバイト列で取得する。
data = fig.canvas.tostring_rgb()

# 画像の大きさを取得する。
w, h = fig.canvas.get_width_height()
c = len(data) // (w * h)

# numpy 配列に変換する
img = np.frombuffer(data, dtype=np.uint8).reshape(h, w, c)

# RGB を BGR に変更する。
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

plt.close()

Figure を PIL.Image に変換する

Pillow の画像形式である PIL.Image オブジェクト に変換する方法です。

In [3]:
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt

fig, ax = plt.subplots()
ax.pie([100, 200, 300, 400, 500])

# Figure をレンダリングする。
fig.canvas.draw()

# 画像をバイト列で取得する。
data = fig.canvas.tostring_rgb()

# 画像の大きさを取得する。
w, h = fig.canvas.get_width_height()
c = len(data) // (w * h)

# PIL.Image に変換する
img = Image.frombytes("RGB", (w, h), data, "raw")

plt.close()

Figure を jpg のバイト列に変換する

エンコードされた jpg のバイト列を取得する場合は、次の手順で行います。

  1. savefig()format="jpg" を指定し、出力先に BytesIO オブジェクトを指定する。
  2. 書き込まれたバイト列を BytesIO.getvalue() で取得する。
In [4]:
from io import BytesIO

import numpy as np
from matplotlib import pyplot as plt
from PIL import Image

fig, ax = plt.subplots()
ax.pie([100, 200, 300, 400, 500])

# Bytes IO に対して、エンコード結果を書き込む。
ofs = BytesIO()
fig.savefig(ofs, format="jpg")
png_data = ofs.getvalue()
print(f"data size: {len(png_data)} bytes")

plt.close()
data size: 9960 bytes

Figure を png のバイト列に変換する

エンコードされた png のバイト列を取得する場合は、次の手順で行います。

  1. savefig()format="png" を指定し、出力先に BytesIO オブジェクトを指定する。
  2. 書き込まれたバイト列を BytesIO.getvalue() で取得する。
In [5]:
from io import BytesIO

import numpy as np
from matplotlib import pyplot as plt
from PIL import Image

fig, ax = plt.subplots()
ax.pie([100, 200, 300, 400, 500])

# Bytes IO に対して、エンコード結果を書き込む。
ofs = BytesIO()
fig.savefig(ofs, format="png")
png_data = ofs.getvalue()
print(f"data size: {len(png_data)} bytes")

plt.close()
data size: 8208 bytes

Figure を base64 文字列に変換する

In [6]:
import base64
from io import BytesIO

import numpy as np
from matplotlib import pyplot as plt
from PIL import Image

fig, ax = plt.subplots()
ax.pie([100, 200, 300, 400, 500])

# Bytes IO に対して、エンコード結果を書き込む。
ofs = BytesIO()
fig.savefig(ofs, format="png")
png_data = ofs.getvalue()

plt.close()

# バイト列を base64 文字列に変換する。
base64_data = base64.b64encode(png_data).decode()
print(f"base64 string length: {len(base64_data)}")
base64 string length: 10944

Flask で図を表示する

先に紹介した base64 文字列に変換するコードを使って、Flask で Figure を動的に HTML 上に描画する例を紹介します。

app.py

import base64
from io import BytesIO

import numpy as np
import matplotlib

matplotlib.use("Agg")
from matplotlib import pyplot as plt
from flask import Flask, render_template

app = Flask(__name__)


def fig_to_base64(fig):
    """Figure を base64 文字列に変換する。
    """
    # Bytes IO に対して、エンコード結果を書き込む。
    ofs = BytesIO()
    fig.savefig(ofs, format="png")
    png_data = ofs.getvalue()

    # バイト列を base64 文字列に変換する。
    base64_data = base64.b64encode(png_data).decode()

    return base64_data


def create_graph():
    """matplotlib のグラフを作成する。
    """
    x = np.linspace(-10, 10, 100)
    y = x ** 2

    fig, ax = plt.subplots()
    ax.plot(x, y)
    ax.set_title("Title", c="darkred", size="large")

    return fig


@app.route("/")
def hello():
    # グラフを作成する。
    fig, ax = plt.subplots()
    ax.pie([100, 200, 300, 400, 500])

    # グラフを base64 文字列に変換する。
    img = fig_to_base64_img(fig)

    return render_template("index.html", img=img)


if __name__ == "__main__":
    app.run(debug=True)

templates/index.html

<!DOCTYPE html>
<html lang="ja">

<head>
    <meta charset="utf-8">
    <title>Example</title>
</head>

<body>
    <h1>Example</h1>

    <img src="data:image/png;base64,{{ img }}" />
</body>

</html>

コメント

コメントする

目次