当前位置:   article > 正文

Rich Human Feedback for Text-to-Image Generation 读论文笔记

Rich Human Feedback for Text-to-Image Generation 读论文笔记

摘要

Motivation:探索如何优化如Stable Diffusion T2I生成模型的优化问题,因为这些模型都会有诸如伪影,与文字描述不匹配和美学质量低等问题。本文参考大语言模型强化学习的方式,训练奖励模型来改进生成模型。

Contribusion:在收集的数据集(RichHF18K)收集feedback,通过选择高质量的训练数据和改进来生成模型,或者使用预测的heatmap来创建掩码,修复有问题的区域。

  • Rich Human Feedback dataset
  • 一个多模态Transformer模型对生成的图像进行丰富的反馈预测
  • improve method:方式:1. 标记有问题的图像区域 2.标记文本描述不匹配的prompt(被误报或漏报)3. 使用预测的分数来帮助微调图像生成模型

方法细节

收集数据的过程

RichHF-18K数据集
每个图片包含的标注和分数:

  • 图像高度的1 / 20为半径标记伪影和错位标注(两个heatmap,implausibility and misalignment heatmap)。
  • 没有对齐的关键词的标注
  • 四个细粒度的分数(合理性、一致性、美观性、总体评分)
    在这里插入图片描述

人类反馈确认

每个图像-文本对由三个标注员进行注释,所以对于分数直接做平均,文字对齐标注采取多数原则,点标注使用每个点区域的平均值(每个点被转换为热图上的一个磁盘区域,然后计算三个热图之间的平均热图)

数据集

在Pick-a-Pic dataset数据集选取的子集。选取的部分是照片等级的图像。为了平衡类别,使用PaLI visual question answering (VQA) model从Pick - a - Pic数据样本中提取一些基本特征。

VQA

是一种用于能够结合大语言模型和图像理解技术的多模态模型。

使用方法

输入问题:

  1. 图像有真实感吗
  2. 那个类别最能描述图像?在"人"、“动物”、“物”、“室内场景”、"室外场景"中任选其一

18K的数据集,16K作为训练集,1K作为验证,1K作为测试。

数据分析

分数统计

s − s min s max − s min = s − 1 5 − 1 \frac{s - s_{\text{min}}}{s_{\text{max}} - s_{\text{min}}} = \frac{s - 1}{5 - 1} smaxsminssmin=51s1
得到的分布如下:
在这里插入图片描述
基本符合高斯分布

评价一致性(pair alignment)

maxdiff = max ⁡ ( scores ) − min ⁡ ( scores ) \text{maxdiff} = \max(\text{scores}) - \min(\text{scores}) maxdiff=max(scores)min(scores)

在这里插入图片描述

实验

模型

模型架构

在这里插入图片描述
这个架构中有两个计算流,分别关注视觉和文本的部分,使用的架构分别是VIT和T5X。

文本信息通过对齐程度和heatmap传递给图像token,视觉信息传递给文本token用于视觉感知。使用WebLi预训练模型。

  1. 生成的图像输入ViT,然后在输出的地方成为高级表征,text则是嵌入成dense向量。
  2. 将两种token经过T5X的自注意力级联编码
  3. 编码后使用三种预测器来预测不同的输出。
typeoperate
heatmap输入:图像token 经过卷积反卷积和sigmoid 输出:不可信和heatmap
score输入:feature map 经过卷积,线性和sigmoid 输出:细粒度scores
misalignment输入:原始caption,target:修改的caption 使用T5X的解码器,不对齐的用后缀0表示,e.g.:如果生成的图像中包含黑猫,且黄色单词与图像不对齐,则为黄色0猫。

模型变体

  1. Multi-head 每个评分,heat map和misalignment有一个头对应,共七个
  2. 对每个预测类型使用单个头,即总共3个头,分别用于热图、得分和misalignment。
    在实验中,第二种方法具体操作是:增加能够让模型判断输出类型的prompt,比如如 “implausibility heatmap”,这样能够明确任务类型。通过将这种prompt与相应的任务进行结合,单个热图(得分)头就可以预测不同的热图(得分)。能够在有些任务中得到比第一种更好的结果。

模型其他优化

损失函数是热图MSE损失、评分MSE和序列CE的加权组合。

实验

针对三种标注和打分的方法:

Metrics

Score

使用的系数:Pearson线性相关系数(PLCC)和斯皮尔曼等级相关系数(SRCC)。PLCC测量预测和真实分数之间的线性相关性,表明预测以线性方式近似实际分数的程度。SRCC测量预测和实际分数之间的关系可以使用单调函数来描述,重点是排名顺序而不是确切值。

Heatmap

标准的显着性热图评估指标,如归一化扫描路径显着性(NSS),Kullback-Leibler发散(KLD)

Misalignment

Token-level precision, recall, 和 F1-score.精度测量预测的未对齐关键字的准确性(即,正确的预测关键字的比例),查全率测量完整性(即,被正确预测的实际未对齐关键字的比例),而F1-score通过计算它们的调和平均值来提供精确度和召回率之间的平衡。

量化结果

Score

在这里插入图片描述

Heatmap

在这里插入图片描述
在这里插入图片描述

Misalignment

在这里插入图片描述

表1和表3中变体都超过了ResNet50,表2中多头版本不如resnet50,但是三头版本优于resnet50。
作者在这里预测的原因是:可能在多头版本中,所有7个预测任务都使用相同的prompt(相对于3头版本),因此所有任务的特征图和文本标记都是相同的。在这些任务之间找到一个好的折衷可能并不容易,因此一些任务如伪影/不可信热图的性能会变得更差。
注意到misalignment heat map预测通常比伪影heatmap预测的结果更差,这可能是因为错配区域的定义较少,因此注释可能更嘈杂。

定性分析

在这里插入图片描述
在这里插入图片描述

从反馈中学习

研究从这些反馈中能不能学到知识用于改善图像生成。
使用基于遮蔽变换器架构的Muse模型作为改进的目标。

首先,我们使用预训练的Muse模型为12,564个prompt(通过PaLM 2生成的提示集)生成了八张图像。我们为每张图像预测RAHF分数,如果每个提示生成的图像中最高分超过一个固定阈值,它将被选为我们微调数据集的一部分。然后,Muse模型与这个数据集一起进行微调。前后对比:在这里插入图片描述
量化Muse微调的收益:作者使用100个新提示生成图像,并请6名注释者进行两张图像的并排比较,这两张图像分别来自原始的Muse和微调后的Muse。注释者在不知道哪个模型用于生成图像A/B的情况下,从五种可能的反应中选择(图像A明显/稍微好于图像B,大致相同,图像B稍微/明显好于图像A)。表5的结果显示,与原始Muse相比,经过RAHF可信度分数微调的Muse具有显著更少的人工痕迹/不可信之处。
在这里插入图片描述
展示了一个使用RAHF审美分数作为分类器指导对潜在扩散模型的示例
在这里插入图片描述
对于每张图像,首先预测不可信度heatmap,然后通过处理heatmap(使用阈值和扩张)创建一个掩码。在掩码区域内应用Muse修复,生成与文本提示相匹配的新图像。生成多张图像,最终图像由我们的RAHF预测的最高可信度分数选择。

在这里插入图片描述
总结来说就是使用训练的模型来判断生成模型中不合理的地方,并使用掩码模型做遮蔽处理,好让模型重新生成有问题的部位,类似图像编辑的内容。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小桥流水78/article/detail/754126
推荐阅读
相关标签
  

闽ICP备14008679号