赞
踩
一、plt.subplots(nrows, ncols, ...)
- import matplotlib.pyplot as plt
- fig, axes = plt.subplots(1, 3, num='train', figsize=(12, 6))
上述代码创建了一个有1行3列axes的figure,figure的大小为(12,6),figure的名字为'train'。如下图所示。此时plt指向最右边的ax(因为是最后创建的)。
上述代码等价于:(和上面一样,此时plt指向最右边的ax)。
- import matplotlib.pyplot as plt
- plt.figure("train", (12, 6))
- plt.subplot(1,3,1)
- plt.subplot(1,3,2)
- plt.subplot(1,3,3)
二、plt当前所指的fig/ax永远是最新创建的fig/ax,在调用plt.xxx函数时,要注意操作的对象是哪一个fig的哪个ax。(但plt.show会显示所有figure)
- import matplotlib.pyplot as plt
- import numpy as np
-
- np.random.seed(0)
- epochs = 4
- epoch_loss_values = np.random.randint(5, size=epochs)
-
- fig, axes = plt.subplots(1, 3, num='train', figsize=(12, 6))
- x = [i + 1 for i in range(len(epoch_loss_values))]
- y = epoch_loss_values
- axes[0].plot(x, y) # ax也有plot方法
- axes[0].set_xlabel('aaa') # ax有set_xlabel方法,没有xlabel方法
- plt.xlabel("epoch")
- plt.title("Epoch Average Loss")
结果如下:
三、一个fig中新创建的ax可能会覆盖旧的ax
- import matplotlib.pyplot as plt
- import numpy as np
-
- np.random.seed(0)
- epochs = 4
- epoch_loss_values = np.random.randint(5, size=epochs)
-
- fig, axes = plt.subplots(1, 3, num='train', figsize=(12, 6))
- x = [i + 1 for i in range(len(epoch_loss_values))]
- y = epoch_loss_values
- axes[0].plot(x, y)
- axes[0].set_xlabel('aaa')
- plt.subplot(1,2,2)
- plt.xlabel("epoch")
- plt.title("Epoch Average Loss")
结果如下:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。