当前位置:   article > 正文

OpenAI视频生成模型Sora的全面解析:从ViViT、Diffusion Transformer到NaViT、VideoPoet_dit = [vae encoder + vit + ddpm + vae decoder]

dit = [vae encoder + vit + ddpm + vae decoder]

前言

真没想到,距离视频生成上一轮的集中爆发(详见《Sora之前的视频生成发展史:从Gen2、Emu Video到PixelDance、SVD、Pika 1.0)才过去三个月,没想OpenAI一出手,该领域又直接变天了

  1. 自打2.16日OpenAI发布sora以来(其开发团队包括DALLE 3的4作Tim Brooks、DiT一作Bill Peebles、三代DALLE的核心作者之一Aditya Ramesh等13人),不但把同时段Google发布的Gemini 1.5干没了声音,而且网上各个渠道,大量新闻媒体、自媒体(含公号、微博、博客、视频)做了大量的解读,也引发了圈内外的大量关注
    很多人因此认为,视频生成领域自此进入了大规模应用前夕,好比NLP领域中GPT3的发布
  2. 一开始,我还自以为视频生成这玩意对于有场景的人,是重大利好,比如在影视行业的
    对于没场景的人,只能当热闹看看,而且我司大模型项目开发团队去年年底还考虑过是否做视频生成的应用,但当时想了好久,没找到场景,做别的应用去了

可当我接连扒出sora相关的10多篇论文之后,觉得sora和此前发布的视频生成模型有了质的飞跃(不只是一个60s),而是再次印证了大力出奇迹,大模型似乎可以在力大砖飞的情况下开始理解物理世界了,使得我司大模型项目组也愿意重新考虑开发视频生成的相关应用

本文主要分为三个部分(初步理解只看第一部分即可,深入理解看第二部分,更多细节则看第三部分)

  • 第一部分,侧重sora的核心技术解读
    方便大家把握重点,且会比一切新闻稿都更准确,此外
    \rightarrow  如果之前没有了解过DDPM、ViT的,建议先阅读下此文《从VAE、扩散模型DDPM、DETR到ViT、Swin transformer
    \rightarrow  如果之前没有了解过图像生成的,建议先阅读下此文《从CLIP到DALLE1/2、DALLE 3、Stable Diffusion、SDXL Turbo、LCM
    当然,如果个别朋友实在不想点开看上面的两篇文章,我也尽可能在本文中把相关重点交代清楚
  • 第二部分,侧重sora相近技术的发展演变
    把sora涉及到的关键技术在本文中全部全面、深入、细致的阐述清楚,毕竟如果人云亦云就不用我来写了
    且看完这部分你会发现,从来没有任何一个火爆全球的产品是一蹴而就的,且基本都是各种创新技术的集大成者(Google很多工作把transformer等各路技术发扬光大,但OpenAI则把各路技术 整合到极致了..)
  • 第三部分,根据sora的32个reference以窥探其背后的更多细节
    由于sora实在是太火了,网上各种解读非常多,有的很专业,有的看上去一本正经 实则是胡说八道(即便他的title看起来有一定的水平),为方便大家辨别什么样的解读是不对的,特把一些更深入的细节也介绍下

如果只允许用10个字定义sora的模型结构,则可以是:潜在扩散架构下的Video Transformer

第一部分 OpenAI Sora的关键技术点

1.1 Sora的三大Transformer组件

1.1.1 从前置工作DALLE 2到sora的三大组件

为方便大家更好的理解sora背后的原理,我们先来快速回顾下AI绘画的原理(理解了AI绘画,也就理解了sora一半)

以DALLE 2为例,如下图所示(以下内容来自此文:从CLIP到DALLE1/2、DALLE 3、Stable Diffusion、SDXL Turbo、LCM)

  1. CLIP训练过程:学习文字与图片的对应关系
    如上图所示,CLIP的输入是一对对配对好的的图片-文本对(根据对应文本一条狗,去匹配一条狗的图片),这些文本和图片分别通过Text Encoder和Image Encoder输出对应的特征,然后在这些输出的文字特征和图片特征上进行对比学习
  2. DALL·E2:prior + decoder
    上面的CLIP训练好之后,就将其冻住了,不再参与任何训练和微调,DALL·E2训练时,输入也是文本-图像对,下面就是DALL·E2的两阶段训练:
    \rightarrow  阶段一 prior的训练:根据文本特征(即CLIP text encoder编码后得到的文本特征),预测图像特征(CLIP image encoder编码后得到的图片特征)
    换言之,prior模型的输入就是上面CLIP编码的文本特征,然后利用文本特征预测图片特征(说明白点,即图中右侧下半部分预测的图片特征的ground truth,就是图中右侧上半部分经过CLIP编码的图片特征),就完成了prior的训练
    推理时,文本还是通过CLIP text encoder得到文本特征,然后根据训练好的prior得到类似CLIP生成的图片特征,此时图片特征应该训练的非常好,不仅可以用来生成图像,而且和文本联系的非常紧(包含丰富的语义信息)

    \rightarrow  阶段二 decoder生成图:常规的扩散模型解码器,解码生成图像
    这里的decoder就是升级版的GLIDE(GLIDE基于扩散模型),所以说DALL·E2 = CLIP + GLIDE

所以对于DALLE 2来说,正因为经过了大量上面这种训练,所以便可以根据人类给定的prompt画出人类预期的画作,说白了,可以根据text预测画作长什么样

最终,sora由三大Transformer组件组成(如果你还不了解transformer或注意力机制,请读此文):Visual Encoder(即Video transformer,类似下文将介绍的ViViT)、Diffusion TransformerTransformer Decoder,具体而言

  1. 训练中,给定一个原始视频X
    \rightarrow  Visual Encoder将视频压缩到较低维的潜在空间(潜在空间这个概念在stable diffusion中用的可谓炉火纯青了,详见此文的第三部分)
    \rightarrow  然后把视频分解为在时间和空间上压缩的潜在表示(不重叠的3D patches),即所谓的一系列时空Patches
    \rightarrow  再将这些patches拉平成一个token序列,这个token序列其实就是原始视频的表征:visual token序列
  2. Sora 在这个压缩的潜在空间中接受训练,还是类似扩散模型那一套,先加噪、再去噪
    这里,有两点必须注意的是
    \rightarrow  1 扩散过程中所用的噪声估计器U-net被替换成了transformer结构的DiT(加之视觉元素转换成token之后,transformer擅长长距离建模,下文详述DiT)
    \rightarrow  2 视频中这一系列帧在上个过程中是同时被编码的,去噪也是一系列帧并行去噪的(每一帧逐步去噪、多帧并行去噪)
    此外,去噪过程中,可以加入去噪的条件(即text condition),这个去噪条件可以是原始视频X的描述,也可以是二次创作的prompt
    比如可以将visual tokens视为query,将text tokens作为key和value,然后类似SD那样做cross attention
  3. OpenAI 还训练了相应的Transformer解码器模型,将生成的潜在表示映射回像素空间,从而生成视频X'

你会发现,上述整个过程,其实和SD的原理是有较大的相似性(SD原理见此文《从CLIP到DALLE1/2、DALLE 3、Stable Diffusion、SDXL Turbo、LCM》的3.2节),当然,不同之处也有很多,比如视频需要一次性还原多帧、图像只需要还原一帧

网上也有不少人画出了sora的架构图,比如来自魔搭社区的

1.1.2 如何理解所谓的时空编码(含其好处)

首先,一个视频无非就是沿着时间轴分布的图像序列而已

但其中有个问题是,因为像素的关系,一张图像有着比较大的维度(比如250 x 250),即一张图片上可能有着5万多个元素,如果根据上一张图片的5万多元素去逐一交互下一张图片的5万多个元素,未免工程过于浩大(而且,即便是同一张图片上的5万多个像素点之间两两做self-attention,你都会发现计算复杂度超级高)

  1. 故为降低处理的复杂度,可以类似ViT把一张图像划分为九宫格(如下图的左下角),如此,处理9个图像块总比一次性处理250 x 250个像素维度 要好不少吧(ViT的出现直接挑战了此前CNN在视觉领域长达近10年的绝对统治地位,其原理细节详见本文开头提到的此文第4部分)

  2. 当我们理解了一张静态图像的patch表示之后(不管是九宫格,还是16 x 9个格),再来理解所谓的时空Patches就简单多了,无非就是在纵向上加上时间的维度,比如t1 t2 t3 t4 t5 t6
    而一个时空patch可能跨3个时间维度,当然,也可能跨5个时间维度

    如此,同时间段内不同位置的立方块可以做横向注意力交互——空间编码
    不同时间段内相同位置的立方块则可以做纵向注意力交互——时间编码
    (如果依然还没有特别理解,没关系,可以再看下下文第二部分中对ViViT的介绍)

可能有同学问,这么做有什么好处呢?好处太多了

  • 一方面,时空建模之下,不仅提高单帧的流畅、更提高帧与帧之间的流畅,毕竟有Transformer的注意力机制,那无论哪一帧图像,各个像素块都不再是孤立的存在,都与周围的元素紧密联系
  • 二方面,可以兼容所有的数据素材:一个静态图像不过是时间=0的一系列时空patch,不同的像素尺寸、不同的时间长短,都可以通过组合一系列 “时空patch” 得到

总之,基于 patches 的表示,使 Sora 能够对不同分辨率、持续时间和长宽比的视频和图像进行训练。在推理时,也可以可以通过在适当大小的网格中排列随机初始化的 patches 来控制生成视频的大小

DiT 作者之一 Saining Xie 在推文中提到:Sora“可能还使用了谷歌的 Patch n’ Pack (NaViT) 论文成果,使其能够适应可变的分辨率/持续时间/长宽比”


当然,ViT本身也能够处理任意分辨率(不同分辨率相当于不同长度的图片块序列),但NaViT提供了一种高效训练的方法,关于NaViT的更多细节详见下文的介绍

而过去的图像和视频生成方法通常需要调整大小、进行裁剪或者是将视频剪切到标准尺寸,例如 4 秒的视频分辨率为 256x256。相反,该研究发现在原始大小的数据上进行训练,最终提供以下好处:

  1. 首先是采样的灵活性:Sora 可以采样宽屏视频 1920x1080p,垂直视频 1920x1080p 以及两者之间的视频。这使 Sora 可以直接以其天然纵横比为不同设备创建内容。Sora 还允许在生成全分辨率的内容之前,以较小的尺寸快速创建内容原型 —— 所有内容都使用相同的模型

    图片

  2. 其次使用视频的原始长宽比进行训练可以提升内容组成和帧的质量
    其他模型一般将所有训练视频裁剪成正方形,而经过正方形裁剪训练的模型生成的视频(如下图左侧),其中的视频主题只是部分可见;相比之下,Sora 生成的视频具有改进的帧内容(如下图右侧)

    图片

1.1.3 Diffusion Transformer(DiT):扩散过程中以Transformer为骨干网络

sora不是第一个把扩散模型和transformer结合起来用的模型,但是第一个取得巨大成功的,为何说它是结合体呢

  1. 一方面,它类似扩散模型那一套流程,给定输入噪声patches(以及文本提示等调节信息),训练出的模型来预测原始的不带噪声的patches「Sora is a diffusion model, given input noisy patches (and conditioning information like text prompts), it’s trained to predict the original “clean” patches
    类似把一张图片打上各种马赛克,然后训练一个模型,让它学会去除各种马赛克,且一开始各种失败没关系,反正有原图作为ground truth,不断缩小与原图之间的差异即可
    而当把图片打上全部马赛克之后,还可以训练该模型根据prompt直接创作的能力,让它画啥就画啥
    更多细节的理解请参看此文《从VAE、扩散模型DDPM、DETR到ViT、Swin transformer
  2. 二方面,它把DPPM中的噪声估计器所用的卷积架构U-Net换成了Transformer架构

图片

总之,总的来说,Sora是一个在不同时长、分辨率和宽高比的视频及图像上训练而成的扩散模型,同时采用了Transformer架构,如sora官博所说,Sora is a diffusion transformer,简称DiT

关于DiT的更多细节详见下文第二部分介绍的DiT

1.2 基于DALLE 3的重字幕技术:提升文本-视频数据质量

1.2.1 DALLE 3的重字幕技术:为文本-视频数据集打上详细字幕

首先,训练文本到视频生成系统需要大量带有相应文本字幕的视频,研究团队将 DALL・E 3 中的重字幕(re-captioning)技术应用于视频

  1. 具体来说,研究团队首先训练一个高度描述性的字幕生成器模型,然后使用它为训练集中所有视频生成文本字幕
  2. 与DALLE 3类似,研究团队还利用 GPT 将简短的用户 prompt 转换为较长的详细字幕,然后发送到视频模型,这使得 Sora 能够生成准确遵循用户 prompt 的高质量视频

关于DALLE 3的重字幕技术更具体的细节请见此文2.3节《AI绘画原理解析:从CLIP到DALLE1/2、DALLE 3、Stable Diffusion、SDXL Turbo、LCM

2.3 DALLE 3:Improving Image Generation with Better Captions

2.3.1 为提高文本图像配对数据集的质量:基于谷歌的CoCa​微调出图像字幕生成器

2.3.1.1 什么是谷歌的CoCa

2.1.1.2 分别通过短caption、长caption微调预训练好的image captioner

2.1.1.3 为提高合成caption对文生图模型的性能:采用描述详细的长caption,训练的混合比例高达95%..

1.2.2 类似Google的W.A.L.T工作:引入auto regressive进行视频扩展

其次,如之前所述,为了保证视频的一致性,模型层不是通过多个stage方式来进行预测,而是整体预测了整个视频的latent(即去噪时非先去噪几帧,再去掉几帧,而是一次性去掉全部帧的噪声)

但在视频内容的扩展上,比如从一段已有的视频向后拓展出新视频的训练过程中可能引入了auto regressive的task,以帮助模型更好的进行视频特征和帧间关系的学习

更多可以参考Google的W.A.L.T工作,下文将详述

1.3 对真实物理世界的模拟能力

1.3.1 sora学习了大量关于3D几何的知识

OpenAI 发现,视频模型在经过大规模训练后,会表现出许多有趣的新能力。这些能力使 Sora 能够模拟物理世界中的人、动物和环境的某些方面。这些特性的出现没有任何明确的三维、物体等归纳偏差 — 它们纯粹是规模现象

  1. 三维一致性(下图左侧)
    Sora 可以生成动态摄像机运动的视频。随着摄像机的移动和旋转,人物和场景元素在三维空间中的移动是一致的
    针对这点,sora一作Tim Brooks说道,sora学习了大量关于3D几何的知识,但是我们并没有事先设定这些,它完全是从大量数据中学习到的
    图片图片
    长序列连贯性和目标持久性(上图右侧)
    视频生成系统面临的一个重大挑战是在对长视频进行采样时保持时间一致性
    例如,即使人、动物和物体被遮挡或离开画面,Sora 模型也能保持它们的存在。同样,它还能在单个样本中生成同一角色的多个镜头,并在整个视频中保持其外观
  2. 与世界互动(下图左侧)
    Sora 有时可以模拟以简单方式影响世界状态的动作。例如,画家可以在画布上留下新的笔触,这些笔触会随着时间的推移而持续,而视频中一个人咬一口面包 则面包上会有一个被咬的缺口

    图片图片

    模拟数字世界(上图右侧)
    视频游戏就是一个例子。Sora 可以通过基本策略同时控制 Minecraft 中的玩家,同时高保真地呈现世界及其动态。只需在 Sora 的提示字幕中提及 「Minecraft」,就能零样本激发这些功能

1.3.2 sora真的会模拟真实物理世界了么

对于“sora真的会模拟真实物理世界”这个问题,网上的解读非常多,很多人说sora是通向通用AGI的必经之路、不只是一个视频生成,更是模拟真实物理世界的模拟器,这个事 我个人觉得从技术的客观角度去探讨更合适,那样会让咱们的思维、认知更冷静,而非人云亦云、最终不知所云

首先,作为“物理世界的模拟器”,需要能够在虚拟环境中重现物理现实,为用户提供一个逼真且不违反「物理规律」的数字世界

比如苹果不能突然在空中漂浮,这不符合牛顿的万有引力定律;比如在光线照射下,物体产生的阴影和高光的分布要符合光影规律等;比如物体之间产生碰撞后会破碎或者弹开

其次,李志飞等人在《为什么说 Sora 是世界的模拟器?》一文中提到,技术上至少有两种方式可以实现这样的模拟器

  • 一种是通过大数据学习出一个AI系统来模拟这个世界,比如说本文讨论的 Sora
  • 另外一种是弄懂物理世界各种现象背后的数学原理,并把这些原理手工编码到计算机程序里,从而让计算机程序“渲染”出物理世界需要的各种人、物、场景、以及他们之间的互动

虚幻引擎(Unreal Engine,UE)就是这种物理世界的模拟器

  1. 它内置了光照、碰撞、动画、刚体、材质、音频、光电等各种数学模型。一个开发者只需要提供人、物、场景、交互、剧情等配置,系统就能做出一个交互式的游戏,这种交互式的游戏可以看成是一个交互式的动态视频
  2. UE 这类渲染引擎所创造的游戏世界已经能够在某种程度上模拟物理世界,只不过它是通过人工数学建模及渲染而成,而非通过模型从数据中自我学习。而且,它也没有和语言代表的认知模型连接起来,因此本质上缺乏世界常识。而 Sora 代表的AI系统有可能避免这些缺陷和局限

不同于 UE 这一类渲染引擎,Sora 并没有显式地对物理规律背后的数学公式去“硬编码”,而是通过对互联网上的海量视频数据进行自监督学习,从而能够在给定一段文字描述的条件下生成不违反物理世界规律的长视频

与 UE 这一类“硬编码”的物理渲染引擎不同,Sora 视频创作的想象力来自于它端到端的数据驱动,以及跟LLM这类认知模型的无缝结合(比如ChatGPT已经确定了基本的物理常识)

最后值得一提的是,Sora 的训练可能用了 UE 合成的数据,但 Sora 模型本身应该没有调用 UE 的能力

第二部分 Sora相近技术的发展史:ViViT、DiT、NaViT、MAGVIT v2、W.A.L.T、VideoPoet

注意,和sora相关的技术其实有非常多,但有些技术在本博客之前的文章中写过了(详见本文开头),则本部分不再重复,比如DDPM、ViT、DALLE三代、Stable Diffusion(包括潜在空间LDM)等等

2.1 视频Transformer之ViViT:视频元素token化且批处理

在具体介绍ViViT之前,先说三个在其之前的工作

  1. 业界最早是用卷积那一套处理视频,比如时空3D CNN(Learning spatiotemporal features with 3d convolutional networks),由于3D CNN比图像卷积网络需要较多的计算量,许多架构在空间和时间维度上进行卷积的因式分解和/或使用分组卷积,且最近,还通过在后续层中引入自注意力来增强模型,以更好地捕捉长程依赖性
  2. 而Transformer在NLP领域大获成功,很快便出现了将Transformer架构应用到图像领域的ViT(Vision Transformer),ViT将图片按给定大小分为不重叠的patches,再将每个patch线性映射为一个token,随位置编码和cls token(可选)一起输入到Transformer的编码器中(下图来自萝卜社长,如果不熟悉或忘了ViT的,详见此文的第4部分)

  3. 2021年的这两篇论文《Is space-time attention all you need for video understanding?》、《Video transformer network》都是基于transformer做视频理解

而Google于2021年提出的ViViT(A Video Vision Transformer)便要尝试在视频中使用ViT模型,且他们充分借鉴了之前3D CNN因式分解等工作,比如考虑到视频作为输入会产生大量的时空token,处理时必须考虑这些长范围token序列的上下文关系,同时要兼顾模型效率问题

故作者团队在空间和时间维度上分别对Transformer编码器各组件进行分解,在ViT模型的基础上提出了三种用于视频分类的纯Transformer模型,如下图所示

区别于常规的二维图像数据,视频数据相当于需在三维空间内进行采样(拓展了一个时间维度),有两种方法来将视频\mathbf{V} \in \mathbb{R}^{T \times H \times W \times C}映射到token序列\tilde{\mathbf{z}} \in \mathbb{R}^{n_{t} \times n_{h} \times n_{w} \times d}(说白了,就是从视频中提取token,而后添加位置编码并对token进行reshape得到最终Transformer的输入\mathrm{z} \in \mathbb{R}^{\mathrm{N} \times \mathrm{d}})

  • 第一种,如下图所示,将输入视频划分为token的直接方法是从输入视频剪辑中均匀采样 n_t 个帧,使用与ViT 相同的方法独立地嵌入每个2D帧(embed each 2D frame independently using the same method as ViT),并将所有这些token连接在一起

    具体地说,如果从每个帧中提取 n_{h} \cdot n_{w} 个非重叠图像块(就像 ViT 一样),那么总共将有 n_{t} \cdot n_{h} \cdot n_{w} 个token通过transformer编码器进行传递,这个过程可以被看作是简单地构建一个大的2D图像,以便按照ViT的方式进行tokenised(这点和本节开头所提到的21年那篇论文space-time attention for video所用的方式一致)
  • 第二种则是把输入的视频划分成若干个tuplet(类似不重叠的带空间-时间维度的立方体)
    每个tuplet会变成一个token(因这个tublelt的维度就是: t * h * w,故token包含了时间、宽、高)
    经过spatial temperal attention进行空间和时间建模获得有效的视频表征token

2.1.1 spatio-temporal attention

上文说过,Google在ViT模型的基础上提出了三种用于视频分类的纯Transformer模型,接下来,介绍下这三种模型

当然,由于论文中把一个没有啥技巧且计算复杂度高的模型作为模型1:简单地将从视频中提取的所有时空token,然后每个transformer层都对所有配对进行建模,类似Neimark_Video_Transformer_Network_ICCVW_2021_paper的工作(其证明了VTN可以高效地处理非常长的视频)

所以下述三个模型在论文中被分别称之为模型2、3、4

2.1.2 factorised encoder及其代码实现

第二个模型如下图所示,该模型由两个串联的transformer编码器组成:

  1. 第一个模型是空间编码器Spatial Transformer Encoder
    处理来自相同时间索引的token之间的相互作用(相当于处理同一帧画面下的各个元素,时间维度都相同了自然时间层面上没啥要处理的了,只处理空间维度),以产生每个时间索引的潜在表示,并输出cls_token
  2. 第二个transformer编码器是时间编码器Temporal Transformer Encoder
    处理时间步之间的相互作用(相当于处理不同帧,即空间维度相同但时间维度不同)。 因此,它对应于空间和时间信息的“后期融合”
    换言之,将输出的cls_token和帧维度的表征token拼接输入到时间编码器中得到最终的结果

对应的代码如下(为方便大家一目了然,我不仅给每一行代码都加上了注释,且把代码分解成了8块,每一块代码的重点都做了细致说明)

  1. 首先定义ViViT类,且定义相关变量
    1. # 定义ViViT模型类
    2. class ViViT(nn.Module):
    3. def __init__(self, image_size, patch_size, num_classes, num_frames, dim=192, depth=4, heads=3, pool='cls', in_channels=3, dim_head=64, dropout=0.,
    4. emb_dropout=0., scale_dim=4):
    5. super().__init__() # 调用父类的构造函数
    6. # 检查pool参数是否有效
    7. assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
    8. # 确保图像尺寸能被patch尺寸整除
    9. assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
    10. # 计算patch数量
    11. num_patches = (image_size // patch_size) ** 2
    12. # 计算每个patch的维度
    13. patch_dim = in_channels * patch_size ** 2
    14. # 将图像切分成patch并进行线性变换的模块
    15. self.to_patch_embedding = nn.Sequential(
    16. Rearrange('b t c (h p1) (w p2) -> b t (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
    17. nn.Linear(patch_dim, dim),
    18. )
    为方便大家理解,我得解释一下上面中这行的含义:
    Rearrange('b t c (h p1) (w p2) -> b t (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
    且为方便大家和我之前介绍ViT的文章前后连贯起来,故还是用的ViT那篇文章中的例子(此文的第4部分)
    以ViT_base_patch16为例,一张224 x 224的图片先分割成 16 x 16 的 patch ,很显然会因此而存在 (224\times 224/16\times 16)^2=196 个 patch
    且图片的长宽由原来的224  x 224 变成:14  x 14(因为224/16 = 14)
    16*1616*1616*1616*1616*1616*1616*1616*1616*1616*1616*1616*1616*1616*16
    16*16
    16*16
    16*16
    ...
    所以对于上面那行意味着可以让批次大小b=1、时间维度t=2、RGB图像的通道数c=3
    原始维度即为:
    (1, 2, 3,
    旧的长 = 224 patch_size = 16, 旧的宽 = 224 patch_size = 16),Rearrange之后的维度则变为:
    (1, 2,
    新的长14 x 新的宽14 = 196, 16 x 16 x 3 = 768)
  2. 初始化位置编码和cls token
    self.pos_embedding 的维度为(1, num_frames, num_patches + 1, dim)
    在这里,num_frames 是 t,num_patches 是 n=196,dim 是 768,因此 pos_embedding 维度为 (1,2,197,768)
    1. # 位置编码
    2. self.pos_embedding = nn.Parameter(torch.randn(1, num_frames, num_patches + 1, dim))
    3. # 空间维度的cls token
    4. self.space_token = nn.Parameter(torch.randn(1, 1, dim))
    5. # 空间变换器
    6. self.space_transformer = Transformer(dim, depth, heads, dim_head, dim * scale_dim, dropout)
    7. # 时间维度的cls token
    8. self.temporal_token = nn.Parameter(torch.randn(1, 1, dim))
    9. # 时间变换器
    10. self.temporal_transformer = Transformer(dim, depth, heads, dim_head, dim * scale_dim, dropout)
    11. # dropout层
    12. self.dropout = nn.Dropout(emb_dropout)
    13. # 池化方式
    14. self.pool = pool
    15. # 最后的全连接层,用于分类
    16. self.mlp_head = nn.Sequential(
    17. nn.LayerNorm(dim),
    18. nn.Linear(dim, num_classes)
    19. )
  3. patch嵌入和cls token的拼接
    输入数据 x 的维度在经过嵌入层后变为 (1,2,196,768)
    self.space_token 的初始维度为 (1,1,768),被复制扩展成 (1,2,1,768) 以匹配批次和时间维度
    cls_space_tokens 和 x 在patch维度上拼接后,维度变为 (1,2,197,768)
    为何拼接之后成197了呢?原因很简单,如ViT那篇文章中所述:“[class] token的维度为 [1, 768] ,通过Concat操作,[196, 768]  与 [1, 768] 拼接得到 [197, 768]”
    1. def forward(self, x):
    2. # 将输入数据x转换为patch embeddings
    3. x = self.to_patch_embedding(x)
    4. b, t, n, _ = x.shape # 获取batch size, 时间维度, patch数量
    5. # 在每个空间位置加上cls token
    6. cls_space_tokens = repeat(self.space_token, '() n d -> b t n d', b=b, t=t)
    7. x = torch.cat((cls_space_tokens, x), dim=2) # 在维度2上进行拼接
  4. 添加位置编码和应用dropout
    加上位置编码后,x 保持 (1,2,197,768) 维度不变。应用dropout后,x 的维度仍然不变
    1. x += self.pos_embedding[:, :, :(n + 1)] # 加上位置编码
    2. x = self.dropout(x) # 应用dropout
  5. 空间Transformer
    重排 x 的维度为 (2,197,768),因为 b×t=1×2=2
    空间Transformer处理后,x 的维度变为 (2,197,768)
    1. # 将(b, t, n, d)重排为((b t), n, d),为了应用空间变换器
    2. x = rearrange(x, 'b t n d -> (b t) n d')
    3. x = self.space_transformer(x) # 应用空间变换器
    4. x = rearrange(x[:, 0], '(b t) ... -> b t ...', b=b) # 把输出重排回(b, t, ...)
  6. 时间Transformer
    self.temporal_token 的初始维度为(1,1,768),被复制扩展成 (1,2,768)
    cls_temporal_tokens 和 x 在时间维度上拼接后,维度变为(1,3,768)
    1. # 在每个时间位置加上cls token
    2. cls_temporal_tokens = repeat(self.temporal_token, '() n d -> b n d', b=b)
    3. x = torch.cat((cls_temporal_tokens, x), dim=1) # 在维度1上进行拼接
    4. x = self.temporal_transformer(x) # 应用时间变换器
  7. 池化
    如果 self.pool 是 'mean',则对 x 在时间维度上取均值,结果维度变为 (1,768)
    如果不是 'mean',则直接取 x 的第一个时间维度的cls token,结果维度同样是 (1,768)
    1. # 根据pool参数选择池化方式
    2. x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
  8. 分类头
    self.mlp_head,将 (1,768) 维度的 x 转换为最终的分类结果,其维度取决于类别数num_classes,如果 num_classes 是 10,则最终输出维度为 (1,10)
    1. # 通过全连接层输出最终的分类结果
    2. return self.mlp_head(x)

2.1.3 factorised self-attention

第二个模型如下图所示,会先计算空间自注意力(token中有相同的时间索引,相当于同一帧画面上的token元素),再计算时间的自注意力(token中有相同的空间索引,相当于不同帧下同一空间位置的token,比如一直在视频的左上角那一块的token块)

  1. 具体进行空间注意力交互的方法是:将初始视频序列生成的\mathrm{z} \in \mathbb{R}^{1 \times \mathrm{n}_{\mathrm{t}} \cdot \mathrm{n}_{\mathrm{w}} \cdot \mathrm{n}_{\mathrm{h}} \cdot \mathrm{d}},通过tensor的reshape思想映射为\mathrm{z}_{\mathrm{S}} \in \mathbb{R}^{\mathrm{n}_{\mathrm{t}} \times \mathrm{n}_{\mathrm{w}} \cdot \mathrm{n}_{\mathrm{h}} \cdot \mathrm{d}},而后计算得到空间自注意力结果
  2. 同理,在时间维度上映射得到\mathrm{z}_{\mathrm{t}} \in \mathbb{R}^{\mathrm{n}_{\mathrm{w}} \cdot \mathrm{n}_{\mathrm{h}} \times \mathrm{n}_{\mathrm{t}} \cdot \mathrm{d}},从而进行时间自注意力的计算

2.1.4 factorised dot-product attention

由于实验表明空间-时间自注意力或时间-空间自注意力的顺序并不重要,所以第三个模型的结构如下图所示,一半的头仅在空间轴上计算点积注意力,另一半头则仅在时间轴上计算,且其参数数量增加了,因为有一个额外的自注意力层

不过,该模型通过利用dot-product点积注意力操作来取代因式分解factorisation操作,通过注意力计算的方式来代替简单的张量reshape。思想是对于空间注意力和时间注意力分别构建对应的键、值,如下图所示(图源自萝卜社长)

在这里插入图片描述

2.2 DiT:将 U-Net 架构换成 Transformer

2.2.1 DiT = VAE encoder + ViT + DDPM + VAE decoder

在ViT之前,图像领域基本是CNN的天下,包括扩散过程中的噪声估计器所用的U-net也是卷积架构,但随着ViT的横空出世,人们自然而然开始考虑这个噪声估计器可否用Transform架构来代替

2022年年底,William Peebles(当时在UC Berkeley,Peebles在声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】

推荐阅读
相关标签