PyTorch 使用 PyTorch C ++前端

2025-06-23 17:39 更新

本教程将详细讲解如何使用 PyTorch C++ 前端进行模型开发和训练。我们将通过一个生成对抗网络(GAN)的示例,展示如何在 C++ 中定义模型、加载数据、训练模型以及保存模型。

一、环境配置

首先,确保已安装 PyTorch 的 C++ 前端 LibTorch。可以从PyTorch 官网下载适用于不同操作系统的 LibTorch 版本。

## 下载 LibTorch(以 CPU 版本为例)
wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
unzip libtorch-shared-with-deps-latest.zip

二、定义模型

我们将定义一个简单的生成器和鉴别器模型,用于生成 MNIST 数据集的数字图像。

生成器模型

#include <torch/torch.h>


struct DCGANGeneratorImpl : torch::nn::Module {
    DCGANGeneratorImpl(int kNoiseSize)
        : conv1(nn::ConvTranspose2dOptions(kNoiseSize, 256, 4).bias(false)),
          batch_norm1(256),
          conv2(nn::ConvTranspose2dOptions(256, 128, 3).stride(2).padding(1).bias(false)),
          batch_norm2(128),
          conv3(nn::ConvTranspose2dOptions(128, 64, 4).stride(2).padding(1).bias(false)),
          batch_norm3(64),
          conv4(nn::ConvTranspose2dOptions(64, 1, 4).stride(2).padding(1).bias(false)) {
        register_module("conv1", conv1);
        register_module("conv2", conv2);
        register_module("conv3", conv3);
        register_module("conv4", conv4);
        register_module("batch_norm1", batch_norm1);
        register_module("batch_norm2", batch_norm2);
        register_module("batch_norm3", batch_norm3);
    }


    torch::Tensor forward(torch::Tensor x) {
        x = torch::relu(batch_norm1(conv1(x)));
        x = torch::relu(batch_norm2(conv2(x)));
        x = torch::relu(batch_norm3(conv3(x)));
        x = torch::tanh(conv4(x));
        return x;
    }


    nn::ConvTranspose2d conv1, conv2, conv3, conv4;
    nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3;
};


TORCH_MODULE(DCGANGenerator);

鉴别器模型

auto discriminator = torch::nn::Sequential(
    nn::Conv2d(nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)),
    nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
    nn::Conv2d(nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
    nn::BatchNorm2d(128),
    nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
    nn::Conv2d(nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),
    nn::BatchNorm2d(256),
    nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
    nn::Conv2d(nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),
    nn::Sigmoid()
);

三、数据加载

使用 PyTorch C++ 前端的 torch::data API 加载 MNIST 数据集。

auto dataset = torch::data::datasets::MNIST("./mnist")
    .map(torch::data::transforms::Normalize<>(0.5, 0.5))
    .map(torch::data::transforms::Stack<>());


auto data_loader = torch::data::make_data_loader(
    std::move(dataset),
    torch::data::DataLoaderOptions().batch_size(64).workers(2)
);

四、训练循环

定义生成器和鉴别器的优化器,并实现训练循环。

torch::optim::Adam generator_optimizer(
    generator->parameters(),
    torch::optim::AdamOptions(2e-4).beta1(0.5)
);
torch::optim::Adam discriminator_optimizer(
    discriminator->parameters(),
    torch::optim::AdamOptions(5e-4).beta1(0.5)
);


int kNumberOfEpochs = 30;
int batches_per_epoch = 938;


for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
    int64_t batch_index = 0;
    for (torch::data::Example<>& batch : *data_loader) {
        // Train discriminator with real images
        discriminator->zero_grad();
        torch::Tensor real_images = batch.data;
        torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
        torch::Tensor real_output = discriminator->forward(real_images);
        torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels);
        d_loss_real.backward();


        // Train discriminator with fake images
        torch::Tensor noise = torch::randn({batch.data.size(0), 100, 1, 1});
        torch::Tensor fake_images = generator->forward(noise);
        torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
        torch::Tensor fake_output = discriminator->forward(fake_images.detach());
        torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
        d_loss_fake.backward();
        torch::Tensor d_loss = d_loss_real + d_loss_fake;
        discriminator_optimizer.step();


        // Train generator
        generator->zero_grad();
        fake_labels.fill_(1);
        fake_output = discriminator->forward(fake_images);
        torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels);
        g_loss.backward();
        generator_optimizer.step();


        std::printf(
            "\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f",
            epoch, kNumberOfEpochs, ++batch_index, batches_per_epoch,
            d_loss.item<float>(), g_loss.item<float>()
        );
    }
}

五、模型保存与加载

定期保存模型和优化器的状态,以便在需要时恢复训练。

int kCheckpointEvery = 100;
int checkpoint_counter = 0;


if (batch_index % kCheckpointEvery == 0) {
    torch::save(generator, "generator-checkpoint.pt");
    torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt");
    torch::save(discriminator, "discriminator-checkpoint.pt");
    torch::save(discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");


    torch::Tensor samples = generator->forward(torch::randn({8, 100, 1, 1}));
    torch::save((samples + 1.0) / 2.0, torch::str("dcgan-sample-", checkpoint_counter, ".pt"));
    std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n';
}

通过本教程,大家可以在编程狮(W3Cschool)平台上轻松掌握 PyTorch C++ 前端的使用方法。C++ 前端为 PyTorch 提供了强大的功能扩展,使开发者能够在更多场景中应用深度学习技术。在编程狮(W3Cschool)学习更多相关内容,提升你的深度学习开发技能。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号