宣物莫大于言,存形莫善于画。
--【晋】陆机
多模态数据(文本、图像、声音)是人类认识、理解和表达世间万物的重要载体。近年来,多模态数据的爆炸性增长促进了内容互联网的繁荣,也带来了大量多模态内容理解和生成的需求。与常见的跨模态理解任务不同,文到图的生成任务是流行的跨模态生成任务,旨在生成与给定文本对应的图像。这一文图生成的任务,极大地释放了AI的想象力,也激发了人类的创意。典型的模型例如OpenAI开发的DALL-E和DALL-E2。近期,业界也训练出了更大、更新的文图生成模型,例如Google提出的Parti和Imagen。
然而,上述模型一般不能用于处理中文的需求,而且上述模型的参数量庞大,很难被开源社区的广大用户直接用来Fine-tune和推理。本次,EasyNLP开源框架再次迎来大升级,集成了先进的文图生成架构Transformer+VQGAN,同时,向开源社区免费开放不同参数量的中文文图生成模型的Checkpoint,以及相应Fine-tune和推理接口。用户可以在我们开放的Checkpoint基础上进行少量领域相关的微调,在不消耗大量计算资源的情况下,就能一键进行各种艺术创作。
EasyNLP是阿里云机器学习PAI 团队基于 PyTorch 开发的易用且丰富的中文NLP算法框架,并且提供了从训练到部署的一站式 NLP 开发体验。EasyNLP 提供了简洁的接口供用户开发 NLP 模型,包括NLP应用 AppZoo 、预训练模型 ModelZoo、数据仓库DataHub等特性。由于跨模态理解和生成需求的不断增加,EasyNLP也支持各种跨模态模型,特别是中文领域的跨模态模型,推向开源社区。例如,在先前的工作中,EasyNLP已经对中文图文检索CLIP模型进行了支持(看这里)。我们希望能够服务更多的 NLP 和多模态算法开发者和研究者,也希望和社区一起推动 NLP /多模态技术的发展和模型落地。本文简要介绍文图生成的技术,以及如何在EasyNLP框架中如何轻松实现文图生成,带你秒变艺术家。本文开头的展示图片即为我们模型创作的作品。
文图生成模型简述下面以几个经典的基于Transformer的工作为例,简单介绍文图生成模型的技术。DALL-E由OpenAI提出,采取两阶段的方法生成图像。在第一阶段,训练一个dVAE(discrete variational autoencoder)的模型将256×256的RGB图片转化为32×32的image token,这一步骤将图片进行信息压缩和离散化,方便进行文本到图像的生成。第二阶段,DALL-E训练一个自回归的Transformer模型,将文本输入转化为上述1024个image token。
由清华大学等单位提出的CogView模型对上述两阶段文图生成的过程进行了进一步的优化。在下图中,CogView采用了sentence piece作为text tokenizer使得输入文本的空间表达更加丰富,并且在模型的Fine-tune过程中采用了多种技术,例如图像的超分、风格迁移等。
ERNIE-ViLG模型考虑进一步考虑了Transformer模型学习知识的可迁移性,同时学习了从文本生成图像和从图像生成文本这两种任务。其架构图如下所示:
随着文图生成技术的不断发展,新的模型和技术不断涌现。举例来说,OFA将多种跨模态的生成任务统一在同一个模型架构中。DALL-E 2同样由OpenAI提出,是DALL-E模型的升级版,考虑了层次化的图像生成技术,模型利用CLIP encoder作为编码器,更好地融入了CLIP预训练的跨模态表征。Google进一步提出了Diffusion Model的架构,能有效生成高清大图,如下所示:
在本文中,我们不再对这些细节进行赘述。感兴趣的读者可以进一步查阅参考文献。
EasyNLP文图生成模型由于前述模型的规模往往在数十亿、百亿参数级别,庞大的模型虽然能生成质量较大的图片,然后对计算资源和预训练数据的要求使得这些模型很难在开源社区广泛应用,尤其在需要面向垂直领域的情况下。在本节中,我们详细介绍EasyNLP提供的中文文图生成模型,它在较小参数量的情况下,依然具有良好的文图生成效果。
模型架构模型框架图如下图所示:
考虑到Transformer模型复杂度随序列长度呈二次方增长,文图生成模型的训练一般以图像矢量量化和自回归训练两阶段结合的方式进行。
图像矢量量化是指将图像进行离散化编码,如将256×256的RGB图像进行16倍降采样,得到16×16的离散化序列,序列中的每个image token对应于codebook中的表示。常见的图像矢量量化方法包括:VQVAE、VQVAE-2和VQGAN等。我们采用VQGAN在ImageNet上训练的f16_16384(16倍降采样,词表大小为16384)的模型权重来生成图像的离散化序列。
自回归训练是指将文本序列和图像序列作为输入,在图像部分,每个image token仅与文本序列的tokens和其之前的image tokens进行attention计算。我们采用GPT作为backbone,能够适应不同模型规模的生成任务。在模型预测阶段,输入文本序列,模型以自回归的方式逐步生成定长的图像序列,再通过VQGAN decoder重构为图像。
开源模型参数设置在EasyNLP中,我们提供两个版本的中文文图生成模型,模型参数配置如下表:
模型配置pai-painter-base-zhpai-painter-large-zh参数量(Parameters)202M433M层数(Number of Layers)1224注意力头数(Attention Heads)1216隐向量维度(Hidden Size)7681024文本长度(Text Length)3232图像序列长度(Image Length)16 x 1616 x 16图像尺寸(Image Size)256 x 256256 x 256VQGAN词表大小(Codebook Size)1638416384 模型实现在EasyNLP框架中,我们在模型层构建基于minGPT的backbone构建模型,核心部分如下所示:
self.first_stage_model = VQModel(ckpt_path=vqgan_ckpt_path).eval()
self.transformer = GPT(self.config)
VQModel的Encoding阶段过程为:
# in easynlp/appzoo/text2image_generation/model.py
@torch.no_grad()
def encode_to_z(self, x):
quant_z, _, info = self.first_stage_model.encode(x)
indices = info[2].view(quant_z.shape[0], -1)
return quant_z, indices
x = inputs['image']
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
# one step to produce the logits
_, z_indices = self.encode_to_z(x) # z_indice: torch.Size([batch_size, 256])
VQModel的Decoding阶段过程为:
# in easynlp/appzoo/text2image_generation/model.py
@torch.no_grad()
def decode_to_img(self, index, zshape):
bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
quant_z = self.first_stage_model.quantize.get_codebook_entry(
index.reshape(-1), shape=bhwc)
x = self.first_stage_model.decode(quant_z)
return x
# sample为训练阶段的结果生成,与预测阶段的generate类似,详解见下文generate
index_sample = self.sample(z_start_indices, c_indices,
steps=z_indices.shape[1],
...)
x_sample = self.decode_to_img(index_sample, quant_z.shape)
Transformer采用minGPT进行构建,输入图像的离散编码,输出文本token。前向传播过程为:
# in easynlp/appzoo/text2image_generation/model.py
def forward(self, inputs):
x = inputs['image']
c = inputs['text']
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
# one step to produce the logits
_, z_indices = self.encode_to_z(x) # z_indice: torch.Size([batch_size, 256])
c_indices = c
if self.training and self.pkeep < 1.0:
mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
device=z_indices.device))
mask = mask.round().to(dtype=torch.int64)
r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
a_indices = mask*z_indices+(1-mask)*r_indices
else:
a_indices = z_indices
cz_indices = torch.cat((c_indices, a_indices), dim=1)
# target includes all sequence elements (no need to handle first one
# differently because we are conditioning)
target = z_indices
# make the prediction
logits, _ = self.transformer(cz_indices[:, :-1])
# cut off conditioning outputs - output i corresponds to p(z_i | z_{
关注
打赏
最近更新
- 深拷贝和浅拷贝的区别(重点)
- 【Vue】走进Vue框架世界
- 【云服务器】项目部署—搭建网站—vue电商后台管理系统
- 【React介绍】 一文带你深入React
- 【React】React组件实例的三大属性之state,props,refs(你学废了吗)
- 【脚手架VueCLI】从零开始,创建一个VUE项目
- 【React】深入理解React组件生命周期----图文详解(含代码)
- 【React】DOM的Diffing算法是什么?以及DOM中key的作用----经典面试题
- 【React】1_使用React脚手架创建项目步骤--------详解(含项目结构说明)
- 【React】2_如何使用react脚手架写一个简单的页面?