`def conv_down(inp,oup):
return nn.Sequential(
nn.Conv2d(inp,oup,4,stride=2,padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(oup,oup,3,stride=1,padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
def conv_up(inp,oup):
return nn.Sequential(
nn.ConvTranspose2d(inp,oup,2,stride=2,padding=0),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
def conv_merge(inp,oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, stride=1, padding=0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(oup, oup, 3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(oup, oup, 3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
def conv(inp,oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
)
def convbn(in_channel, out_channel, kernel_size, stride, pad, dilation):
#no bn
return nn.Sequential(
nn.Conv2d(
in_channel,
out_channel,
kernel_size=kernel_size,
stride=stride,
padding=dilation if dilation > 1 else pad,
dilation=dilation))
class BasicBlock(nn.Module):
"""ResNet BasicBlock"""
expansion = 1
def __init__(self, c1, c2, s, downsample, p, d):
super(BasicBlock, self).__init__()
self.conv1 = nn.Sequential(convbn(c1, c2, 3, s, p, d), nn.LeakyReLU(negative_slope=0.2, inplace=True))
self.conv2 = convbn(c2, c2, 3, 1, p, d)
self.stride = s
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
out += x
return out
class Unet(nn.Module):
def init(self):
super().init()
#16,16,24,24,32
self.conv1 = conv(3, 16)
self.down1 = conv_down(16,16)
self.down2 = conv_down(16,24)
self.down3 = conv_down(24,24)
self.down4 = nn.Sequential(conv_down(24,32),
nn.Conv2d(32, 32, 3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(32, 32, 3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
self.up4 = conv_up(32,24)
self.up3 = conv_up(24,24)
self.up2 = conv_up(24,16)
self.up1 = conv_up(16,16)
self.merge4 = conv_merge(24+24,24)
self.merge3 = conv_merge(24+24,24)
self.merge2 = conv_merge(16+16,16)
self.merge1 = conv_merge(16+16,16)
def forward(self, x):
x_down = self.conv1(x) #16*320*960
x_down1 = self.down1(x_down) #16*160*480
x_down2 = self.down2(x_down1) #24*96*320
x_down3 = self.down3(x_down2) #24*48*160
x_down4 = self.down4(x_down3) #32*24*80
x_up4 = self.up4(x_down4)
if x_up4.size() != x_down3.size():
x_up4 = x_up4[:,:,:-1,:]
x_up4 = self.merge4(torch.cat((x_down3,x_up4),dim=1)) #24*48*160
x_up3 = self.up3(x_up4)
if x_up3.size() != x_down2.size():
x_up3 = x_up3[:,:,:-1,:]
x_up3 = self.merge3(torch.cat((x_down2,x_up3),dim=1)) #24*96*320
x_up2 = self.up2(x_up3)
if x_up2.size() != x_down1.size():
x_up2 = x_up2[:,:,:-1,:]
x_up2 = self.merge2(torch.cat((x_down1,x_up2),dim=1)) #16*192*640
x_up1 = self.up1(x_up2)
if x_up1.size() != x_down.size():
x_up1 = x_up1[:,:,:-1,:]
x_up1 = self.merge1(torch.cat((x_down,x_up1),dim=1)) #16*384*1280
return [x_down4,x_up4,x_up3,x_up2,x_up1]
class Initial(nn.Module):
#share the 4*4 conv
def init(self,inp,maxdisp=256,scale=8):
super().init()
self.maxdisp = maxdisp
self.scale = scale
# self.pad = torch.nn.ZeroPad2d(padding=(0, 3, 0, 0))
self.pad = torch.nn.ZeroPad2d(padding=(1, 2, 0, 0))
self.conv4_4 = nn.Conv2d(inp,16,4,stride=4)
self.conv4_4.weight = nn.Parameter(torch.randn(16,inp,4,4).cuda())
self.conv4_4.bias = nn.Parameter(torch.randn(16).cuda())
self.post = nn.Sequential(
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(16,16,1,stride=1,padding=0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
)
# self.post = nn.Sequential(
# nn.LeakyReLU(negative_slope=0.2, inplace=True),
# nn.Conv2d(16,16,1,stride=1,padding=0),
# nn.LeakyReLU(negative_slope=0.2, inplace=True),
# nn.Conv2d(16,16,1,stride=1,padding=0),
# nn.LeakyReLU(negative_slope=0.2, inplace=True),
# )
if scale == 16 or scale==8:
self.descriptor = nn.Sequential(
nn.Conv2d(17,13,1),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
elif scale == 4:
self.descriptor = nn.Sequential(
nn.Conv2d(33, 13, 1),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
elif scale == 2 or scale==1:
self.descriptor = nn.Sequential(
nn.Conv2d(25, 13, 1),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
def forward(self, fea_l,fea_r,fea_cat=None):
# print(fea_l.size(), fea_r.size())
fea_l = self.post(self.conv4_4(fea_l))
self.conv4_4.stride=[4,1]
# fea_r = self.post(self.conv4_4(self.pad(fea_r)))
fea_r = self.post(self.conv4_4(fea_r))
# print(fea_l.size(), fea_r.size())
#还原,不然下个数据会有问题
self.conv4_4.stride=[4,4]
maxdisp = self.maxdisp //(self.scale)
#看下数值是不是都是正数:
#print(fea_l.shape,(fea_l<0).sum(),(fea_r<0).sum(),fea_l.mean(),fea_r.mean())
cost,d_init ,cost_d_init = calc_init_disp(fea_l,fea_r,maxdisp)
d_init = d_init.squeeze(1)
if self.scale==16 or self.scale==8:
p_init = self.descriptor(torch.cat((cost_d_init,fea_l),1))
else:
# print(cost_d_init.size(), fea_cat.size())
p_init = self.descriptor(torch.cat((cost_d_init, fea_cat), 1))
#1*24*12*40 1*24*12*160 1*13*12*40 1*d*12*40
return [fea_l,fea_r,p_init,d_init,cost,cost_d_init]`