Comments (2)
class Pos2Weight(nn.Module):
def init(self,inC, kernel_size=3, outC=3):
super(Pos2Weight,self).init()
self.inC = inC
self.kernel_size=kernel_size
self.outC = outC
self.meta_block=nn.Sequential(
nn.Linear(3,256),
nn.ReLU(inplace=True),
nn.Linear(256,self.kernel_sizeself.kernel_sizeself.inC*self.outC)
)
def forward(self,x):
output = self.meta_block(x)
return output
把输出的卷积核参数从2d卷积换成3d卷积即可,self.kernel_size^2 -> self.kernel_size^3
up_x = self.repeat_x(x) ### the output is (N*r*r,inC,inH,inW)
cols = nn.functional.unfold(up_x, 3,padding=1)
scale_int = math.ceil(self.scale)
cols = cols.contiguous().view(cols.size(0)//(scale_int**2),scale_int**2, cols.size(1), cols.size(2), 1).permute(0,1, 3, 4, 2).contiguous()
local_weight = local_weight.contiguous().view(x.size(2),scale_int, x.size(3),scale_int,-1,3).permute(1,3,0,2,4,5).contiguous()
local_weight = local_weight.contiguous().view(scale_int**2, x.size(2)*x.size(3),-1, 3)
out = torch.matmul(cols,local_weight).permute(0,1,4,2,3)
out = out.contiguous().view(x.size(0),scale_int,scale_int,3,x.size(2),x.size(3)).permute(0,3,4,1,5,2)
out = out.contiguous().view(x.size(0),3, scale_int*x.size(2),scale_int*x.size(3))
out = self.add_mean(out)
return out
这里参考2d卷积的底层计算即转化成矩阵运算,把3d卷积也转化为矩阵运算即可
input_matrix_wpn 矩阵输入从(h,w) 变成(h,w,d)即可
from meta-sr-pytorch.
感谢,但是input_matrix_wpn ,我没改出来,能否帮忙修正下,感谢!
`def input_matrix_wpn_new(inH, inW, inD, scale, add_scale=True): #10, 10, 10, 2.0
'''
inH, inW: the size of the feature maps
scale: is the upsampling times
'''
#计算上采样后的输出大小
outH, outW, outD = int(scale * inH), int(scale * inW), int(scale * inD)
#### mask records which pixel is invalid, 1 valid or o invalid
#### h_offset and w_offset caculate the offset to generate the input matrix
scale_int = int(math.ceil(scale)) #向上取整数, math.ceil(1.5) = 2
h_offset = torch.ones(inH, scale_int * inW, scale_int, 1) #10x20x2x1
mask_h = torch.zeros(inH, scale_int * inW, scale_int, 1)
w_offset = torch.ones(1, inW, scale_int * inD, scale_int) #1x10x2 1x10x2x2
mask_w = torch.zeros(1, inW, scale_int * inD, scale_int)
d_offset = torch.ones(scale_int, 1, inD, scale_int * inH) #1x10x2 2x1x10x2
mask_d = torch.zeros(scale_int, 1, inD, scale_int * inH)
####projection coordinate and caculate the offset, 如果改成3d这里加一个 d_project_coord
#h
h_project_coord = torch.arange(0, outH, 1).mul(1.0 / scale) #[0, 0.5, 1, 3/2, 4/2, 5/2, .... 9/2]
int_h_project_coord = torch.floor(h_project_coord) #再向下取整 [0, 0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0] #浮点数
offset_h_coord = h_project_coord - int_h_project_coord #[0, 0.5, 0, 0.5, 0, 0.5, 0, 0.5, 0, 0.5]
int_h_project_coord = int_h_project_coord.int() #[0, 0, 1, 1, 2, 2 , 3, 3, 4, 4] int类型
#w 同理
w_project_coord = torch.arange(0, outW, 1).mul(1.0 / scale)
int_w_project_coord = torch.floor(w_project_coord)
offset_w_coord = w_project_coord - int_w_project_coord
int_w_project_coord = int_w_project_coord.int()
#D
d_project_coord = torch.arange(0, outD, 1).mul(1.0 / scale)
int_d_project_coord = torch.floor(d_project_coord)
offset_d_coord = d_project_coord - int_d_project_coord
int_d_project_coord = int_d_project_coord.int()
####flag for number for current coordinate LR image, 改成3d这里应该也得加
flag = 0
number = 0
#int_h_project_coord = [0, 0, 1, 1, 2, 2 , 3, 3, 4, 4] int类型
for i in range(outH):
if int_h_project_coord[i] == number:
h_offset[int_h_project_coord[i], flag, flag, 0] = offset_h_coord[i] # h_offset = 10x2x2x1
mask_h[int_h_project_coord[i], flag, flag, 0] = 1 # mask_h = 10x2x2x1
flag += 1
else:
h_offset[int_h_project_coord[i], 0, 0, 0] = offset_h_coord[i]
mask_h[int_h_project_coord[i], 0, 0, 0] = 1
number += 1
flag = 1
#W
flag = 0
number = 0
for i in range(outW):
if int_w_project_coord[i] == number:
w_offset[0, int_w_project_coord[i], flag, flag] = offset_w_coord[i] #w_offset= 1x10x2x2
mask_w[0, int_w_project_coord[i], flag, flag] = 1 #mask_w = 1x10x2 #1x10x2x2
flag += 1
else:
w_offset[0, int_w_project_coord[i], 0, 0] = offset_w_coord[i]
mask_w[0, int_w_project_coord[i], 0, 0] = 1
number += 1
flag = 1
#D
flag = 0
number = 0
for i in range(outD):
if int_d_project_coord[i] == number:
d_offset[flag, 0, int_d_project_coord[i], flag] = offset_d_coord[i] #d_offset= 2x1x10x2
mask_d[flag, 0, int_d_project_coord[i], flag] = 1 #mask_d = 1x10x2 #2x1x10x2
flag += 1
else:
d_offset[0, 0, int_d_project_coord[i], 0] = offset_d_coord[i]
mask_d[0, 0, int_d_project_coord[i], 0] = 1
number += 1
flag = 1
## the size is scale_int* inH* (scal_int*inW), #1在哪,就在哪拼接
h_offset_coord = torch.cat([h_offset] * (scale_int * inW), 3).view(-1, scale_int * inW, scale_int * inW, 1) #10x2x2x1 拼接后的10x2x2x20
print('h_offset_coord', h_offset_coord.shape) #20x20x20x1
w_offset_coord = torch.cat([w_offset] * (scale_int * inH), 0).view(-1, scale_int * inW, scale_int * inW, 1) #1x10x2x2 拼接后的20x10x2
print('w_offset_coord', w_offset_coord.shape) #20x20x20x1
d_offset_coord = torch.cat([d_offset] * (scale_int * inH), 1).view(-1, scale_int * inW, scale_int * inW, 1) #2x1x10x2 拼接后的20x10x2
print('w_offset_coord', w_offset_coord.shape) # 20x20x20x1
####
mask_h = torch.cat([mask_h] * (scale_int * inW), 3).view(-1, scale_int * inW, 1) #10x2x2x1 10x2x2x20
print('mask_h', mask_h.shape)
mask_w = torch.cat([mask_w] * (scale_int * inH), 0).view(-1, scale_int * inW, 1) #1x10x2x2
print('mask_w', mask_w.shape)
mask_d = torch.cat([mask_d] * (scale_int * inD), 1).view(-1, scale_int * inW, 1) #2x1x10x2
print('mask_w', mask_w.shape)
#前面的操作把维度=1放在最后,最前面是超分后的结果大小。
pos_mat = torch.cat((h_offset_coord, w_offset_coord, d_offset_coord), 2)
print('pos_mat', pos_mat.shape) #应该得到20x20x20x3
mask_mat = torch.sum(torch.cat((mask_h, mask_w, mask_d), 2), 2).view(scale_int * inH, scale_int * inW, scale_int * inD)
print('mask_mat', mask_mat.shape) #
mask_mat = mask_mat.eq(2) #与2进行比较,元素>2返回true,生成掩码
i = 1 #3d pos_mat 20x20x20x1, 用k 表示d
h, w, k, _ = pos_mat.size()
while (pos_mat[i][0][0][0] >= 1e-6 and i < h):
i = i + 1
j = 1
# pdb.set_trace()
h, w, k, _ = pos_mat.size()
while (pos_mat[0][j][0][0] >= 1e-6 and j < w):
j = j + 1
k = 1
# pdb.set_trace()
h, w, k, _ = pos_mat.size()
while (pos_mat[0][0][k][0] >= 1e-6 and j < k):
k = k + 1
pos_mat_small = pos_mat[0:i, 0:j, 0:k, :] #
print('pos_mat_small', pos_mat_small.shape) #2x2x2
pos_mat_small = pos_mat_small.contiguous().view(1, -1, 2) # 1x8x2
print('pos_mat_small', pos_mat_small.shape)
if add_scale:
scale_mat = torch.zeros(1, 1)
scale_mat[0, 0] = 1.0 / scale
scale_mat = torch.cat([scale_mat] * (pos_mat_small.size(1)), 0) ###(inH*inW*scale_int**2, 4)
print('scale_mat', scale_mat.shape) #4x1
pos_mat_small = torch.cat((scale_mat.view(1, -1, 1), pos_mat_small), 2)
return pos_mat_small, mask_mat `
from meta-sr-pytorch.
Related Issues (20)
- Have you debugged it yet?
- Meta-Upscale Module
- meta-upscale
- meta-upscale的输入
- 请问输入矩阵为什么需要mask
- Meta-upscale的实现 HOT 3
- RuntimeError: cuda runtime error (2) HOT 5
- Trying to train Meta-RCAN but failed HOT 2
- Testing directories HOT 3
- rewrite dataloader for more recnt Pytorch
- meta-learning for weight prediction
- dataloader error, help plz~
- Higher PSNR when i use pretrained model?
- 请问怎样运行 geberate_LR_metasr_X1_X4.m 文件?
- Pretrained models
- 如何将MetaUpSampler 改成适用于3d图像的上采样?
- 你好,能帮忙指点下吗? 改成3d 后 pos_mat_small 维度不是Scale x Scale x Scale x 3的维度? h_offset这需要改吗? HOT 3
- 你好,cols = nn.functional.unfold(up_x.permute(0, 2, 3, 1), self.kernel_size, padding=1),该咋改呀? HOT 1
- Pre-training model selection for testing
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from meta-sr-pytorch.