  • GPT-2模型介绍

  • GPT-2(Generative Pre-trained Transformer 2)是由OpenAI开发的一种基于Transformer模型的自然语言处理(NLP)模型,旨在生成自然流畅的文本。

  • 它是一种无监督学习模型,设计目标是能够理解人类语言的复杂性并模拟出自然的语言生成。

  • GPT-2具有大量的训练数据和强大的算法,可以生成自然流畅、准确的文本。

  • KerasNLP与GPT-2

  • KerasNLP是Keras的一个扩展库,提供了对NLP任务的便捷支持,包括文本生成。

  • 通过KerasNLP,可以方便地加载预训练的GPT-2模型,并用于文本生成任务。

  • 文本生成过程

  • 使用GPT2Tokenizer将输入的文本转换为模型可以理解的格式(即token IDs)。

  • 将token IDs作为输入传递给GPT-2模型。

  • 模型根据输入的上下文生成新的token IDs。

  • 使用GPT2Tokenizer将生成的token IDs解码回文本格式。

  • 特点与优势

  • GPT-2模型使用了大量的预训练参数,使其具有强大的表现力和泛化能力。

  • 可以生成各种类型的文本,如新闻、故事、对话和代码等。

  • 与其他基于神经网络的语言模型相比,GPT-2具有许多独特的优点,如自监督学习方式和处理多种语言和任务的能力。

  • 性能与规模

  • GPT-2模型有多个版本,从小型到大型,以适应不同的计算资源和性能需求。

  • 参数数量从1.5亿到1.75亿不等,模型大小从0.5GB到1.5GB。

  • 使用示例

  • 可以通过KerasNLP提供的接口和预训练模型,轻松实现文本生成任务。

  • 可以通过修改输入文本和参数设置,生成具有不同风格和主题的文本。

  • 注意事项

  • 生成的文本可能不完全符合语法或逻辑,因为模型是基于统计语言模型进行预测的。

  • 在实际应用中,需要对生成的文本进行适当的后处理和筛选,以确保其质量和适用性。





运行GPT2模型需要较高的资源需求,请确保前往运行时 -> 更改运行环境类型并选择GPU硬件加速器运行环境(应具有>12G主机RAM和~15G GPU RAM),因为你将微调GPT-2模型。在CPU运行环境中运行此教程将需要数小时。


这个示例使用Keras 3以便在"tensorflow"、"jax"或"torch"中任一环境中工作。KerasNLP内置了对Keras 3的支持,只需更改"KERAS_BACKEND"环境变量即可选择您所选择的后端。我们在下面选择JAX后端。

!pip install git+https://github.com/keras-team/keras-nlp.git -q
import os
os.environ["KERAS_BACKEND"] = "jax"  # 或"tensorflow"或"torch"

import keras_nlp
import keras
import tensorflow as tf
import time

生成性LLMs通常基于深度学习的神经网络,例如2017年由Google研究人员发明的Transformer架构,并且它们在大量文本数据上进行训练,通常涉及数十亿个单词。这些模型,如Google LaMDA和PaLM,是使用来自各种数据源的大型数据集进行训练的,这使它们能够为许多任务生成输出。生成性LLMs的核心是预测句子中的下一个词,通常称为因果语言模型预训练。通过这种方式,LLMs可以根据用户提示生成连贯的文本。有关语言模型的更多教学性讨论,可以参考斯坦福CS324 LLM课程。





  • 带有generate()方法的预训练模型,例如keras_nlp.models.GPT2CausalLMkeras_nlp.models.OPTCausalLM
  • 实现生成算法(如Top-K、Beam和对比搜索)的Sampler类。这些samplers可用于使用自定义模型生成文本。

3 加载模型

3.1 加载预训练的GPT-2模型并生成一些文本

KerasNLP提供了许多预训练模型,如Google Bert和GPT-2。程序员可以在KerasNLP仓库中看到可用模型的列表。


# 为了加快训练和生成速度,我们使用长度为128的预处理器
# 而不是完整的长度1024。
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
    "gpt2_base_en", preprocessor=preprocessor
start = time.time()

output = gpt2_lm.generate("My trip to Yosemite was", max_length=200)
print("\nGPT-2 output:")

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")
start = time.time()

output = gpt2_lm.generate("That Italian restaurant is", max_length=200)
print("\nGPT-2 output:")

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")
GPT-2 output:
That Italian restaurant is known for its delicious food, and the best part is that it has a full bar, with seating for a whole host of guests. And that's only because it's located at the heart of the neighborhood.
3.2 KerasNLP中的GPT-2模型的工具



  • keras_nlp.models.GPT2Tokenizer: GPT2模型使用的分词器,它是一个字节对编码器。
  • keras_nlp.models.GPT2CausalLMPreprocessor: GPT2因果语言模型训练使用的预处理器。它进行分词以及其他预处理工作,如创建标签和附加结束标记。
  • keras_nlp.models.GPT2Backbone: GPT2模型,它是keras_nlp.layers.TransformerDecoder的堆叠。这通常只被称为GPT2
  • keras_nlp.models.GPT2CausalLM: 包装GPT2Backbone,它将GPT2Backbone的输出乘以嵌入矩阵以在词汇表标记上生成logits。

3.3 在Reddit数据集上微调


import tensorflow_datasets as tfds

reddit_ds = tfds.load("reddit_tifu", split="train", as_supervised=True)
train_ds = (
    reddit_ds.map(lambda document, _: document)
train_ds = train_ds.take(500)
num_epochs = 1

# 线性衰减的学习率。
learning_rate = keras.optimizers.schedules.PolynomialDecay(
    decay_steps=train_ds.cardinality() * num_epochs,
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.fit(train_ds, epochs=num_epochs)
start = time.time()

output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")
  • 使用字符串标识符,如"greedy",您通过这种方式使用默认配置。
  • 传递一个keras_nlp.samplers.Sampler实例,您可以通过这种方式使用自定义配置。
# 使用字符串标识符。
output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")

# 使用`Sampler`实例。`GreedySampler`往往会重复自身。
greedy_sampler = keras_nlp.samplers.GreedySampler()

output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")
5 在中文诗歌数据集上微调



!# 加载中文诗歌数据集。
!git clone https://github.com/chinese-poetry/chinese-poetry.git
import os
import json

poem_collection = []
for file in os.listdir("chinese-poetry/全唐诗"):
    if ".json" not in file or "poet" not in file:
    full_filename = "%s/%s" % ("chinese-poetry/全唐诗", file)
    with open(full_filename, "r") as f:
        content = json.load(f)

paragraphs = ["".join(data["paragraphs"]) for data in poem_collection]
train_ds = (

# 运行整个数据集需要很长时间,只取500条并运行1个epoch用于演示目的。
train_ds = train_ds.take(500)
num_epochs = 1

learning_rate = keras.optimizers.schedules.PolynomialDecay(
    decay_steps=train_ds.cardinality() * num_epochs,
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.fit(train_ds, epochs=num_epochs)
output = gpt2_lm.generate("昨夜雨疏风骤", max_length=200)
pip install git+https://github.com/keras-team/keras-nlp.git -q

import os

os.environ["KERAS_BACKEND"] = "jax"  # or "tensorflow" or "torch"

import keras_nlp
import keras
import tensorflow as tf
import time


## Introduction to Generative Large Language Models (LLMs)

Large language models (LLMs) are a type of machine learning models that are
trained on a large corpus of text data to generate outputs for various natural
language processing (NLP) tasks, such as text generation, question answering,
and machine translation.

Generative LLMs are typically based on deep learning neural networks, such as
the [Transformer architecture](https://arxiv.org/abs/1706.03762) invented by
Google researchers in 2017, and are trained on massive amounts of text data,
often involving billions of words. These models, such as Google [LaMDA](https://blog.google/technology/ai/lamda/)
and [PaLM](https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html),
are trained with a large dataset from various data sources which allows them to
generate output for many tasks. The core of Generative LLMs is predicting the
next word in a sentence, often referred as **Causal LM Pretraining**. In this
way LLMs can generate coherent text based on user prompts. For a more
pedagogical discussion on language models, you can refer to the
[Stanford CS324 LLM class](https://stanford-cs324.github.io/winter2022/lectures/introduction/).

## Introduction to KerasNLP

Large Language Models are complex to build and expensive to train from scratch.
Luckily there are pretrained LLMs available for use right away. [KerasNLP](https://keras.io/keras_nlp/)
provides a large number of pre-trained checkpoints that allow you to experiment
with SOTA models without needing to train them yourself.

KerasNLP is a natural language processing library that supports users through
their entire development cycle. KerasNLP offers both pretrained models and
modularized building blocks, so developers could easily reuse pretrained models
or stack their own LLM.

In a nutshell, for generative LLM, KerasNLP offers:

- Pretrained models with `generate()` method, e.g.,
    `keras_nlp.models.GPT2CausalLM` and `keras_nlp.models.OPTCausalLM`.
- Sampler class that implements generation algorithms such as Top-K, Beam and
    contrastive search. These samplers can be used to generate text with
    custom models.

## Load a pre-trained GPT-2 model and generate some text

KerasNLP provides a number of pre-trained models, such as [Google
and [GPT-2](https://openai.com/research/better-language-models). You can see
the list of models available in the [KerasNLP repository](https://github.com/keras-team/keras-nlp/tree/master/keras_nlp/models).

It's very easy to load the GPT-2 model as you can see below:

# To speed up training and generation, we use preprocessor of length 128
# instead of full length 1024.
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
    "gpt2_base_en", preprocessor=preprocessor

Once the model is loaded, you can use it to generate some text right away. Run
the cells below to give it a try. It's as simple as calling a single function

start = time.time()

output = gpt2_lm.generate("My trip to Yosemite was", max_length=200)
print("\nGPT-2 output:")

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")

Try another one:

start = time.time()

output = gpt2_lm.generate("That Italian restaurant is", max_length=200)
print("\nGPT-2 output:")

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")

Notice how much faster the second call is. This is because the computational
graph is [XLA compiled](https://www.tensorflow.org/xla) in the 1st run and
re-used in the 2nd behind the scenes.

The quality of the generated text looks OK, but we can improve it via

## More on the GPT-2 model from KerasNLP

Next up, we will actually fine-tune the model to update its parameters, but
before we do, let's take a look at the full set of tools we have to for working
with for GPT2.

The code of GPT2 can be found
Conceptually the `GPT2CausalLM` can be hierarchically broken down into several
modules in KerasNLP, all of which have a *from_preset()* function that loads a
pretrained model:

- `keras_nlp.models.GPT2Tokenizer`: The tokenizer used by GPT2 model, which is a
    [byte-pair encoder](https://huggingface.co/course/chapter6/5?fw=pt).
- `keras_nlp.models.GPT2CausalLMPreprocessor`: the preprocessor used by GPT2
    causal LM training. It does the tokenization along with other preprocessing
    works such as creating the label and appending the end token.
- `keras_nlp.models.GPT2Backbone`: the GPT2 model, which is a stack of
    `keras_nlp.layers.TransformerDecoder`. This is usually just referred as
- `keras_nlp.models.GPT2CausalLM`: wraps `GPT2Backbone`, it multiplies the
    output of `GPT2Backbone` by embedding matrix to generate logits over
    vocab tokens.

## Finetune on Reddit dataset

Now you have the knowledge of the GPT-2 model from KerasNLP, you can take one
step further to finetune the model so that it generates text in a specific
style, short or long, strict or casual. In this tutorial, we will use reddit
dataset for example.

import tensorflow_datasets as tfds

reddit_ds = tfds.load("reddit_tifu", split="train", as_supervised=True)

Let's take a look inside sample data from the reddit TensorFlow Dataset. There
are two features:

- **__document__**: text of the post.
- **__title__**: the title.


for document, title in reddit_ds:

In our case, we are performing next word prediction in a language model, so we
only need the 'document' feature.

train_ds = (
    reddit_ds.map(lambda document, _: document)

Now you can finetune the model using the familiar *fit()* function. Note that
`preprocessor` will be automatically called inside `fit` method since
`GPT2CausalLM` is a `keras_nlp.models.Task` instance.

This step takes quite a bit of GPU memory and a long time if we were to train
it all the way to a fully trained state. Here we just use part of the dataset
for demo purposes.

train_ds = train_ds.take(500)
num_epochs = 1

# Linearly decaying learning rate.
learning_rate = keras.optimizers.schedules.PolynomialDecay(
    decay_steps=train_ds.cardinality() * num_epochs,
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

gpt2_lm.fit(train_ds, epochs=num_epochs)

After fine-tuning is finished, you can again generate text using the same
*generate()* function. This time, the text will be closer to Reddit writing
style, and the generated length will be close to our preset length in the
training set.

start = time.time()

output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")

## Into the Sampling Method

In KerasNLP, we offer a few sampling methods, e.g., contrastive search,
Top-K and beam sampling. By default, our `GPT2CausalLM` uses Top-k search, but
you can choose your own sampling method.

Much like optimizer and activations, there are two ways to specify your custom

- Use a string identifier, such as "greedy", you are using the default
configuration via this way.
- Pass a `keras_nlp.samplers.Sampler` instance, you can use custom configuration
via this way.

# Use a string identifier.
output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")

# Use a `Sampler` instance. `GreedySampler` tends to repeat itself,
greedy_sampler = keras_nlp.samplers.GreedySampler()

output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")

For more details on KerasNLP `Sampler` class, you can check the code

## Finetune on Chinese Poem Dataset

We can also finetune GPT2 on non-English datasets. For readers knowing Chinese,
this part illustrates how to fine-tune GPT2 on Chinese poem dataset to teach our
model to become a poet!

Because GPT2 uses byte-pair encoder, and the original pretraining dataset
contains some Chinese characters, we can use the original vocab to finetune on
Chinese dataset.

# Load chinese poetry dataset.
git clone https://github.com/chinese-poetry/chinese-poetry.git

Load text from the json file. We only use《全唐诗》for demo purposes.

import os
import json

poem_collection = []
for file in os.listdir("chinese-poetry/全唐诗"):
    if ".json" not in file or "poet" not in file:
    full_filename = "%s/%s" % ("chinese-poetry/全唐诗", file)
    with open(full_filename, "r") as f:
        content = json.load(f)

paragraphs = ["".join(data["paragraphs"]) for data in poem_collection]

Let's take a look at sample data.


Similar as Reddit example, we convert to TF dataset, and only use partial data
to train.

train_ds = (

# Running through the whole dataset takes long, only take `500` and run 1
# epochs for demo purposes.
train_ds = train_ds.take(500)
num_epochs = 1

learning_rate = keras.optimizers.schedules.PolynomialDecay(
    decay_steps=train_ds.cardinality() * num_epochs,
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

gpt2_lm.fit(train_ds, epochs=num_epochs)

Let's check the result!

output = gpt2_lm.generate("昨夜雨疏风骤", max_length=200)

  • 环境设置:首先介绍了如何在Colab上选择GPU硬件加速器运行环境,以便于进行GPT-2模型的微调。

  • 安装与配置:然后指导用户安装KerasNLP库,并根据需要选择后端(tensorflow、jax或torch)。

  • 大型语言模型(GPT-2)介绍:解释了大型语言模型的概念,以及GPT-2是如何在大量文本数据上进行预训练的。

  • KerasNLP库介绍:介绍了KerasNLP库的功能,包括提供预训练模型和模块化的构建块,以便开发者可以重用或堆叠自己的LLM。

  • 加载预训练的GPT-2模型:展示了如何加载预训练的GPT-2模型,并使用它生成文本。

  • 微调模型:教程接下来介绍了如何使用Reddit数据集对GPT-2模型进行微调,以生成特定风格的文本。

  • 采样方法:讨论了KerasNLP中提供的几种采样方法,如Top-K、Beam和对比搜索,并展示了如何使用这些采样方法。

  • 在中文诗歌数据集上微调:最后,教程还介绍了如何在非英语数据集(中文诗歌)上微调GPT-2模型,以教模型成为诗人。


