PyTorch LSTM Word 语言模型上的(实验)动态量化

2025-06-23 11:58 更新

在自然语言处理任务中,模型的性能和效率至关重要。动态量化是一种有效的技术,可以减小模型尺寸、加快推断速度,同时对模型准确性影响较小。本教程将详细讲解如何在 LSTM Word 语言模型上应用动态量化。

一、模型定义

我们首先定义一个基于 LSTM 的语言模型,该模型包括编码器、循环模块和解码器。

import torch
import torch.nn as nn
import torch.nn.functional as F


class LSTMModel(nn.Module):
    """包含编码器、循环模块和解码器的容器模块。"""


    def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(LSTMModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)
        self.init_weights()
        self.nhid = nhid
        self.nlayers = nlayers


    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)


    def forward(self, input, hidden):
        emb = self.drop(self.encoder(input))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output)
        return decoded, hidden


    def init_hidden(self, bsz):
        weight = next(self.parameters())
        return (weight.new_zeros(self.nlayers, bsz, self.nhid),
                weight.new_zeros(self.nlayers, bsz, self.nhid))

二、数据预处理

我们将使用 Wikitext-2 数据集来训练和评估模型。首先,我们需要对数据进行预处理,包括分词和转换为张量。

class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []


    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]


    def __len__(self):
        return len(self.idx2word)


class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.train = self.tokenize(os.path.join(path, 'train.txt'))
        self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
        self.test = self.tokenize(os.path.join(path, 'test.txt'))


    def tokenize(self, path):
        """对文本文件进行分词。"""
        assert os.path.exists(path)
        with open(path, 'r', encoding="utf8") as f:
            for line in f:
                words = line.split() + ['<eos>']
                for word in words:
                    self.dictionary.add_word(word)
        with open(path, 'r', encoding="utf8") as f:
            idss = []
            for line in f:
                words = line.split() + ['<eos>']
                ids = []
                for word in words:
                    ids.append(self.dictionary.word2idx[word])
                idss.append(torch.tensor(ids).type(torch.int64))
            ids = torch.cat(idss)
        return ids


model_data_filepath = 'data/'
corpus = Corpus(model_data_filepath + 'wikitext-2')

三、加载预训练模型

为了应用动态量化,我们首先需要加载预训练的模型权重。

ntokens = len(corpus.dictionary)
model = LSTMModel(
    ntoken=ntokens,
    ninp=512,
    nhid=256,
    nlayers=5,
)
model.load_state_dict(
    torch.load(
        model_data_filepath + 'word_language_model_quantize.pth',
        map_location=torch.device('cpu')
    )
)
model.eval()
print(model)

四、动态量化应用

使用 PyTorch 的 quantize_dynamic 函数对模型进行动态量化。

import torch.quantization


quantized_model = torch.quantization.quantize_dynamic(
    model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
)
print(quantized_model)

五、性能评估

比较量化前后模型的大小和推断速度。

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p") / 1e6)
    os.remove('temp.p')


print_size_of_model(model)
print_size_of_model(quantized_model)

测试模型的推断性能:

torch.set_num_threads(1)


def time_model_evaluation(model, test_data):
    s = time.time()
    loss = evaluate(model, test_data)
    elapsed = time.time() - s
    print('''loss: {0:.3f}
elapsed time (seconds): {1:.1f}'''.format(loss, elapsed))


time_model_evaluation(model, test_data)
time_model_evaluation(quantized_model, test_data)

六、结果分析

通过动态量化,我们可以看到模型尺寸明显减小,推断速度显著提高,而模型的准确性几乎没有受到影响。在实际应用中,这种优化对于模型的部署和推理效率提升具有重要意义。

在编程狮(W3Cschool)平台上,你可以找到更多关于 PyTorch 模型优化和动态量化的详细教程和示例代码,帮助你深入理解和应用这些技术。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号