GithubHelp home page GithubHelp logo

kobiso / cbam-tensorflow Goto Github PK

View Code? Open in Web Editor NEW
146.0 4.0 64.0 515 KB

CBAM implementation on TensowFlow

License: MIT License

Python 98.84% Shell 1.16%
cbam senet tensorflow resnext inception-resnet-v2 inception-v4

cbam-tensorflow's Introduction

CBAM-TensorFlow

This is a Tensorflow implementation of "CBAM: Convolutional Block Attention Module". This repository includes the implementation of "Squeeze-and-Excitation Networks" as well, so that you can train and compare among base CNN model, base model with CBAM block and base model with SE block. Base CNN models are ResNext, Inception-V4, and Inception-ResNet-V2 where the implementation is revised from Junho Kim's code: SENet-Tensorflow.

If you want to use more sophisticated implementation and more base models to use, check the repository "CBAM-TensorFlow-Slim" which aims to be compatible on the TensorFlow-Slim image classification model library and support more base models.

CBAM: Convolutional Block Attention Module

CBAM proposes an architectural unit called "Convolutional Block Attention Module" (CBAM) block to improve representation power by using attention mechanism: focusing on important features and supressing unnecessary ones. This research can be considered as a descendant and an improvement of "Squeeze-and-Excitation Networks".

Diagram of a CBAM_block

Diagram of each attention sub-module

Classification results on ImageNet-1K

Prerequisites

  • Python 3.x
  • TensorFlow 1.x
  • tflearn

Prepare Data set

This repository use Cifar10 dataset. When you run the training script, the dataset will be automatically downloaded.

CBAM_block and SE_block Supportive Models

You can train and test base CNN model, base model with CBAM block and base model with SE block. You can run CBAM_block or SE_block added models in the below list by adding one argument --attention_module=cbam_block or --attention_module=se_block when you train a model.

  • Inception V4 + CBAM / + SE
  • Inception-ResNet-v2 + CBAM / + SE
  • ResNeXt + CBAM / + SE

Change Reduction ratio

To change reduction ratio, you can add an argument --reduction_ratio=8.

Train a Model

You can simply run a model by executing following scripts.

  • sh train_ResNext.sh
  • sh train_inception_resnet_v2.sh
  • sh train_inception_v4.sh

Train a model with CBAM_block

Below script gives you an example of training a model with CBAM_block.

CUDA_VISIBLE_DEVICES=0 python ResNeXt.py \
--model_name put_your_model_name \
--attention_module cbam_block  \
--reduction_ratio 8 \
--learning_rate 0.1 \
--weight_decay 0.0005 \
--momentum 0.9 \
--batch_size 128 \
--total_epoch 100 \
--attention_module cbam_block

Train a model with SE_block

Below script gives you an example of training a model with SE_block.

CUDA_VISIBLE_DEVICES=0 python ResNeXt.py \
--model_name put_your_model_name \
--attention_module cbam_block  \
--reduction_ratio 8 \
--learning_rate 0.1 \
--weight_decay 0.0005 \
--momentum 0.9 \
--batch_size 128 \
--total_epoch 100 \
--attention_module se_block

Train a model without attention module

Below script gives you an example of training a model without attention module.

CUDA_VISIBLE_DEVICES=0 python ResNeXt.py \
--model_name put_your_model_name \
--attention_module cbam_block  \
--reduction_ratio 8 \
--learning_rate 0.1 \
--weight_decay 0.0005 \
--momentum 0.9 \
--batch_size 128 \
--total_epoch 100

Related Works

Reference

Author

Byung Soo Ko / [email protected]

cbam-tensorflow's People

Contributors

kobiso avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

cbam-tensorflow's Issues

x = init + x ?

why would you want to add init to the output of the attention block ?

About your mobilenet+SE result on imagenet.

In SENet paper, the SE_Mobilenet top-1 accuracy on imagenet is 74.7%, but in your result, it is 70.03%, the margin between them is too large.
So, what caused this gap? optmizer, hyper-parameters?
Thanks.

How to implement distributed deep learning on small master-slave architecture through data parallelism approach?

I am a beginner and I would like to deploy the distributed deep learning model followed by Hadoop on a toy example. Like I want to use three Personals computers (PC), one would be work as a parameter server and the other two would be work as worker machines. Here I initially want to configure the Hadoop over the three machines (do not know exactly how it would be done on the three machines). Then distribute the data into pieces over the two worker machines via the Parameter server machine for training. Suppose I have 10 GB of data, so 5 GB would be a shift to the first worker personal computer and the other 5 GB of data would be allocated to the second PC. Then I would like to apply the data parallelism model synchronously on the data set. What would be the steps to implement distributed deep learning system on these small network machines?

TypeError: argument should be integer or bytes-like object, not 'str'

I am new in python. I am doing the following code

import matplotlib as mpl;
import matplotlib.pyplot as plt;
import numpy as np;
import gzip;
#import StringIO; Udated one is below
#import io.StringIO
from io import StringIO


def parse_header_of_csv(csv_str):
    # Isolate the headline columns:
    headline = csv_str[:csv_str.index('\n')];
    columns = headline.split(',');

    # The first column should be timestamp:
    assert columns[0] == 'timestamp';
    # The last column should be label_source:
    assert columns[-1] == 'label_source';

    # Search for the column of the first label:
    for (ci, col) in enumerate(columns):
        if col.startswith('label:'):
            first_label_ind = ci;
            break;
        pass;

    # Feature columns come after timestamp and before the labels:
    feature_names = columns[1:first_label_ind];
    # Then come the labels, till the one-before-last column:
    label_names = columns[first_label_ind:-1];
    for (li, label) in enumerate(label_names):
        # In the CSV the label names appear with prefix 'label:', but we don't need it after reading the data:
        assert label.startswith('label:');
        label_names[li] = label.replace('label:', '');
        pass;

    return (feature_names, label_names);


def parse_body_of_csv(csv_str, n_features):
    # Read the entire CSV body into a single numeric matrix:
    full_table = np.loadtxt(StringIO.StringIO(csv_str), delimiter=',', skiprows=1);

    # Timestamp is the primary key for the records (examples):
    timestamps = full_table[:, 0].astype(int);

    # Read the sensor features:
    X = full_table[:, 1:(n_features + 1)];

    # Read the binary label values, and the 'missing label' indicators:
    trinary_labels_mat = full_table[:, (n_features + 1):-1];  # This should have values of either 0., 1. or NaN
    M = np.isnan(trinary_labels_mat);  # M is the missing label matrix
    Y = np.where(M, 0, trinary_labels_mat) > 0.;  # Y is the label matrix

    return (X, Y, M, timestamps);


'''
Read the data (precomputed sensor-features and labels) for a user.
This function assumes the user's data file is present.
'''


def read_user_data(uuid):
    user_data_file = '%s.features_labels.csv.gz' % uuid;

    # Read the entire csv file of the user:
    with gzip.open(user_data_file, 'r') as fid:
        csv_str = fid.read();
        pass;

    (feature_names, label_names) = parse_header_of_csv(csv_str);
    n_features = len(feature_names);
    (X, Y, M, timestamps) = parse_body_of_csv(csv_str, n_features);

    return (X, Y, M, timestamps, feature_names, label_names);

# Press the green button in the gutter to run the script.
if __name__ == '__main__':
    uuid = '1155FF54-63D3-4AB2-9863-8385D0BD0A13';
    (X, Y, M, timestamps, feature_names, label_names) = read_user_data(uuid);

I am getting Error in the line headline = csv_str[:csv_str.index('\n')]; TypeError: argument should be integer or bytes-like object, not 'str'. I do not know, why it is coming?

visualization

Can you share some code related to visualization.thank you

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.