GithubHelp home page GithubHelp logo

mkocabas / coordconv-pytorch Goto Github PK

View Code? Open in Web Editor NEW
396.0 396.0 50.0 17 KB

Pytorch implementation of CoordConv introduced in 'An intriguing failing of convolutional neural networks and the CoordConv solution' paper. (https://arxiv.org/pdf/1807.03247.pdf)

Python 100.00%

coordconv-pytorch's People

Contributors

levirve avatar valentinp avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

coordconv-pytorch's Issues

I find some erro?

you realize another channel,but i found it's wrong or i don't understand it ?
L104-you want add a new channel with r-->radius
but this xx_channel is not range 0-1;
so i change it to below and i change some place to reduce the transpose:

class AddCoords(nn.Module):

    def __init__(self, with_r=False):
        super().__init__()
        self.with_r = with_r

    def forward(self, input_tensor):
        """
        Args:
            input_tensor: shape(batch, channel, x_dim, y_dim)
        """
        batch_size, _, x_dim, y_dim = input_tensor.size()

        xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1).transpose(1, 2)
        yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1)
        print(xx_channel)

        xx_channel = xx_channel.repeat(batch_size, 1, 1, 1)
        yy_channel = yy_channel.repeat(batch_size, 1, 1, 1)

        xx_channel_01 = xx_channel.float() / (x_dim - 1)
        yy_channel_01 = yy_channel.float() / (y_dim - 1)

        xx_channel = xx_channel_01 * 2 - 1
        yy_channel = yy_channel_01 * 2 - 1

        ret = torch.cat([
            input_tensor,
            xx_channel.type_as(input_tensor),
            yy_channel.type_as(input_tensor)], dim=1)

        if self.with_r:
            rr = torch.sqrt(torch.pow(xx_channel_01.type_as(input_tensor) - 0.5, 2) + torch.pow(yy_channel_01.type_as(input_tensor) - 0.5, 2))
            ret = torch.cat([ret, rr], dim=1)

        return ret

another ways to realize it is to use "torch.meshgrid"

Error with code notation.

Hi @mkocabas,

There is a slight mismatch in the notation used to denote the extracted dimensions from the shape of the input tensor. here

batch_size, _, x_dim, y_dim = input_tensor.size()

You extract x_dim from dimension number 2 and y_dim from dimension number 3. As per the image based 4d tensors of Pytorch, it is the height that comes before width. Please refer to the documentation of Pytorch's Conv2d layer. Notice the input and output tensor shapes.

The code would run fine since the extracted dimensions end up in right place when constructing the output tensor, but for a code reader like me (implementation checking), it seems like a mistake at first glance. Perhaps, the notation can be changed as per Pytorch.
Thanks

Cheers!
@akanimax

Training time increases linearly for each epoch.

Currently, I have been using this implementation for one of my object detection models.
As the model is training, batch processing time increases linearly. I still haven't figured out why.

Does anyone see a similar behavior?

Object Detection Example

Is there an example code integrating this with FasterRCNN as mentioned in the paper?

It seems to me like this solution has potential impact on localization tasks applied to OCR. I'd be curious to see what impact it would have on the ICDAR 2013 Table Competition dataset.

See https://arxiv.org/pdf/1804.06236.pdf for Reference

Shifted radius channel?

I went through the paper on archivX (this one https://arxiv.org/pdf/1807.03247.pdf)

but could not find the reason of why the center of the radius channel is shifted by 0.5 in both x and y direction? (Is there any particular reason? My intuition is it won't matter. Please let me know, if I'm missing something.)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.