GithubHelp home page GithubHelp logo

请问,想改成 针对3d数据,该怎么改? 比如(batch,C, h, w, d),超分到(batch, C, H, W, D)。 about meta-sr-pytorch HOT 2 OPEN

FreshmanMa avatar FreshmanMa commented on July 18, 2024
请问,想改成 针对3d数据,该怎么改? 比如(batch,C, h, w, d),超分到(batch, C, H, W, D)。

from meta-sr-pytorch.

Comments (2)

XuecaiHu avatar XuecaiHu commented on July 18, 2024

class Pos2Weight(nn.Module):
def init(self,inC, kernel_size=3, outC=3):
super(Pos2Weight,self).init()
self.inC = inC
self.kernel_size=kernel_size
self.outC = outC
self.meta_block=nn.Sequential(
nn.Linear(3,256),
nn.ReLU(inplace=True),
nn.Linear(256,self.kernel_sizeself.kernel_sizeself.inC*self.outC)
)
def forward(self,x):

    output = self.meta_block(x)
    return output

把输出的卷积核参数从2d卷积换成3d卷积即可,self.kernel_size^2 -> self.kernel_size^3

    up_x = self.repeat_x(x)     ### the output is (N*r*r,inC,inH,inW)

    cols = nn.functional.unfold(up_x, 3,padding=1)
    scale_int = math.ceil(self.scale)

    cols = cols.contiguous().view(cols.size(0)//(scale_int**2),scale_int**2, cols.size(1), cols.size(2), 1).permute(0,1, 3, 4, 2).contiguous()

    local_weight = local_weight.contiguous().view(x.size(2),scale_int, x.size(3),scale_int,-1,3).permute(1,3,0,2,4,5).contiguous()
    local_weight = local_weight.contiguous().view(scale_int**2, x.size(2)*x.size(3),-1, 3)

    out = torch.matmul(cols,local_weight).permute(0,1,4,2,3)
    out = out.contiguous().view(x.size(0),scale_int,scale_int,3,x.size(2),x.size(3)).permute(0,3,4,1,5,2)
    out = out.contiguous().view(x.size(0),3, scale_int*x.size(2),scale_int*x.size(3))
    out = self.add_mean(out)

    return out

这里参考2d卷积的底层计算即转化成矩阵运算,把3d卷积也转化为矩阵运算即可

input_matrix_wpn 矩阵输入从(h,w) 变成(h,w,d)即可

from meta-sr-pytorch.

FreshmanMa avatar FreshmanMa commented on July 18, 2024

感谢,但是input_matrix_wpn ,我没改出来,能否帮忙修正下,感谢!
`def input_matrix_wpn_new(inH, inW, inD, scale, add_scale=True): #10, 10, 10, 2.0
'''
inH, inW: the size of the feature maps
scale: is the upsampling times
'''
#计算上采样后的输出大小
outH, outW, outD = int(scale * inH), int(scale * inW), int(scale * inD)
#### mask records which pixel is invalid, 1 valid or o invalid
#### h_offset and w_offset caculate the offset to generate the input matrix
scale_int = int(math.ceil(scale)) #向上取整数, math.ceil(1.5) = 2

h_offset = torch.ones(inH, scale_int * inW, scale_int, 1)  #10x20x2x1
mask_h = torch.zeros(inH, scale_int * inW, scale_int, 1)

w_offset = torch.ones(1, inW, scale_int * inD, scale_int)  #1x10x2   1x10x2x2
mask_w = torch.zeros(1, inW, scale_int * inD, scale_int)

d_offset = torch.ones(scale_int, 1, inD, scale_int * inH)  #1x10x2   2x1x10x2
mask_d = torch.zeros(scale_int, 1, inD, scale_int * inH)


####projection  coordinate  and caculate the offset, 如果改成3d这里加一个 d_project_coord
#h
h_project_coord = torch.arange(0, outH, 1).mul(1.0 / scale)  #[0, 0.5, 1, 3/2, 4/2, 5/2, .... 9/2]
int_h_project_coord = torch.floor(h_project_coord)           #再向下取整 [0, 0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0] #浮点数

offset_h_coord = h_project_coord - int_h_project_coord       #[0, 0.5,  0, 0.5,  0, 0.5,  0, 0.5,   0, 0.5]
int_h_project_coord = int_h_project_coord.int()              #[0, 0, 1, 1, 2, 2 , 3, 3, 4, 4] int类型

#w 同理
w_project_coord = torch.arange(0, outW, 1).mul(1.0 / scale)
int_w_project_coord = torch.floor(w_project_coord)

offset_w_coord = w_project_coord - int_w_project_coord
int_w_project_coord = int_w_project_coord.int()

#D
d_project_coord = torch.arange(0, outD, 1).mul(1.0 / scale)
int_d_project_coord = torch.floor(d_project_coord)

offset_d_coord = d_project_coord - int_d_project_coord
int_d_project_coord = int_d_project_coord.int()

####flag for   number for current coordinate LR image, 改成3d这里应该也得加
flag = 0
number = 0
#int_h_project_coord =  [0, 0, 1, 1, 2, 2 , 3, 3, 4, 4] int类型
for i in range(outH):
    if int_h_project_coord[i] == number:
        h_offset[int_h_project_coord[i], flag, flag, 0] = offset_h_coord[i] # h_offset = 10x2x2x1
        mask_h[int_h_project_coord[i], flag, flag, 0] = 1                   # mask_h = 10x2x2x1
        flag += 1
    else:
        h_offset[int_h_project_coord[i], 0, 0, 0] = offset_h_coord[i]
        mask_h[int_h_project_coord[i], 0, 0, 0] = 1
        number += 1
        flag = 1

#W
flag = 0
number = 0
for i in range(outW):
    if int_w_project_coord[i] == number:
        w_offset[0, int_w_project_coord[i], flag, flag] = offset_w_coord[i]  #w_offset= 1x10x2x2
        mask_w[0, int_w_project_coord[i], flag, flag] = 1  #mask_w = 1x10x2  #1x10x2x2
        flag += 1
    else:
        w_offset[0, int_w_project_coord[i], 0, 0] = offset_w_coord[i]
        mask_w[0, int_w_project_coord[i], 0, 0] = 1
        number += 1
        flag = 1

#D
flag = 0
number = 0
for i in range(outD):
    if int_d_project_coord[i] == number:
        d_offset[flag, 0, int_d_project_coord[i], flag] = offset_d_coord[i]  #d_offset= 2x1x10x2
        mask_d[flag, 0, int_d_project_coord[i], flag] = 1  #mask_d = 1x10x2  #2x1x10x2
        flag += 1
    else:
        d_offset[0, 0, int_d_project_coord[i], 0] = offset_d_coord[i]
        mask_d[0, 0, int_d_project_coord[i], 0] = 1
        number += 1
        flag = 1

## the size is scale_int* inH* (scal_int*inW),      #1在哪,就在哪拼接
h_offset_coord = torch.cat([h_offset] * (scale_int * inW), 3).view(-1, scale_int * inW, scale_int * inW, 1)  #10x2x2x1  拼接后的10x2x2x20
print('h_offset_coord', h_offset_coord.shape) #20x20x20x1
w_offset_coord = torch.cat([w_offset] * (scale_int * inH), 0).view(-1, scale_int * inW, scale_int * inW, 1)  #1x10x2x2  拼接后的20x10x2
print('w_offset_coord', w_offset_coord.shape) #20x20x20x1
d_offset_coord = torch.cat([d_offset] * (scale_int * inH), 1).view(-1, scale_int * inW, scale_int * inW, 1)  #2x1x10x2  拼接后的20x10x2
print('w_offset_coord', w_offset_coord.shape)  # 20x20x20x1

####
mask_h = torch.cat([mask_h] * (scale_int * inW), 3).view(-1, scale_int * inW, 1)  #10x2x2x1  10x2x2x20
print('mask_h', mask_h.shape)
mask_w = torch.cat([mask_w] * (scale_int * inH), 0).view(-1, scale_int * inW, 1)  #1x10x2x2
print('mask_w', mask_w.shape)
mask_d = torch.cat([mask_d] * (scale_int * inD), 1).view(-1, scale_int * inW, 1)  #2x1x10x2
print('mask_w', mask_w.shape)

#前面的操作把维度=1放在最后,最前面是超分后的结果大小。
pos_mat = torch.cat((h_offset_coord, w_offset_coord, d_offset_coord), 2)
print('pos_mat', pos_mat.shape) #应该得到20x20x20x3

mask_mat = torch.sum(torch.cat((mask_h, mask_w, mask_d), 2), 2).view(scale_int * inH, scale_int * inW, scale_int * inD)
print('mask_mat', mask_mat.shape)  #
mask_mat = mask_mat.eq(2) #与2进行比较,元素>2返回true,生成掩码

i = 1      #3d pos_mat 20x20x20x1, 用k 表示d
h, w, k, _ = pos_mat.size()
while (pos_mat[i][0][0][0] >= 1e-6 and i < h):
    i = i + 1

j = 1
# pdb.set_trace()
h, w, k, _ = pos_mat.size()
while (pos_mat[0][j][0][0] >= 1e-6 and j < w):
    j = j + 1

k = 1
# pdb.set_trace()
h, w, k, _ = pos_mat.size()
while (pos_mat[0][0][k][0] >= 1e-6 and j < k):
    k = k + 1

pos_mat_small = pos_mat[0:i, 0:j, 0:k, :]  #
print('pos_mat_small', pos_mat_small.shape) #2x2x2

pos_mat_small = pos_mat_small.contiguous().view(1, -1, 2) # 1x8x2
print('pos_mat_small', pos_mat_small.shape)
if add_scale:
    scale_mat = torch.zeros(1, 1)
    scale_mat[0, 0] = 1.0 / scale
    scale_mat = torch.cat([scale_mat] * (pos_mat_small.size(1)), 0)  ###(inH*inW*scale_int**2, 4)
    print('scale_mat', scale_mat.shape) #4x1
    pos_mat_small = torch.cat((scale_mat.view(1, -1, 1), pos_mat_small), 2) 

return pos_mat_small, mask_mat `

from meta-sr-pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.