GithubHelp home page GithubHelp logo

Comments (6)

taolei87 avatar taolei87 commented on August 22, 2024

Hi

Since U is computed by multiplying X with weight using torch.mm()
https://github.com/taolei87/sru/blob/master/cuda_functional.py#L677

Pytorch will handle the backward gradient computation automatically (i.e. grad_u -> grad_w).

I believe TF supports the same functionality as well?

from sru.

kzjeef avatar kzjeef commented on August 22, 2024

Hi Tao,

I think TF cannot do this computation automatically, because I pass W and X as two input, and do the mm() operation inside the OP's Copmute() function(like forward() function in pytorch), do I need give TF the grad w.r.t the input, W, and X, so I need do this computation inside the GradOp's Compute() function.

I found https://github.com/musyoku/chainer-sru do the same computation (grad_u -> grad_x), and I have ask about the algorithm detail (see musyoku/chainer-sru#3 (comment)),

Since I only care about how to convert from grad_u -> grad_w, so I think what this code do is like this numpy code:

u = np.zeros((length,batch,n_out * 3)) # u shape
x = np.zeros((length, batch, n_in)) # x shape

# reshape u and x, so that can do element-wise multiplication
u_grad = u.reshape((length,batch,1,n_out * 3))
x_grad = x.reshape((length,batch,n_in,1))
# n_in ==  n_out in this case.
grad_w = (u_grad * x_grad).sum((0,1)) # reduce length, batch dim
#  so the output shape wil be [n_in, n_out * 3]

I don't sure this math is correct or not.

from sru.

kzjeef avatar kzjeef commented on August 22, 2024

Hi,

Actually TF cannot do this computation automatically, because the new defined OP take "W" as input, and do the mm inside the OP 's compute function(like forward() function in pytorch.

When do the backward, TF will give the output's gradient, and want the OP compute the input's gradient wrt this output, so it's not really know u, so I need to compute u's gradient -> grad_w.

I saw this compute in chainer-sru's code( https://github.com/musyoku/chainer-sru/blob/b7cb02397423fb7d53396b56fa0382963944eef2/sru/sru.py#L435)
have this kind of compute, and I also asked him about a simpler version of this code in this issue(musyoku/chainer-sru#3)

Translate this code to a numpy compute is like code below:

hidden_count = n_out * 3 * direction_count;
grad_u = np.zeros((length,batch,hidden_count))
grad_u_reshape = grad_u.reshape((length,batch,1,hidden_count))

x = np.zeros((length, batch, i)).reshape((length,batch,i,1))

grad_w = (grad_u_reshape * x).sum((1,0)) # reduce the length and batch dim, so the shape become [n_in, hidden_count)

I don't really understand why this compute work, does BP through time(BPTT) in SRU is a sum through all time ?

Could you give some advance of how to correct compute grad_u -> grad_w ?

Thanks.

from sru.

taolei87 avatar taolei87 commented on August 22, 2024

Hi @kzjeef

As you see in my code and musyoku's code, there are two possible ways to implement it.

(1) you define your operator to take U as input instead of W. The SRU computation then involves two steps: (a) a matrix multiplication U=Wx for which you can use TF's built-in op; (b) followed by the new defined OP you implement. This is how I did in my pytorch version.

(2) if you define your op to take W as the input. You have to compute grad_u -> grad_w and grad_u -> grad_x in your defined op.

  • Say the shape of W is (n_in, ndir, n_out, k). In my version, it is stored as a matrix (n_in, ndir*n_out*k)

  • the shape of x is (length, batch, n_in), so x_2d has shape (length*batch, n_in)

  • U is the mm of W and x_2d, so it has shape (length*batch, ndir*n_out*k); so does grad_u

  • grad_w would be x_2d.transpose() multiply grad_u:
    (n_in, length*batch) x (length*batch, ndir*n_out*k) --> (n_in, ndir*n_out*k)

  • grad_x2d would be grad_u multiply W.transpose():
    (length*batch, ndir*n_out*k) x (ndir*n_out*k, n_in) --> (length*batch, n_in)

You can do this using tensordot() or dot().

from sru.

kzjeef avatar kzjeef commented on August 22, 2024

Thanks for your detail explain.

from sru.

kzjeef avatar kzjeef commented on August 22, 2024

@taolei87

I did some gradient check in my implementation of SRU in Tensorflow, I found some gradient error between grad function and numeric gradient,
here is some data:

When k == 3:
using grad_x of the sru_bwd() 's output, I got 0.41330277919769287 gradient error
using grad_x2d which caculated by above formula, I got 0.340061485767 gradient error.

However, when k == 4:
using grad_x2d, I can get 0.000147312879562 gradient error, which is pretty good, I believe it's should be the correct gradient.

But I'm confused when k == 3, x 's gradient seems have big error using both methods.

Could you give me some guide about this ?

from sru.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.