网站建设技术 教材求职seo服务
背景:transformer在CV领域的应用
论文下载链接:https://arxiv.org/abs/2010.11929
Pytorch实现代码: pytorch_classification/vision_transformer(太阳花的小绿豆博主实现的代码)
有一些大神在研究关于CNN+transformer或者纯用transformer实现。
原文的摘要说"We show that this reliance on CNNs is not necessary and a pure transformer applied directly to sequences of image patches can perform very well on image classification tasks."(我们展示,这种对 CNN 的依赖是不必要的,直接应用于图像块序列的纯变换器可以很好地执行图像分类任务)
比较具体的内容请看太阳花的小绿豆博主的《Vision Transformer详解》,相关的图片是这个博主的,我这里直接用ONNX的模型结构进行说明,可能更加直观一点(不喜勿碰哈)
VIT整体结构图


VIT形状变化
pytorch的api:summary(model, (3, 224, 224))----------------------------------------------------------------Layer (type) Output Shape Param #
================================================================
(1) 前处理Conv2d-1 [-1, 768, 14, 14] 590,592Identity-2 [-1, 196, 768] 0PatchEmbed-3 [-1, 196, 768] 0Dropout-4 [-1, 197, 768] 0
(2) transformer encoderblock 1LayerNorm-5 [-1, 197, 768] 1,536Linear-6 [-1, 197, 2304] 1,771,776Dropout-7 [-1, 12, 197, 197] 0Linear-8 [-1, 197, 768] 590,592Dropout-9 [-1, 197, 768] 0Attention-10 [-1, 197, 768] 0Identity-11 [-1, 197, 768] 0LayerNorm-12 [-1, 197, 768] 1,536Linear-13 [-1, 197, 3072] 2,362,368GELU-14 [-1, 197, 3072] 0Dropout-15 [-1, 197, 3072] 0Linear-16 [-1, 197, 768] 2,360,064Dropout-17 [-1, 197, 768] 0Mlp-18 [-1, 197, 768] 0Identity-19 [-1, 197, 768] 0Block-20 [-1, 197, 768] 0
block 2LayerNorm-21 [-1, 197, 768] 1,536Linear-22 [-1, 197, 2304] 1,771,776Dropout-23 [-1, 12, 197, 197] 0Linear-24 [-1, 197, 768] 590,592Dropout-25 [-1, 197, 768] 0Attention-26 [-1, 197, 768] 0Identity-27 [-1, 197, 768] 0LayerNorm-28 [-1, 197, 768] 1,536Linear-29 [-1, 197, 3072] 2,362,368GELU-30 [-1, 197, 3072] 0Dropout-31 [-1, 197, 3072] 0Linear-32 [-1, 197, 768] 2,360,064Dropout-33 [-1, 197, 768] 0Mlp-34 [-1, 197, 768] 0Identity-35 [-1, 197, 768] 0Block-36 [-1, 197, 768] 0
block 3LayerNorm-37 [-1, 197, 768] 1,536Linear-38 [-1, 197, 2304] 1,771,776Dropout-39 [-1, 12, 197, 197] 0Linear-40 [-1, 197, 768] 590,592Dropout-41 [-1, 197, 768] 0Attention-42 [-1, 197, 768] 0Identity-43 [-1, 197, 768] 0LayerNorm-44 [-1, 197, 768] 1,536Linear-45 [-1, 197, 3072] 2,362,368GELU-46 [-1, 197, 3072] 0Dropout-47 [-1, 197, 3072] 0Linear-48 [-1, 197, 768] 2,360,064Dropout-49 [-1, 197, 768] 0Mlp-50 [-1, 197, 768] 0Identity-51 [-1, 197, 768] 0Block-52 [-1, 197, 768] 0
block 4LayerNorm-53 [-1, 197, 768] 1,536Linear-54 [-1, 197, 2304] 1,771,776Dropout-55 [-1, 12, 197, 197] 0Linear-56 [-1, 197, 768] 590,592Dropout-57 [-1, 197, 768] 0Attention-58 [-1, 197, 768] 0Identity-59 [-1, 197, 768] 0LayerNorm-60 [-1, 197, 768] 1,536Linear-61 [-1, 197, 3072] 2,362,368GELU-62 [-1, 197, 3072] 0Dropout-63 [-1, 197, 3072] 0Linear-64 [-1, 197, 768] 2,360,064Dropout-65 [-1, 197, 768] 0Mlp-66 [-1, 197, 768] 0Identity-67 [-1, 197, 768] 0Block-68 [-1, 197, 768] 0
block 5LayerNorm-69 [-1, 197, 768] 1,536Linear-70 [-1, 197, 2304] 1,771,776Dropout-71 [-1, 12, 197, 197] 0Linear-72 [-1, 197, 768] 590,592Dropout-73 [-1, 197, 768] 0Attention-74 [-1, 197, 768] 0Identity-75 [-1, 197, 768] 0LayerNorm-76 [-1, 197, 768] 1,536Linear-77 [-1, 197, 3072] 2,362,368GELU-78 [-1, 197, 3072] 0Dropout-79 [-1, 197, 3072] 0Linear-80 [-1, 197, 768] 2,360,064Dropout-81 [-1, 197, 768] 0Mlp-82 [-1, 197, 768] 0Identity-83 [-1, 197, 768] 0Block-84 [-1, 197, 768] 0
block 6LayerNorm-85 [-1, 197, 768] 1,536Linear-86 [-1, 197, 2304] 1,771,776Dropout-87 [-1, 12, 197, 197] 0Linear-88 [-1, 197, 768] 590,592Dropout-89 [-1, 197, 768] 0Attention-90 [-1, 197, 768] 0Identity-91 [-1, 197, 768] 0LayerNorm-92 [-1, 197, 768] 1,536Linear-93 [-1, 197, 3072] 2,362,368GELU-94 [-1, 197, 3072] 0Dropout-95 [-1, 197, 3072] 0Linear-96 [-1, 197, 768] 2,360,064Dropout-97 [-1, 197, 768] 0Mlp-98 [-1, 197, 768] 0Identity-99 [-1, 197, 768] 0Block-100 [-1, 197, 768] 0
block 7LayerNorm-101 [-1, 197, 768] 1,536Linear-102 [-1, 197, 2304] 1,771,776Dropout-103 [-1, 12, 197, 197] 0Linear-104 [-1, 197, 768] 590,592Dropout-105 [-1, 197, 768] 0Attention-106 [-1, 197, 768] 0Identity-107 [-1, 197, 768] 0LayerNorm-108 [-1, 197, 768] 1,536Linear-109 [-1, 197, 3072] 2,362,368GELU-110 [-1, 197, 3072] 0Dropout-111 [-1, 197, 3072] 0Linear-112 [-1, 197, 768] 2,360,064Dropout-113 [-1, 197, 768] 0Mlp-114 [-1, 197, 768] 0Identity-115 [-1, 197, 768] 0Block-116 [-1, 197, 768] 0
block 8LayerNorm-117 [-1, 197, 768] 1,536Linear-118 [-1, 197, 2304] 1,771,776Dropout-119 [-1, 12, 197, 197] 0Linear-120 [-1, 197, 768] 590,592Dropout-121 [-1, 197, 768] 0Attention-122 [-1, 197, 768] 0Identity-123 [-1, 197, 768] 0LayerNorm-124 [-1, 197, 768] 1,536Linear-125 [-1, 197, 3072] 2,362,368GELU-126 [-1, 197, 3072] 0Dropout-127 [-1, 197, 3072] 0Linear-128 [-1, 197, 768] 2,360,064Dropout-129 [-1, 197, 768] 0Mlp-130 [-1, 197, 768] 0Identity-131 [-1, 197, 768] 0Block-132 [-1, 197, 768] 0
block 9LayerNorm-133 [-1, 197, 768] 1,536Linear-134 [-1, 197, 2304] 1,771,776Dropout-135 [-1, 12, 197, 197] 0Linear-136 [-1, 197, 768] 590,592Dropout-137 [-1, 197, 768] 0Attention-138 [-1, 197, 768] 0Identity-139 [-1, 197, 768] 0LayerNorm-140 [-1, 197, 768] 1,536Linear-141 [-1, 197, 3072] 2,362,368GELU-142 [-1, 197, 3072] 0Dropout-143 [-1, 197, 3072] 0Linear-144 [-1, 197, 768] 2,360,064Dropout-145 [-1, 197, 768] 0Mlp-146 [-1, 197, 768] 0Identity-147 [-1, 197, 768] 0Block-148 [-1, 197, 768] 0
block 10LayerNorm-149 [-1, 197, 768] 1,536Linear-150 [-1, 197, 2304] 1,771,776Dropout-151 [-1, 12, 197, 197] 0Linear-152 [-1, 197, 768] 590,592Dropout-153 [-1, 197, 768] 0Attention-154 [-1, 197, 768] 0Identity-155 [-1, 197, 768] 0LayerNorm-156 [-1, 197, 768] 1,536Linear-157 [-1, 197, 3072] 2,362,368GELU-158 [-1, 197, 3072] 0Dropout-159 [-1, 197, 3072] 0Linear-160 [-1, 197, 768] 2,360,064Dropout-161 [-1, 197, 768] 0Mlp-162 [-1, 197, 768] 0Identity-163 [-1, 197, 768] 0Block-164 [-1, 197, 768] 0
block 11LayerNorm-165 [-1, 197, 768] 1,536Linear-166 [-1, 197, 2304] 1,771,776Dropout-167 [-1, 12, 197, 197] 0Linear-168 [-1, 197, 768] 590,592Dropout-169 [-1, 197, 768] 0Attention-170 [-1, 197, 768] 0Identity-171 [-1, 197, 768] 0LayerNorm-172 [-1, 197, 768] 1,536Linear-173 [-1, 197, 3072] 2,362,368GELU-174 [-1, 197, 3072] 0Dropout-175 [-1, 197, 3072] 0Linear-176 [-1, 197, 768] 2,360,064Dropout-177 [-1, 197, 768] 0Mlp-178 [-1, 197, 768] 0Identity-179 [-1, 197, 768] 0Block-180 [-1, 197, 768] 0
block 12LayerNorm-181 [-1, 197, 768] 1,536Linear-182 [-1, 197, 2304] 1,771,776Dropout-183 [-1, 12, 197, 197] 0Linear-184 [-1, 197, 768] 590,592Dropout-185 [-1, 197, 768] 0Attention-186 [-1, 197, 768] 0Identity-187 [-1, 197, 768] 0LayerNorm-188 [-1, 197, 768] 1,536Linear-189 [-1, 197, 3072] 2,362,368GELU-190 [-1, 197, 3072] 0Dropout-191 [-1, 197, 3072] 0Linear-192 [-1, 197, 768] 2,360,064Dropout-193 [-1, 197, 768] 0Mlp-194 [-1, 197, 768] 0Identity-195 [-1, 197, 768] 0Block-196 [-1, 197, 768] 0
(3)后处理LayerNorm-197 [-1, 197, 768] 1,536Identity-198 [-1, 768] 0Linear-199 [-1, 5] 3,845
================================================================
Total params: 85,650,437
Trainable params: 85,650,437
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 408.54
Params size (MB): 326.73
Estimated Total Size (MB): 735.84
----------------------------------------------------------------
3. 数据前处理
3*224*224经过768个16*16的卷积,输出768*14*14
将输出flatten,768*196(14*14)
调整通道196*768
添加class_num(分类信息)1*768,拼接196*768成197*768
添加位置信息pos,add(shape还是197*768)

4.数据输入到transformer encoder的onnx结构图
关于ONNX里面的op,说实话,有点hold不住,layernorm层搞得很复杂,融合暂时还没有看(后续会研究的,暂时没有时间),反正这个就是transformer encoder(我不管,这个就是)
LayerNorm-5 [-1, 197, 768]
Linear-6 [-1, 197, 2304]
Dropout-7 [-1, 12, 197, 197]
Linear-8 [-1, 197, 768]
Dropout-9 [-1, 197, 768]
Attention-10 [-1, 197, 768]
Identity-11 [-1, 197, 768]
LayerNorm-12 [-1, 197, 768]
Linear-13 [-1, 197, 3072]
GELU-14 [-1, 197, 3072]
Dropout-15 [-1, 197, 3072]
Linear-16 [-1, 197, 768]
Dropout-17 [-1, 197, 768]
Mlp-18 [-1, 197, 768]
Identity-19 [-1, 197, 768]
Block-20 [-1, 197, 768]

5.后处理
LayerNorm-197 [-1, 197, 768]
Identity-198 [-1, 768]
Linear-199 [-1, 5]
那 ,你看,这就是layernorm的op操作(不忍吐槽)

最后接上全连接层,输出结果

总结
其实从OP来看,VIT并没有添加新的算子,只是一些层的拼接,但是效果却是很好,真的,朴实无华的结构,做着深奥的内容,哎,继续学习吧,学无止境!!!相关的ONNX代码,感兴趣的读者多的话,后续可以上传,供大家试用,请关注或者评论(⊙o⊙)哦!!!

class: daisy prob: 0.995
class: dandelion prob: 0.00298
class: roses prob: 0.000599
class: sunflowers prob: 0.000633
class: tulips prob: 0.000771