PyTorch 基础知识

2025-06-24 09:38 更新

PyTorch 是一个开源的机器学习库,广泛应用于计算机视觉、自然语言处理等领域。本教程将详细讲解 PyTorch 的基础知识。

一、PyTorch 简介

PyTorch 是由 Facebook AI Research 实验室开发的开源机器学习库,具有以下特点:

  • 动态计算图 :PyTorch 采用动态计算图,允许在运行时动态调整网络结构,提供了更大的灵活性。
  • 易于使用的 API :PyTorch 提供了简单直观的 API,方便用户快速上手。
  • 强大的社区支持 :PyTorch 拥有活跃的开发者社区,提供了丰富的教程和示例代码。

二、张量操作

张量是 PyTorch 中的基本数据结构,类似于 NumPy 中的数组。

import torch


## 创建张量
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])


## 张量运算
z = x + y
print(z)

三、自动求导

PyTorch 提供了自动求导功能,可以自动计算梯度。

x = torch.tensor(2.0, requires_grad=True)
y = x ** 2
y.backward()
print(x.grad)

四、模型构建

使用 PyTorch 构建神经网络模型非常简单。

import torch.nn as nn


class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 5)


    def forward(self, x):
        return self.linear(x)


model = SimpleModel()
print(model)

五、模型训练

训练模型是深度学习中的关键步骤。

import torch.optim as optim


## 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)


## 训练模型
for epoch in range(10):
    optimizer.zero_grad()
    output = model(torch.randn(1, 10))
    loss = criterion(output, torch.randn(1, 5))
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

六、模型保存与加载

训练完成后,可以保存和加载模型。

## 保存模型
torch.save(model.state_dict(), 'model.pth')


## 加载模型
model.load_state_dict(torch.load('model.pth'))

通过本教程,大家可以在编程狮(W3Cschool)平台上轻松掌握 PyTorch 的基础知识。PyTorch 是深度学习领域的重要工具,希望大家能够学以致用,在实际项目中灵活应用这些技术。在编程狮(W3Cschool)学习更多相关内容,提升你的深度学习开发技能。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号