F
(
m
∗
m
,
r
∗
r
)
F(m*m,r*r)
F(m∗m,r∗r): 一个
(
m
+
r
−
1
)
∗
(
m
+
r
−
1
)
(m+r-1)*(m+r-1)
(m+r−1)∗(m+r−1)的输入特征图和一个
r
∗
r
r*r
r∗r的卷积核进行2d卷积得到
m
∗
m
m*m
m∗m的输出,若采用直接卷积,则需要
m
2
r
2
m^2r^2
m2r2个乘法,而若采用winograd算法,则乘法数量减少为
(
m
+
r
−
1
)
∗
(
m
+
r
−
1
)
(m+r-1)*(m+r-1)
(m+r−1)∗(m+r−1),具体计算过程如下: 其中,g为
r
∗
r
r*r
r∗r的卷积核,d为
n
∗
n
n*n
n∗n的输入
t
i
l
e
(
n
=
m
+
r
−
1
)
tile(n=m+r-1)
tile(n=m+r−1),而G,B,A分别是卷积核变换矩阵、输入特征变换矩阵以及逆变换矩阵,对于
F
(
2
∗
2
,
3
∗
3
)
F(2*2,3*3)
F(2∗2,3∗3),
A
,
G
,
B
A,G,B
A,G,B的值分别如下:
现有
N
∗
H
∗
W
的
输
入
特
征
图
和
M
∗
N
∗
r
∗
r
(
r
=
3
)
N*H*W的输入特征图和M*N*r*r(r=3)
N∗H∗W的输入特征图和M∗N∗r∗r(r=3)的卷积核,使用Winograd算法进行计算,则步骤如下 我们选用
F
(
2
∗
2
,
3
∗
3
)
F(2*2,3*3)
F(2∗2,3∗3)进行计算,设
g
k
,
t
g_{k,t}
gk,t为第k个输出通道第t个输入通道的卷积核,我们对输入特征图进行4x4的分块,滑动步长为m=2,则设
d
t
,
b
d_{t,b}
dt,b表示第t个输入通道的第b个tile,那么我们有:
U
k
,
t
=
G
g
k
,
t
G
T
U_{k,t}=Gg_{k,t}G^T
Uk,t=Ggk,tGT
V
t
,
b
=
B
T
d
t
,
b
B
V_{t,b}=B^Td_{t,b}B
Vt,b=BTdt,bB 易知 对每个给定的
k
,
t
,
b
k,t,b
k,t,b,
U
k
,
t
和
V
t
,
b
U_{k,t}和V_{t,b}
Uk,t和Vt,b都是
n
∗
n
n*n
n∗n的矩阵,现在我们固定这
n
∗
n
n*n
n∗n个元素的横纵坐标,让
k
,
t
,
b
k,t,b
k,t,b变化,则我们有如下
n
∗
n
n*n
n∗n对矩阵
U
k
,
t
i
,
j
,
V
t
,
b
i
,
j
,
(
i
,
j
=
0
,
1
,
.
.
.
,
n
−
1
)
U^{i,j}_{k,t},V^{i,j}_{t,b},(i,j=0,1,...,n-1)
Uk,ti,j,Vt,bi,j,(i,j=0,1,...,n−1) 这
n
∗
n
n*n
n∗n个矩阵分别相乘,得到
Q
k
,
b
i
,
j
=
U
k
,
t
i
,
j
∗
V
t
,
b
i
,
j
Q^{i,j}_{k,b}=U^{i,j}_{k,t}*V^{i,j}_{t,b}
Qk,bi,j=Uk,ti,j∗Vt,bi,j 再将
Q
Q
Q进行逆变换,得到
Y
k
,
b
=
A
T
Q
k
,
b
A
Y_{k,b}=A^TQ_{k,b}A
Yk,b=ATQk,bA 此时Y的维度为
k
∗
b
∗
m
∗
m
k*b*m*m
k∗b∗m∗m 经过变换,可以最终得到输出特征图
O
(
k
,
h
o
,
w
o
)
O(k,h_o,w_o)
O(k,ho,wo)
'''
A^T=[1 1 1 0
0 1 -1 1]
G=[ 1 0 0
0.5 0.5 0.5
0.5 -0.5 0.5
0 0 1]
B^T=[1 0 -1 0
0 1 1 0
0 -1 1 0
0 -1 0 1]
Y=A^T[(GgG^T)*(B^TdB)]A
U(k,t)=Gg(k,t)G^T
V(t,b)=B^Td(t,b)B
Q(i,j)=U(i,j)*V(i,j) i=0,1,...,n-1,j=0,1,...,n-1,shape=(k,b)
gather i,j to shape Q(k,b,n,n)
Y=A^TQA shape Y(k,b,m,m)
'''
import numpy as np
#F(2,2,3,3)
def FeatureTransform(x):
#x(4,4)
B=np.array([[1,0,0,0],[0,1,-1,-1],[-1,1,1,0],[0,0,0,1]])
return np.matmul(np.matmul(B.T,x),B)
def FilterTransform(x):
#x(3,3)
G=np.array([[1,0,0],[0.5,0.5,0.5],[0.5,-0.5,0.5],[0,0,1]])
return np.matmul(np.matmul(G,x),G.T)
def ReverseTransform(x):
A=np.array([[1,0],[1,1],[1,-1],[0,1]])
return np.matmul(np.matmul(A.T,x),A)
def conv2d(x,w):
n,_=x.shape
r,_=w.shape
m=n-r+1
out=np.zeros((m,m))
for i in range(m):
for j in range(m):
out[i,j]=np.sum(np.multiply(w,x[i:i+r,j:j+r]))
return out
def conv(x,f):
c,h,w=x.shape
n,c,k,k=f.shape
h_o=(h-k+1)
w_o=(w-k+1)
out=np.zeros((n,h_o,w_o))
for i in range(h_o):
for j in range(w_o):
for nn in range(n):
out[nn,i,j]=np.sum(np.multiply(f[nn,:,:,:],x[:,i:i+k,j:j+k]))
return out
def winograd1(x,w):
U=FilterTransform(w)
V=FeatureTransform(x)
Q=np.multiply(U,V)
return ReverseTransform(Q)
def wingrad2(x,f):
#F(2,2,3,3)
n,h,w=x.shape
m,n,kx,ky=f.shape
if kx!=3 or ky!=3:
return None
h_o,w_o=h-kx+1,w-ky+1
l=4
b=int(h_o/2*w_o/2) #tile个数
U=np.zeros((m,n,l,l))
V=np.zeros((n,b,l,l))
Q=np.zeros((m,b,l,l))
Y=np.zeros((m,b,2,2))
#transform filter
for mm in range(m):
for nn in range(n):
U[mm,nn,:,:]=FilterTransform(f[mm,nn,:,:])
#transform feature
for nn in range(n):
for bb in range(b):
#计算第b个tile的起始地址,tile stride=m=2
r=int(bb)//(int(w_o/2))
c=int(bb)%(int(w_o/2))
tile=x[nn,r*2:r*2+l,c*2:c*2+l]
V[nn,bb,:,:]=FeatureTransform(tile)
#MM U(m,n,l,l)*V(n,b,l,l)
for i in range(l):
for j in range(l):
Q[:,:,i,j]=np.matmul(U[:,:,i,j],V[:,:,i,j])
#reverse transform Q(m,b,l,l)->Y(m,b,2,2)
for mm in range(m):
for bb in range(b):
Y[mm,bb,:,:]=ReverseTransform(Q[mm,bb,:,:])
#restore to O[m,h_o,w_o]
O=np.zeros((m,h_o,w_o))
for mm in range(m):
for bb in range(b):
for i in range(2):
for j in range(2):
r=bb//int(w_o/2)
c=bb%int(w_o/2)
O[mm,2*r:2*r+2,2*c:2*c+2]=Y[mm,bb,:,:]
return O
d=np.random.randint(low=0,high=30, size=(4,4))
g=np.random.randint(low=0,high=20,size=(3,3))
print(winograd1(d,g))
print(conv2d(d,g))
ch_in=8
ch_out=10
height=8
width=6
X=np.random.randint(low=0,high=100,size=(ch_in,height,width))
W=np.random.randint(low=0,high=100,size=(ch_out,ch_in,3,3)) #6-3+1=4,共2x2个tile
print(np.max(np.abs(conv(X,W)-wingrad2(X,W))))
#print(wingrad2(X,W))
测试,运行结果如下 0.0说明功能无误!
变换矩阵如图所示 相应的python代码(已解决分块不整除的问题)
'''
Y=A^T[(GgG^T)*(B^TdB)]A
U(k,t)=Gg(k,t)G^T
V(t,b)=B^Td(t,b)B
Q(i,j)=U(i,j)*V(i,j) i=0,1,...,n-1,j=0,1,...,n-1,shape=(k,b)
gather i,j to shape Q(k,b,n,n)
Y=A^TQA shape Y(k,b,m,m)
'''
import math
import numpy as np
#F(4,4,3,3)
#input tile shape=(6,6)
def FeatureTransform(x):
#x(6,6)
B=np.array([[ 4, 0, 0, 0, 0, 0],
[ 0,-4, 4,-2, 2, 4],
[-5,-4,-4,-1,-1, 0],
[ 0, 1,-1, 2,-2,-5],
[ 1, 1, 1, 1, 1, 0],
[ 0, 0, 0, 0, 0, 1]])
return np.matmul(np.matmul(B.T,x),B)
def FilterTransform(x):
#x(3,3)
G=np.array([[ 1/4, 0, 0],
[-1/6,-1/6,-1/6],
[-1/6, 1/6,-1/6],
[1/24,1/12, 1/6],
[1/24,-1/12,1/6],
[ 0, 0, 1]])
return np.matmul(np.matmul(G,x),G.T)
def ReverseTransform(x):
A=np.array([[1, 0,0, 0],
[1, 1,1, 1],
[1,-1,1,-1],
[1, 2,4, 8],
[1,-2,4,-8],
[0, 0,0, 1]])
return np.matmul(np.matmul(A.T,x),A)
def conv2d(x,w):
n,_=x.shape
r,_=w.shape
m=n-r+1
out=np.zeros((m,m))
for i in range(m):
for j in range(m):
out[i,j]=np.sum(np.multiply(w,x[i:i+r,j:j+r]))
return out
def conv(x,f):
c,h,w=x.shape
n,c,k,k=f.shape
h_o=(h-k+1)
w_o=(w-k+1)
out=np.zeros((n,h_o,w_o))
for i in range(h_o):
for j in range(w_o):
for nn in range(n):
out[nn,i,j]=np.sum(np.multiply(f[nn,:,:,:],x[:,i:i+k,j:j+k]))
return out
def winograd1(x,w):
U=FilterTransform(w)
V=FeatureTransform(x)
Q=np.multiply(U,V)
return ReverseTransform(Q)
def winograd2(x,f):
#F(4,4,3,3)
n,h,w=x.shape
m,n,kx,ky=f.shape
if kx!=3 or ky!=3:
return None
h_o,w_o=h-kx+1,w-ky+1
l=6 #m+r-1=6
b=int(math.ceil(h_o/4)*math.ceil(w_o/4)) #tile个数
U=np.zeros((m,n,l,l))
V=np.zeros((n,b,l,l))
Q=np.zeros((m,b,l,l))
Y=np.zeros((m,b,4,4))
#transform filter
for mm in range(m):
for nn in range(n):
U[mm,nn,:,:]=FilterTransform(f[mm,nn,:,:])
#transform feature
for nn in range(n):
for bb in range(b):
#计算第b个tile的起始地址,tile stride=m=4
r=int(bb)//(int(math.ceil(w_o/4)))
c=int(bb)%(int(math.ceil(w_o/4)))
if r*4+l=h_o and 4*c+4>=w_o:
O[mm,4*r:h_o,4*c:w_o]=Y[mm,bb,0:h_o-4*r,0:w_o-4*c]
elif 4*r+4>=h_o and 4*c+4
关注
打赏
最近更新
- 深拷贝和浅拷贝的区别(重点)
- 【Vue】走进Vue框架世界
- 【云服务器】项目部署—搭建网站—vue电商后台管理系统
- 【React介绍】 一文带你深入React
- 【React】React组件实例的三大属性之state,props,refs(你学废了吗)
- 【脚手架VueCLI】从零开始,创建一个VUE项目
- 【React】深入理解React组件生命周期----图文详解(含代码)
- 【React】DOM的Diffing算法是什么?以及DOM中key的作用----经典面试题
- 【React】1_使用React脚手架创建项目步骤--------详解(含项目结构说明)
- 【React】2_如何使用react脚手架写一个简单的页面?