GithubHelp home page GithubHelp logo

saoyan / learntopayattention Goto Github PK

View Code? Open in Web Editor NEW
162.0 162.0 48.0 1.15 MB

PyTorch implementation of ICLR 2018 paper Learn To Pay Attention (and some modification)

License: GNU General Public License v3.0

Python 100.00%

learntopayattention's Introduction

Aha! You found me! 👋


Credits

anuraghazra/github-readme-stats

Ileriayo/markdown-badges

Envoy-VC/awesome-badges

learntopayattention's People

Contributors

saoyan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

learntopayattention's Issues

training attention without compatibility scores

Hello
i want to introduce your attention model to a resnet network, my model is already done and i just want to integrate attention, i don't know if i can add a learning step focusing on C1, C2 and C3 so i want to ask if i can train my model without compatibility scores.

code crashing - with log

python train.py --attn_mode before --outf logs_before --normalize_attn --log_images

loading the dataset ...

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to CIFAR100_data/cifar-100-python.tar.gz
100.0%Extracting CIFAR100_data/cifar-100-python.tar.gz to CIFAR100_data
Traceback (most recent call last):
File "train.py", line 199, in
main()
File "train.py", line 51, in main
trainloader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=8, worker_init_fn=_init_fn)
NameError: name '_init_fn' is not defined

model1.py same as model2.py

model1.py
x = self.conv_block1(x)
x = self.conv_block2(x)
l1 = self.conv_block3(x) # /1
x = tnf.max_pool2d(l1, kernel_size=2, stride=2, padding=0) # /2
l2 = self.conv_block4(x) # /2
x = tnf.max_pool2d(l2, kernel_size=2, stride=2, padding=0) # /4
l3 = self.conv_block5(x) # /4
x = tnf.max_pool2d(l3, kernel_size=2, stride=2, padding=0) # /8
x = self.conv_block6(x) # /32
g = self.dense(x) # batch_sizex512x1x1
###can modify as
x = self.conv_block1(x)
x = self.conv_block2(x)

l1 = self.conv_block3(x) # /1
l1 = tnf.max_pool2d(l1, kernel_size=2, stride=2, padding=0) # /2

l2 = self.conv_block4(l1) # /2
l2 = tnf.max_pool2d(l2, kernel_size=2, stride=2, padding=0) # /4

l3 = self.conv_block5(l2) # /4
x = tnf.max_pool2d(l3, kernel_size=2, stride=2, padding=0) # /8

x = self.conv_block6(x) # /32
g = self.dense(x) # batch_sizex512x1x1
##################################################3
model2.py
x = self.conv_block1(x)
x = self.conv_block2(x)
x = self.conv_block3(x)
l1 = tnf.max_pool2d(x, kernel_size=2, stride=2, padding=0) # /2
l2 = tnf.max_pool2d(self.conv_block4(l1), kernel_size=2, stride=2, padding=0) # /4
l3 = tnf.max_pool2d(self.conv_block5(l2), kernel_size=2, stride=2, padding=0) # /8
x = self.conv_block6(l3) # /32
g = self.dense(x) # batch_sizex512x1x1
###can modify to
x = self.conv_block1(x)
x = self.conv_block2(x)

l1 = self.conv_block3(x) # /1
l1 = tnf.max_pool2d(l1, kernel_size=2, stride=2, padding=0) # /2

l2 = self.conv_block4(l1) # /2
l2 = tnf.max_pool2d(l2, kernel_size=2, stride=2, padding=0) # /4

l3 = self.conv_block5(l2) # /4
x = tnf.max_pool2d(l3, kernel_size=2, stride=2, padding=0) # /8

x = self.conv_block6(x) # /32
g = self.dense(x) # batch_sizex512x1x1

they are the same~~~~

Loss does not change

Hello, Thank you very much for posting this implementation of the ‘LearnToPayAttention’ paper. I was hoping you can help me with an issue I am having when running the code. When I run model 1 with (or without) attention with the default hyperparameter settings on the github page (LR = 0.1 etc) on CIFAR100, the training loss and train/test accuracy does not seem to change. Training loss is stuck at around 4.6 and test accuracy is stuck at 1%. I tried pytorch 0.4.1 and 1.0.0. Any help would be greatly appreciated. Thanks.

Visualizing attention maps

Can you include a notebook showing how to extract the attention map for a prediction given the trained model and a sample image?

Few questions regarding the attention model implementation

First of all, thank you very much for the code.
I have a few questions, hoping you can help.

  1. Why are you using sigmoid function and not softmax?
  2. Why are you using adaptive_avg_pool2d? I haven't seen it in the paper.
    I guess the first two questions are related, since normalize_attn is True when running the code. I am basically asking why is it true?
  3. in "model2.py" why are you using g = self.dense(x)? it is also a layer I haven't encountered in the paper. I see that it does convolution with the same number of input and output layers. what does it mean?

Thanks in advance for your help!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.