PyTorch torch.utils.tensorboard
PyTorch TensorBoard 可视化详解:从入门到精通
一、什么是 TensorBoard?
TensorBoard 是一个可视化工具,用于跟踪机器学习实验的指标和可视化模型结构。通过 TensorBoard,我们可以轻松地监控训练过程中的损失、准确率、学习率等指标的变化,以及可视化模型的计算图、图像、直方图等信息。PyTorch 提供了 torch.utils.tensorboard
模块,方便我们将 PyTorch 模型和指标记录到 TensorBoard 中进行可视化。
二、安装 TensorBoard
在使用 TensorBoard 之前,您需要先安装它。可以通过以下命令安装 TensorBoard:
pip install tensorboard
安装完成后,可以通过以下命令启动 TensorBoard:
tensorboard --logdir=runs
其中,--logdir
参数指定日志文件的存储目录。默认情况下,PyTorch 会将日志文件保存在当前目录下的 runs
文件夹中。
三、核心工具:SummaryWriter
SummaryWriter
是 PyTorch 中用于将数据记录到 TensorBoard 的主要工具。它提供了多种方法来记录不同类型的可视化数据,如标量、图像、直方图、模型图等。
(一)创建 SummaryWriter
from torch.utils.tensorboard import SummaryWriter
## 创建一个 SummaryWriter,默认日志文件保存在 ./runs/ 目录下
writer = SummaryWriter()
## 也可以指定日志文件的保存目录
writer = SummaryWriter(log_dir='my_logs')
(二)记录标量数据
import numpy as np
for n_iter in range(100):
writer.add_scalar('Loss/train', np.random.random(), n_iter)
writer.add_scalar('Loss/test', np.random.random(), n_iter)
writer.add_scalar('Accuracy/train', np.random.random(), n_iter)
writer.add_scalar('Accuracy/test', np.random.random(), n_iter)
writer.close()
(三)记录图像数据
from torchvision import datasets, transforms
from torchvision.utils import make_grid
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.MNIST('mnist_train', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
images, labels = next(iter(trainloader))
grid = make_grid(images)
writer.add_image('images', grid, 0)
writer.close()
(四)记录模型结构
import torchvision.models as models
model = models.resnet50(False)
images, labels = next(iter(trainloader))
writer.add_graph(model, images)
writer.close()
(五)记录直方图
for i in range(10):
x = np.random.random(1000)
writer.add_histogram('distribution centers', x + i, i)
writer.close()
(六)记录嵌入
import keyword
import torch
meta = []
while len(meta) < 100:
meta = meta + keyword.kwlist # 获取一些字符串
meta = meta[:100]
for i, v in enumerate(meta):
meta[i] = v + str(i)
label_img = torch.rand(100, 3, 10, 32)
for i in range(100):
label_img[i] *= i / 100.0
writer.add_embedding(torch.randn(100, 5), metadata=meta, label_img=label_img)
writer.close()
(七)记录超参数
with SummaryWriter() as w:
for i in range(5):
w.add_hparams({'lr': 0.1 * i, 'bsize': i}, {'hparam/accuracy': 10 * i, 'hparam/loss': 10 * i})
四、高级功能与技巧
(一)自定义标量布局
通过 add_custom_scalars
方法,可以自定义标量的显示布局,将多个标量组织成组,便于比较和分析。
layout = {
'Taiwan': {
'twse': ['Multiline', ['twse/0050', 'twse/2330']]
},
'USA': {
'dow': ['Margin', ['dow/aaa', 'dow/bbb', 'dow/ccc']],
'nasdaq': ['Margin', ['nasdaq/aaa', 'nasdaq/bbb', 'nasdaq/ccc']]
}
}
writer.add_custom_scalars(layout)
(二)记录网格数据
可以使用 add_mesh
方法记录 3D 网格数据,用于可视化 3D 模型或点云数据。
vertices_tensor = torch.as_tensor([
[1, 1, 1],
[-1, -1, 1],
[1, -1, -1],
[-1, 1, -1],
], dtype=torch.float).unsqueeze(0)
colors_tensor = torch.as_tensor([
[255, 0, 0],
[0, 255, 0],
[0, 0, 255],
[255, 0, 255],
], dtype=torch.int).unsqueeze(0)
faces_tensor = torch.as_tensor([
[0, 2, 3],
[0, 3, 1],
[0, 1, 2],
[1, 3, 2],
], dtype=torch.int).unsqueeze(0)
writer.add_mesh('my_mesh', vertices=vertices_tensor, colors=colors_tensor, faces=faces_tensor)
writer.close()
(三)记录 PR 曲线
PR 曲线(Precision-Recall Curve)用于评估分类模型在不同阈值下的性能。可以通过 add_pr_curve
方法记录 PR 曲线。
labels = np.random.randint(2, size=100) # 二分类标签
predictions = np.random.rand(100) # 预测概率
writer.add_pr_curve('pr_curve', labels, predictions, 0)
writer.close()
五、总结
通过本教程,我们详细介绍了 PyTorch 中的 torch.utils.tensorboard
模块的使用方法,包括如何记录标量、图像、直方图、模型结构、嵌入、超参数等多种数据类型,以及如何利用 TensorBoard 进行可视化。这些功能可以帮助我们更好地理解和优化机器学习模型的训练过程。
更多建议: