GithubHelp home page GithubHelp logo

wcyy0123 / keras-attention Goto Github PK

View Code? Open in Web Editor NEW

This project forked from bubbliiiing/keras-attention

0.0 0.0 0.0 16 KB

这个仓库主要包含了LSTM、卷积神经网络中,注意力机制的实现。

Python 100.00%

keras-attention's Introduction

Attention:注意力机制在Keras当中的实现


目录

  1. 所需环境 Environment
  2. LSTM中的注意力机制
  3. Conv中的注意力机制

所需环境

tensorflow-gpu==1.13.1
keras==2.1.5

LSTM中的注意力机制

在本库中,我将注意力机制施加在LSTM的Step上,目的是注意输入进来的样本,每一个Step的重要程度。我们使用的样本数据如下:

X = [[-21.03816538   1.4249185 ]
     [  3.76040424 -12.83660875]
     [  1.           1.        ]
     [-10.17242648   5.37333323]
     [  2.97058584  -9.31965078]
     [  3.69295417   8.47650258]
     [ -6.91492102  11.00583167]
     [ -0.03511656  -1.71475966]
     [ 10.9554255   12.47562052]
     [ -5.70470182   4.70055424]]
Y = [1]

我们可以看到,当我们将attention_column设置为2的时候,第2个step的输入和当前batch的输出相同,其它step的值是随机设定的,因此网络应该需要去注意第2个step的输入,这就是我们希望他注意的情况。

Conv中的注意力机制

在卷积神经网络中,我将注意力机制施加在通道上,即,注意输入进来的特征层每一个通道的比重。利用该注意力机制,可以获得每个通道的重要程度。如下:

#---------------------------------------#
#   通道注意力机制单元
#   利用两次全连接算出每个通道的比重
#   可以连接在任意特征层后面
#---------------------------------------#
def squeeze(inputs):
    input_channels = int(inputs.shape[-1])
    x = GlobalAveragePooling2D()(inputs)

    x = Dense(int(input_channels/4))(x)
    x = Activation(relu6)(x)

    x = Dense(input_channels)(x)
    x = Activation(hard_swish)(x)

    x = Reshape((1, 1, input_channels))(x)
    x = Multiply()([inputs, x])
    return x

keras-attention's People

Contributors

bubbliiiing avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.