赞
踩
欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://blog.csdn.net/caroline_wendy/article/details/131384440
Matplotlib 是一个用于绘制二维图形的 Python 库,提供了一个 pyplot 模块,用于创建各种类型的图表。其中一种图表是散点图(Scatter Plots),用于展示两个变量之间的关系,以及数据的分布情况。要绘制散点图,使用 pyplot.scatter() 函数,接受以下参数:
主要:
sns.set_theme(style="white")
fig.set_size_inches(10, 10)
ax.plot()
ax.scatter()
ax.annotate()
源码如下:
#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2022. All rights reserved.
Created by C. L. Wang on 2023/6/25
"""
import os
import seaborn as sns
sns.set_theme(style="white")
import matplotlib.pyplot as plt
import numpy as np
from myutils.project_utils import read_excel_to_df, get_current_time_str
from root_dir import DATA_DIR
def draw_diagonal_scatter_plots(
data_better, data_ref, label_list,
min_scale=0.0, max_scale=1.05,
x_label="", y_label="",
save_name=""
):
"""
绘制对角线散点图
:param data_better: 优质数据,数据位于右下方
:param data_ref: 对比数据
:param label_list: 标签
:param min_scale: 最小范围
:param max_scale: 最大范围
:param x_label: x轴描述
:param y_label: y轴描述
:param save_name: 文件存储
:return: 图
"""
assert len(data_ref) == len(data_better) == len(label_list)
fig, ax = plt.subplots()
fig.set_size_inches(10, 10)
ax.grid(True)
ax.plot([min_scale, max_scale], [min_scale, max_scale], ls="--", c=".3")
ax.scatter(data_ref, data_better, s=100, edgecolors="black")
plt.xlim(min_scale, max_scale)
plt.ylim(min_scale, max_scale)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.xlabel(x_label, fontsize=20)
plt.ylabel(y_label, fontsize=20)
for i, txt in enumerate(label_list):
if data_ref[i] - data_better[i] > 0.05:
ax.annotate(txt, (data_ref[i] * 1.01, data_better[i]), fontsize=10, fontweight='bold')
if save_name:
# transparent=True
assert save_name.endswith("png") or save_name.endswith("jpg")
plt.savefig(save_name, bbox_inches='tight', format='png')
plt.show()
def main():
df = read_excel_to_df(os.path.join(DATA_DIR, "20230630-best-tmscore.xlsx"))
data1 = df["m0-score"]
data2 = df["max-score"]
label_list = df["target"]
print(f"data1 : {round(float(np.mean(data1)) * 100, 4)}±{round(float(np.std(data1)), 4)}")
print(f"data2 : {round(float(np.mean(data2)) * 100, 4)}±{round(float(np.std(data2)), 4)}")
draw_diagonal_scatter_plots(data1, data2, label_list, min_scale=0.35, max_scale=1.05,
x_label="SOTA", y_label="Our Best",
save_name=f"img-{get_current_time_str()}.png")
if __name__ == '__main__':
main()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。