GithubHelp home page GithubHelp logo

guoshnbjtu / st-3dnet Goto Github PK

View Code? Open in Web Editor NEW
7.0 7.0 4.0 272 KB

This is a Pytorch implementation of ST-3DNet. Now the corresponding paper is available at https://ieeexplore.ieee.org/abstract/document/8684259/

Python 100.00%

st-3dnet's People

Contributors

guoshnbjtu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

st-3dnet's Issues

该项目Tensorflow版本模型中,shape的变化问题

def ST3DNet(c_conf=(6, 2, 16, 8), t_conf=(4, 2, 16, 8), external_dim=8, nb_residual_unit=4):
    len_closeness, nb_flow, map_height, map_width = c_conf
    # main input
    main_inputs = []
    outputs = []
    if len_closeness > 0:
        input = Input(shape=(nb_flow, len_closeness, map_height, map_width))  # (2,t_c,h,w)
        main_inputs.append(input)
        # Conv1 3D
        conv = Conv3D(filters=64, kernel_size=(6, 3, 3), strides=(1, 1, 1), padding="same",
                      kernel_initializer='random_uniform')(input)
        conv = Activation("relu")(conv)

        # Conv2 3D
        conv = Conv3D(filters=64, kernel_size=(3, 3, 3), strides=(3, 1, 1), padding="same")(conv)
        conv = Activation("relu")(conv)

        # Conv3 3D
        conv = Conv3D(filters=64, kernel_size=(3, 3, 3), strides=(3, 1, 1), padding="same")(conv)

        # (filter,1,height,width)
        reshape = Reshape((64, map_height, map_width))(conv)

        # Residual 2D [nb_residual_unit] Residual Units
        residual_output = ResUnits(_residual_unit, nb_filter=64, repetations=nb_residual_unit)(reshape)

        output_c = Recalibration()(residual_output)
        outputs.append(output_c)

假设输入为(32, 2, 6, 50, 50)的shape, 32为batch_size,在 reshape = Reshape((64, map_height, map_width))(conv)之前shape应该为(32, 2, 64, 50, 50),再输入进ResUnits中reshape 应该是降不了维度的,以下是我尝试用Pytorch写的网络,shape的变化麻烦说明一下

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F


class ResUnit(nn.Module):
    def __init__(self, in_channels, out_channels):
        # It takes in a four dimensional input (B, C, lng, lat)
        super(ResUnit, self).__init__()
        # self.ln1 = nn.LayerNorm(normalized_shape = (lng, lat))
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
        # self.ln2 = nn.LayerNorm(normalized_shape = (lng, lat))
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
    
    def forward(self, x):
        # z = self.ln1(x)
        z = self.bn1(x)
        z = F.relu(z)
        z = self.conv1(z)
        # z = self.ln2(z)
        z = self.bn2(z)
        z = F.relu(z)
        z = self.conv2(z)
        return z + x
    

# 定义三维卷积层
class ST3DNet(nn.Module):
    def __init__(self, c_conf=(6, 2, 16, 8), t_conf=(4, 2, 16, 8), external_dim=8, nb_residual_unit=4):
        super(ST3DNet, self).__init__()
        self.len_closeness, self.nb_flow, self.h, self.w = c_conf
        self.c_net = nn.ModuleList([
            nn.Conv3d(
                in_channels=self.nb_flow, 
                out_channels=64, 
                kernel_size=(6, 3, 3), 
                stride=(1, 1, 1), 
                padding=(6 // 2, 3 // 2, 3 // 2)
                ), 
            nn.ReLU(inplace = True), 
            nn.Conv3d(
                in_channels=64, 
                out_channels=64, 
                kernel_size=(3, 3, 3), 
                stride=(3, 1, 1), 
                padding=(2, 1, 1)
                ), 
            nn.ReLU(inplace = True), 
            nn.Conv3d(
                in_channels=64, 
                out_channels=64, 
                kernel_size=(3, 3, 3), 
                stride=(3, 1, 1), 
                padding=(2, 1, 1)
                )
        ])

        self.c_net.append(ResUnit(64, 64))
        self.c_net.append(ResUnit(64, 64))
        self.c_net.append(ResUnit(64, 64))
        self.c_net.append(ResUnit(64, 64))
        
        # # Rc
        # self.c_net.append(ResUnit(128, 128))
        self.recalibration = Recalibration(32)
        self.out_conv = nn.Conv2d(128, 64, stride=1, padding=1, kernel_size=3)

    def forward(self, x):
        # 应用卷积操作
        x = self.c_net[0](x)
        x = self.c_net[1](x)
        x = self.c_net[2](x)
        x = self.c_net[3](x)
        x = self.c_net[4](x)
        x = x.view(x.size(0), -1, x.size(3), x.size(3))
        # !!!!!!!!!!!torch.Size([32, 128, 32, 32])
        x = self.out_conv(x)
        x = self.c_net[5](x)
        x = self.c_net[6](x)
        x = self.c_net[7](x)
        x = self.c_net[8](x)


        return x


conv3d = ST3DNet()

# 假设输入数据
input_data = torch.randn(32, 2, 6, 32, 32)  # (batch_size, in_channels, nb_flow, len_closeness, map_height, map_width)


output_data = conv3d(input_data)
print(output_data.shape)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.