GithubHelp home page GithubHelp logo

Comments (11)

MikeInnes avatar MikeInnes commented on May 18, 2024

This is trivial with normal slicing:

julia> W = randn(3, 10)
3×10 Array{Float64,2}:
  1.52921     0.34536   -1.54119-0.303022    0.123047    0.00684176
 -0.138502    0.4285    -0.841863      0.0720752  -2.19672    -0.0968025 
 -0.00694323  0.916636   0.834962     -0.239865    0.0237941   0.0236618 

julia> W[:, 5]
3-element Array{Float64,1}:
 -0.972017
  0.293653
 -1.20708 

julia> W[:, [5, 6]]
3×2 Array{Float64,2}:
 -0.972017   2.58503
  0.293653  -2.29474
 -1.20708    1.82632

julia> W[:, [5 6; 7 8]]
3×2×2 Array{Float64,3}:
[:, :, 1] =
 -0.972017  -0.914234
  0.293653   0.333886
 -1.20708   -0.386469

[:, :, 2] =
  2.58503  -0.303022 
 -2.29474   0.0720752
  1.82632  -0.239865

There may be some benefit to having a convenience wrapper though. Should be easy to put together if you want to set up a PR.

from flux.jl.

oxinabox avatar oxinabox commented on May 18, 2024

Better still is to use the OneHot magic, I think.
This is like expressly what it is for?
Since onehot encoded value takes up no more space than an Int, and onehot multiplication is slicing under the hood.
http://fluxml.github.io/Flux.jl/latest/data/onehot.html

from flux.jl.

ngphuoc avatar ngphuoc commented on May 18, 2024

Thanks. I've tried onehot and got the follow error:

julia> x = onehot(1, 1:10)
10-element Flux.OneHotVector:
  true
 false
 false
 false
 false
 false
 false
 false
 false
 false

julia> m = Chain(Dense(10, 5), Dense(5, 2))
Chain(Dense(10, 5), Dense(5, 2))

julia> m(x)
ERROR: MethodError: *(::TrackedArray{…,Array{Float64,2}}, ::Flux.OneHotVector) is ambiguous. Candidates:
  *(A::AbstractArray{T,2} where T, b::Flux.OneHotVector) in Flux at /home/phuoc/.julia/v0.6/Flux/src/onehot.jl:10
  *(a::Flux.Tracker.TrackedArray{T,2,A} where A where T, b::AbstractArray{T,1} where T) in Flux.Tracker at /home/phuoc/.julia/v0.6/Flux/src/tracker/lib.jl:67
Possible fix, define
  *(::Flux.Tracker.TrackedArray{T,2,A} where A where T, ::Flux.OneHotVector)
Stacktrace:
 [1] (::Flux.Dense{Base.#identity,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}})(::Flux.OneHotVector) at /home/phuoc/.julia/v0.6/Flux/src/layers/basic.jl:61
 [2] mapfoldl_impl(::Base.#identity, ::Flux.##45#46, ::Flux.OneHotVector, ::Array{Any,1}, ::Int64) at ./reduce.jl:43
 [3] (::Flux.Chain)(::Flux.OneHotVector) at /home/phuoc/.julia/v0.6/Flux/src/layers/basic.jl:28

from flux.jl.

ngphuoc avatar ngphuoc commented on May 18, 2024

I also tried to add the following model but failed to train since there was no parameters returned by params. Did I miss something?:

julia> using Flux
julia> using Flux: onehotbatch, unstack, truncate!, throttle, logloss, initn
julia> using Flux.Tracker: param, back!, data, grad
julia> struct LangModel{E,R,F}
         emb::E
         rnn::R
         fc::F
       end

julia> LangModel(v::Int, e::Int, h::Int) = LangModel(param(initn(e, v)),
                                      LSTM(e, h),
                                      Dense(h, v))
LangModel

julia> Flux.children(m::LangModel) = (m.emb, m.rnn, m.fc,)

julia> (m::LangModel)(x) = softmax(m.fc(m.rnn(m.emb[:,x])))
julia> m = LangModel(2,3,4)
LangModel{TrackedArray{…,Array{Float64,2}},Flux.Recur{Flux.LSTMCell{Flux.Dense{NNlib.#σ,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}},Flux.Dense{Base.#tanh,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}},TrackedArray{…,Array{Float64,1}}}},Flux.Dense{Base.#identity,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}}}(param([0.00883202 -0.00318895; 0.0140951 0.0126121; -0.0082283 -0.00736651]), Recur(LSTMCell(3, 4)), Dense(4, 2))

julia> params(m)
0-element Array{Any,1}

from flux.jl.

MikeInnes avatar MikeInnes commented on May 18, 2024

The ambiguity issue is one I can fix.

If you're on the latest Flux you need to change the Flux.children line to Flux.treelike(LangModel).

from flux.jl.

ngphuoc avatar ngphuoc commented on May 18, 2024

Thank you. Flux.treelike(LangModel) works perfectly.

from flux.jl.

MikeInnes avatar MikeInnes commented on May 18, 2024

Closing this for now as I think matmul is fine for embeddings, unless anyone has a specific proposal. I've noted the ambiguity issue though so I'll fix that ASAP.

from flux.jl.

datnamer avatar datnamer commented on May 18, 2024

How about a layer that would provide a wrapper to import weights for pretrained embeddings? https://discuss.pytorch.org/t/can-we-use-pre-trained-word-embeddings-for-weight-initialization-in-nn-embedding/1222

And then a way to freeze them during training.

from flux.jl.

oxinabox avatar oxinabox commented on May 18, 2024

@datnamer it is gloriously trivial

Load the pretrained weights into a matrix W from Embeddings.jl.
(Code to do the import lives there)

Then use the onehot product discussed above.
e.g for word with one hot vector ei.

If you want to allow the embedding to fine tune use Param(W)*ei,
if you want to freeze it, use W*ei.

from flux.jl.

datnamer avatar datnamer commented on May 18, 2024

@oxinabox thanks that looks great. Would it be similiairly easy with graph embedding? I'm just starting to play around but there aren't a lot of tutorials for that.

from flux.jl.

oxinabox avatar oxinabox commented on May 18, 2024

I'm not sure. I don't know that pretrained graph embeddings are a thing.

from flux.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.