Matplotlib绘图系列一;基础图形绘制

本文旨在介绍matplotlib中简单图形的绘制,主要涉及:折线图、柱状图、散点图、多子图绘制及相关注意事项

环境:

Python==3.9.0
matplotlib==3.6
seaborn==0.12.2

前言

首先,导入所需的Python库

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sn
matplotlib.use('TkAgg')

其次,指定需要绘制图的风格,字体等信息

plt.style.use('ggplot') # 该命令指定以ggplot风格绘制图形。
# 使用以下命令可以查看matplotlib支持的所有风格
print(plt.style.library)  # 共有28种风格,其中有一些存在重叠。每种风格定义了其背景样式(颜色,底纹),以及绘制图案时的颜色集合)

font_plot = {
    'family': 'Times New Roman',  # 字体类型
    # 'color': 'darkred',		# 字体颜色
    'weight': 'medium',			# 不关键,可不写
    'style': 'italic',			 # 斜体
    'variant': 'small-caps',     # 不关键,可不写
    'stretch': 'extra-condensed', # 不关键,可不写
    'size': 12  # 关键,决定字体大小
}
# 一般用法如下
font14 = {'family': 'Times New Roman',
          'size': 14}

接着,定义简单的数据集作为绘制数据

datas = [
    [100.0, 90.4, 86.7, 92.0, 52.4, 37.3, 38.5, 33.5],
    [100.0, 91.5, 88.3, 92.6, 52.1, 38.5, 36.3, 34.1],
    [100.0, 91.5, 88.4, 94.5, 51.8, 37.1, 37.5, 30.9],
    [100.0, 92.2, 88.1, 93.9, 52.9, 37.6, 38.3, 32.1],
    [100.0, 91.9, 88.5, 94.5, 53.4, 37.2, 38.8, 30.7],
    [100.0, 92.2, 88.7, 94.4, 53.0, 37.0, 38.6, 32.0],
    [100.0, 91.8, 89.1, 94.9, 53.8, 37.6, 40.0, 32.7],
    [100.0, 92.0, 88.2, 94.2, 54.2, 37.0, 39.9, 30.9],
    [100.0, 92.3, 89.3, 94.1, 53.2, 34.8, 38.8, 31.2],
    [100.0, 92.2, 88.9, 94.5, 54.3, 35.1, 39.7, 31.3],
] # 数据维度是 10 * 8,每一列表示不同模型,每一行表示不同实验设置

column_names = ["M1", "M2", "M3", "M4", "M5", "M6", "M7", "M8"]
row_names = ["Exp1", "Exp2", "Exp3", "Exp4", "Exp5", "Exp6", "Exp7", "Exp8"]

最后定义绘制图形的marker,以及颜色

markers = [".", "o", "v", "^", "<", ">", "*", "x", "D", "d", "2", "p", "+"]
colors = ['#E24A33', '#348ABD', '#988ED5', '#777777', '#FBC15E', '#8EBA42', '#FFB5B8', '#bcbd22']

一、折线图

1.1 单个折线图

fig = plt.figure()
x_axis = list(range(len(steps)))  # length:  10
data = np.asarray(resnet50_data_conv_different_steps).T  # shape: [8, 10]
for i, row in enumerate(data):
    # row length: 10
    plt.plot(x_axis, row, color=colors[i % len(colors)], marker=markers[i], label=legend_label[i])
plt.xticks(x_axis, steps, font=font14)
plt.ylabel("Y-axis-label", font=font14)
plt.xlabel("X-axis-label", font=font14)
plt.legend()
plt.show()

1.2 多子图绘制

fig, axs = plt.subplots(nrows=2, ncols=5, figsize=[13, 10], sharex=True, sharey=True,  constrained_layout=True)  # default is [6.4, 4.8]  constrained_layout=False
for idx, ax in enumerate(axs.flat):
    x_axis = list(range(len(steps)))  # length:  10
    data = np.asarray(resnet50_data_conv_different_steps).T  # shape: [8, 10]
    for i, row in enumerate(data):
        # row length: 10
        ax.plot(x_axis, row, color=colors[i % len(colors)], marker=markers[i], label=legend_label[i])
    ax.set_xticks(x_axis, steps, font=font12, rotation=-30)
    ax.set_title(title_list[idx])


lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
lines, labels_leg = [sum(lol, []) for lol in zip(*lines_labels)]

fig.legend(lines[:8], labels_leg[:8], ncol=8, bbox_to_anchor=(0.85, -0.01))
fig.text(0, 0.5, 'Y-axis-label', va='center', rotation='vertical', font=font12)
fig.text(0.45, 0, 'X-axis-label', va='center', rotation='horizontal', font=font12)
plt.savefig("analysis_example_subplot.png", bbox_inches='tight',
                    pad_inches=0, dpi=300)
plt.show()

二、柱状图

2.1 单个柱状图绘制

title_list = [f"data_{i}" for i in range(10)]

def plot_three_bar_figure(data_lsit, xticks, title_name, bar_width=0.2, save_fig=False):
    # 创建示例数据
    method_names = ['dropout', 'uniform', 'normal']
    data_dropout = data_lsit[0]  # length: 8
    data_uniform = data_lsit[1]
    data_normal = data_lsit[2]

    data = np.stack([data_dropout, data_uniform, data_normal], axis=-1)   # shape: [8, 3]
    # 创建X轴刻度位置
    x = np.arange(len(data_normal))  # length: 8
    # 绘制柱状图
    for i in range(len(method_names)):  # length: 3
        plt.bar(x + i * bar_width, data[:, i], bar_width, label=method_names[i], color=colors[i])
    # 设置X轴标签
    plt.xlabel('Models', font=font14)
    plt.xticks(x + bar_width, xticks, font=font14, rotation=-15)
    # 设置Y轴标签
    plt.ylabel('Y-axis-label', font=font14)
    # 添加图例
    plt.legend(loc='upper right')
    # # 设置标题
    plt.title(title_name)
    # 调整子图之间的间距
    plt.tight_layout()
    if save_fig:
        plt.savefig(f"analysis_example_bar.pdf", bbox_inches='tight',
                    pad_inches=0, dpi=300)
        plt.savefig(f"analysis_example_bar.png", bbox_inches='tight',
                    pad_inches=0, dpi=300)
    # 显示图形
    plt.show()
# 注意,这里使用了新得数据集
data = [
    [69.1, 73.7, 75.5, 68.5, 100, 68.7, 44.2, 62.9],
    [76.3, 79.9, 84, 73.8, 99.9, 69.7, 45.7, 63.3],
    [75.9, 80.4, 83.5, 74.1, 100, 69.9, 47.3, 64.5],
]

plot_three_bar_figure(data, legend_label, title_name=title_list[0], save_fig=False)

2.2 多子图绘制

title_list = [f"data_{i}" for i in range(10)]

def plot_three_bar_figure(axes, data_lsit, x_axis, title_name, bar_width=0.2):
    # 创建示例数据
    method_names = ['dropout', 'uniform', 'normal']
    data_dropout = data_lsit[0]  # length: 8
    data_uniform = data_lsit[1]
    data_normal = data_lsit[2]

    data = np.stack([data_dropout, data_uniform, data_normal], axis=-1)   # shape: [8, 3]
    # 创建X轴刻度位置
    # 绘制柱状图
    for i in range(len(method_names)):  # length: 3
        fill_plot = ['/', 'x', '\\']
        axes.bar(x_axis + i * bar_width, data[:, i], bar_width, hatch=fill_plot[i],
                 label=method_names[i], color=colors[i], edgecolor='white')
    
    # # 设置标题
    axes.set_title(title_name)
    # 调整子图之间的间距

data = [
    [69.1, 73.7, 75.5, 68.5, 100, 68.7, 44.2, 62.9],
    [76.3, 79.9, 84, 73.8, 99.9, 69.7, 45.7, 63.3],
    [75.9, 80.4, 83.5, 74.1, 100, 69.9, 47.3, 64.5],
]

bar_width = 0.2
fig, axs = plt.subplots(nrows=2, ncols=5, figsize=[13, 10], sharex=True, sharey=True,  constrained_layout=True)  # default is [6.4, 4.8]  constrained_layout=False
for idx, ax in enumerate(axs.flat):
    x_axis = np.arange(len(data[0]))  # length: 8
    for i, row in enumerate(data):
        plot_three_bar_figure(ax, data, x_axis, title_name=title_list[i])
    ax.set_xticks(x_axis + bar_width, legend_label, font=font12, rotation=90)
    ax.set_title(title_list[idx])

lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
lines, labels_leg = [sum(lol, []) for lol in zip(*lines_labels)]
fig.legend(lines[:3], labels_leg[:3], ncol=3, bbox_to_anchor=(0.65, -0.01))

fig.text(0, 0.5, 'Y-axis-label', va='center', rotation='vertical', font=font12)
fig.text(0.45, 0, 'Models', va='center', rotation='horizontal', font=font12)

# plt.savefig(f"analysis_example_bar.pdf", bbox_inches='tight', pad_inches=0, dpi=300)
plt.savefig(f"analysis_example_bar_subpolt.png", bbox_inches='tight', pad_inches=0, dpi=300)
# plt.show()

3. 散点图

散点图无多子样例

label1 = ["A1", "A2", "A3", "A4", "A5"]
label2 = ["A1-S", "A2-S", "A3-S", "A4-S", "A5-S"]

markers = [".", "o","v", "^","<",">","*", "x", "D","d", "2", "p", "+"]
colors = ['#1f77b4','#ff7f0e','#2ca02c','#d62728','#9467bd','#8c564b','#e377c2', '#7f7f7f', '#bcbd22', '#17becf']

markers1 = ["o","o","o","o","o"]
markers2 = ["v","v","v","v","v"]
colors1 = ['#1f77b4','#ff7f0e','#2ca02c','#d62728','#9467bd']



data_x = {
    "Ori": [55.57, 66.86, 71.21, 66.93, 72.26],
    "S":  [61.64, 70.71, 72.01, 68.29, 73.93]
}

data_y = {
    "Ori": [3.709545851, 3.538772821, 3.847475052, 3.89070797, 3.351880789],
    "S": [3.642693996, 3.538482189, 3.550876141, 3.903666258, 4.207882881]
}


markers3 = ['.',',', 'o','v','^','<','>','8','s','p','*','+','D','d','x','|','_']

fig = plt.figure()
def scatter_plot(data1, data2, markers, type='Ori'):
    for i, (p1, p2, m) in enumerate(zip(data1[type], data2[type], markers)):
        if type == "Ori":
            plt.scatter(p1, p2, marker=m, label=label1[i], color=colors[i], s=60)
        else:
            plt.scatter(p1, p2, marker=m, label=label1[i]+"-S", color=colors[i], s=60)

scatter_plot(data_x, data_y, markers=markers3[:5], type='Ori')
scatter_plot(data_x, data_y, markers=markers3[6:], type='S')
plt.ylabel("data y-axis")
plt.xlabel("data xaxis")

plt.legend(ncol=2)
plt.savefig(f"analysis_example_scatter.png", bbox_inches='tight', pad_inches=0, dpi=300)
plt.show()
Table of Contents