petr学习

petr和其它基于transformer目标检测模型的架构对比

petr模型的架构与detr模型基本一致,整体简洁优雅,避免了detr3d的一些较为复杂的流程。

petr模型首先是使用如resnet、vovnet等backbone提取图像2d特征,然后与生成的3d坐标一起进入encoder,encoder的结果与object query一起进入decoder,decoder的输出经过分类和回归的head得到bbox。

3d 坐标的生成:将图像坐标系空间,分成尺寸为(Wf, Hf, D)的meshgrid,meshgrid的点的坐标可以表示为(u × d, v × d, d)。meshgrid经过与相机内外侧得到的矩阵相乘,投影到3d世界坐标系,将在世界坐标系下的meshgrid的点,根据检测的roi区域进行归一化。就是生成的3d坐标。

3D position encoder,将3d坐标,经过一个mlp,得到position embedding,与2d feature经过一个1×1卷积后相加,得到3d position-aware feature。

3d object query:petr是初始化了在3d坐标系下从0到1均匀分布的可学习的一组anchor points,然后anchor points的坐标经过一个mlp,得到object query。然后训练时会更新这些anchor points

transformer(attention is all you need)学习

transformer最开始是做翻译的,之前看它的论文以及网上相关介绍,对一些东西没有弄明白,看了知乎的一个介绍 Transformer模型详解(图解最完整版) – 知乎 (zhihu.com) ,然后研究了一下GitHub上的代码 GitHub – jadore801120/attention-is-all-you-need-pytorch: A PyTorch implementation of the Transformer model in “Attention is All You Need”. ,算是基本把原本transformer弄懂了。在此总结一下之前没有弄明白的几个问题。

transformer总体结构

从上面的总体架构看,其实涵盖了网络的核心部分,但是有一些地方没有详细说明,也是给我造成了一些疑问,然后研究了一下代码才搞懂。首先从输入开始说明。输入是一段话,预处理就是将每个词根据词汇表映射成一个整数来表示它,我理解就是这个词的token,然后输入到网络里的就是这一串整数。网络将这串数字,首先经过torch.nn.embedding层,得到输入的embedding。

然后就是加上position_encoding,经过encoder。图中灰色的部分就是encoder_layer。encoder是n个不同的encoder_layer组成,而不是一个encoder_layer重复n次。右边的decoder也是同理。

然后就到decoder的部分。一开始会有一个初始的token,表示句子的开始,同样也是经过torch.nn.Embedding层得到embedding,然后经过decoder(需要encoder的输出),得到最后decoder的output,然后这个output经过linear层和softmax,得到一个结果,但是这里注意,这个结果并不是最后的结果,句子已经翻译完了,而只是得到下一个token,softmax输出表示下一个token的可能性。然后取可能性最大的作为预测的下一个token,将预测的token,接在之前的所有token后面,再次重复decoder的过程,直到某一个预测的token为表示句子结束的token。最后将得到的token序列,根据词汇表映射为对应的单词,完成句子的翻译。因此,翻译结果中每个单词是逐个得到而不是一次性完成的,每一个输出的单词都是由已经得到的结果,经过一次decoder得到的。

而decoder最后经过softmax得到下一个token,并不总是直接取可能性最大的token作为结果。实际上transformer维护着k个可能性最高的token序列以及它们对应的可能性得分。将这k个序列都输入decoder ,并每个仍然记录k个可能性最高的预测token,这样得到k^2的一个概率矩阵(矩阵每个元素是预测的token可能性与输入token序列的可能性相乘)。取矩阵中k个可能性最高的结果,更新维护的序列(由这些结果对应的原序列和预测token组成)及其对应的可能性得分(这些结果对应在矩阵中的概率)。

transformer预测的时候基本就是这样,然后之前还有一点没搞懂的就是它在decoder时候的mask操作。实际上在预测的时候,mask没有发挥作用,或者说mask全部没有遮挡。是在训练的时候用到的mask。前文说了,预测的时候是输入已经得到的结果序列,然后预测下一个token。但是在训练的时候,并不是每次都只输入一个不完整的序列,然后让它预测下一个token,再与对应的目标token来计算loss。而实际上是输入开始token+整个句子,目标序列是整个句子+结束token。训练时把整个输入序列都输入到decoder中,但是预测的时候输入的是不完整的序列,这里就存在差异,所以在这时候mask就派上用场了。在decoder的masked attention那里,用一个mask矩阵(下三角部分不遮挡,上三角部分遮挡)与计算得到的attention矩阵点乘,这样使得经过masked MHA之后输出的序列Z(和输入序列长度相同),每个位置的结果都只是由输入序列该位置之前的内容计算得到的,就等同于对Z的这个位置元素来说,输入就是从开头到这个位置的不完整序列。最后decoder的输出也是同理,每个位置的结果都只从输入序列该位置之前的内容计算得到。再经过linear(linear之后长度仍然和输入的长度一致)和softmax,取每个位置softmax后的结果,与该位置对应的目标token,计算loss,实现训练。(预测的时候是取最后一个位置对应的softmax的结果作为预测,前面位置的结果是无意义的)