Comments (11)
This is trivial with normal slicing:
julia> W = randn(3, 10)
3×10 Array{Float64,2}:
1.52921 0.34536 -1.54119 … -0.303022 0.123047 0.00684176
-0.138502 0.4285 -0.841863 0.0720752 -2.19672 -0.0968025
-0.00694323 0.916636 0.834962 -0.239865 0.0237941 0.0236618
julia> W[:, 5]
3-element Array{Float64,1}:
-0.972017
0.293653
-1.20708
julia> W[:, [5, 6]]
3×2 Array{Float64,2}:
-0.972017 2.58503
0.293653 -2.29474
-1.20708 1.82632
julia> W[:, [5 6; 7 8]]
3×2×2 Array{Float64,3}:
[:, :, 1] =
-0.972017 -0.914234
0.293653 0.333886
-1.20708 -0.386469
[:, :, 2] =
2.58503 -0.303022
-2.29474 0.0720752
1.82632 -0.239865
There may be some benefit to having a convenience wrapper though. Should be easy to put together if you want to set up a PR.
from flux.jl.
Better still is to use the OneHot magic, I think.
This is like expressly what it is for?
Since onehot encoded value takes up no more space than an Int, and onehot multiplication is slicing under the hood.
http://fluxml.github.io/Flux.jl/latest/data/onehot.html
from flux.jl.
Thanks. I've tried onehot
and got the follow error:
julia> x = onehot(1, 1:10)
10-element Flux.OneHotVector:
true
false
false
false
false
false
false
false
false
false
julia> m = Chain(Dense(10, 5), Dense(5, 2))
Chain(Dense(10, 5), Dense(5, 2))
julia> m(x)
ERROR: MethodError: *(::TrackedArray{…,Array{Float64,2}}, ::Flux.OneHotVector) is ambiguous. Candidates:
*(A::AbstractArray{T,2} where T, b::Flux.OneHotVector) in Flux at /home/phuoc/.julia/v0.6/Flux/src/onehot.jl:10
*(a::Flux.Tracker.TrackedArray{T,2,A} where A where T, b::AbstractArray{T,1} where T) in Flux.Tracker at /home/phuoc/.julia/v0.6/Flux/src/tracker/lib.jl:67
Possible fix, define
*(::Flux.Tracker.TrackedArray{T,2,A} where A where T, ::Flux.OneHotVector)
Stacktrace:
[1] (::Flux.Dense{Base.#identity,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}})(::Flux.OneHotVector) at /home/phuoc/.julia/v0.6/Flux/src/layers/basic.jl:61
[2] mapfoldl_impl(::Base.#identity, ::Flux.##45#46, ::Flux.OneHotVector, ::Array{Any,1}, ::Int64) at ./reduce.jl:43
[3] (::Flux.Chain)(::Flux.OneHotVector) at /home/phuoc/.julia/v0.6/Flux/src/layers/basic.jl:28
from flux.jl.
I also tried to add the following model but failed to train since there was no parameters returned by params
. Did I miss something?:
julia> using Flux
julia> using Flux: onehotbatch, unstack, truncate!, throttle, logloss, initn
julia> using Flux.Tracker: param, back!, data, grad
julia> struct LangModel{E,R,F}
emb::E
rnn::R
fc::F
end
julia> LangModel(v::Int, e::Int, h::Int) = LangModel(param(initn(e, v)),
LSTM(e, h),
Dense(h, v))
LangModel
julia> Flux.children(m::LangModel) = (m.emb, m.rnn, m.fc,)
julia> (m::LangModel)(x) = softmax(m.fc(m.rnn(m.emb[:,x])))
julia> m = LangModel(2,3,4)
LangModel{TrackedArray{…,Array{Float64,2}},Flux.Recur{Flux.LSTMCell{Flux.Dense{NNlib.#σ,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}},Flux.Dense{Base.#tanh,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}},TrackedArray{…,Array{Float64,1}}}},Flux.Dense{Base.#identity,TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,1}}}}(param([0.00883202 -0.00318895; 0.0140951 0.0126121; -0.0082283 -0.00736651]), Recur(LSTMCell(3, 4)), Dense(4, 2))
julia> params(m)
0-element Array{Any,1}
from flux.jl.
The ambiguity issue is one I can fix.
If you're on the latest Flux you need to change the Flux.children
line to Flux.treelike(LangModel)
.
from flux.jl.
Thank you. Flux.treelike(LangModel)
works perfectly.
from flux.jl.
Closing this for now as I think matmul is fine for embeddings, unless anyone has a specific proposal. I've noted the ambiguity issue though so I'll fix that ASAP.
from flux.jl.
How about a layer that would provide a wrapper to import weights for pretrained embeddings? https://discuss.pytorch.org/t/can-we-use-pre-trained-word-embeddings-for-weight-initialization-in-nn-embedding/1222
And then a way to freeze them during training.
from flux.jl.
@datnamer it is gloriously trivial
Load the pretrained weights into a matrix W
from Embeddings.jl.
(Code to do the import lives there)
Then use the onehot product discussed above.
e.g for word with one hot vector ei
.
If you want to allow the embedding to fine tune use Param(W)*ei
,
if you want to freeze it, use W*ei
.
from flux.jl.
@oxinabox thanks that looks great. Would it be similiairly easy with graph embedding? I'm just starting to play around but there aren't a lot of tutorials for that.
from flux.jl.
I'm not sure. I don't know that pretrained graph embeddings are a thing.
from flux.jl.
Related Issues (20)
- No cudnn implementation of Conv((1,) N=>M) HOT 12
- Flux docs missing withgradient() call for multi-objective loss functions HOT 6
- can't use masks in multi-head-attention layer HOT 6
- Adapt saving & loading example to CuArrays HOT 2
- Segmentation fault when doing a forward pass with a model saved with BSON HOT 2
- Flux new explicit API does not work but old implicit API works for a simple RNN HOT 4
- Intel Arc GPU support. HOT 3
- `using Flux, cuDNN` freezes, but `using Flux, CUDA, cuDNN` works HOT 1
- Convolutional network slower than tensorflow on CPU HOT 13
- Problem with RNN and CUDA. HOT 7
- precompilation issue on Julia 1.10 HOT 1
- Android/iOS support HOT 1
- since new version: Flux throws error when for train! / update! even on quick start problem HOT 4
- Illegal Memory Access Error During Gradient Calculation of predefined losses on GPU RTX 4050 HOT 1
- Unnecessarily using shared GPU memory HOT 8
- Flux installation error under Julia 1.10 on Apple Silicon HOT 2
- Given that DataLoader implements `length` shouldn't it also be able to provide size? HOT 4
- The dedicated tutorial on DataLoader is missing HOT 2
- Incorrect link on docs HOT 4
- Hard error using dice loss HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flux.jl.