Comments (2)
An easy fix is to make the anonymous functions proper ones, because they don't need to be anonymous:
x->x[begin:inputpoints, 1, :]
# becomes
f(x, inputpoints) = x[begin:inputpoints, 1, :]
...
mainblock = Chain(
Base.Fix2(f, inputpoints),
...
)
# and
x -> x[:, (begin+1):end, :]
# becomes
g(x) = x[:, (begin+1):end, :] # no need for Fix1/Fix2
But this only handles anonymous functions. For a more robust solution, let's switch to #2263.
from flux.jl.
In accordance with @ToucheSir 's suggestion to write a more robust solution, this is how it would apply to this particular case (we drop BSON.jl for JLD2.jl instead):
Creating and saving the model:
### SAVING TEST
using Flux
using Random
using JLD2
inputpoints = 24 * 7
auxfeatures = 3 # one main feature, 3 aux features
samples = 2000
labelpoints = 24 * 2
inputs = randn(Float32, inputpoints, 1 + auxfeatures, samples)
# main feature block
mainblock = Chain(
x->x[begin:inputpoints, 1, :],
Dense(inputpoints, labelpoints)
)
# aux features block
auxblock = Chain(
x -> x[:, (begin+1):end, :],
Flux.flatten,
Dense(auxfeatures * inputpoints, labelpoints)
)
struct TestModel
architecture
end
Flux.@functor TestModel
TestModel() = TestModel(Parallel(+, mainblock, auxblock))
model = TestModel()
model_state = Flux.state(model)
jldsave("testsavemodel_1.jld2"; model_state)
Starting a new session and loading the model:
### LOADING TEST
using Flux
using Random
using JLD2
inputpoints = 24 * 7
auxfeatures = 3 # one main feature, 3 aux features
samples = 2000
labelpoints = 24 * 2
inputs = randn(Float32, inputpoints, 1 + auxfeatures, samples)
# main feature block
mainblock = Chain(
x->x[begin:inputpoints, 1, :],
Dense(inputpoints, labelpoints)
)
# aux features block
auxblock = Chain(
x -> x[:, (begin+1):end, :],
Flux.flatten,
Dense(auxfeatures * inputpoints, labelpoints)
)
struct TestModel
architecture
end
Flux.@functor TestModel
TestModel() = TestModel(Parallel(+, mainblock, auxblock))
model = TestModel()
model_state = JLD2.load("testsavemodel_1.jld2", "model_state");
Flux.loadmodel!(model, model_state)
A forward pass after loading is successful:
julia> test_fwdpass = model.architecture(inputs)
48×2000 Matrix{Float32}:
4.44022 0.802483 -0.785774 1.168 … -2.20038 -2.76761 -2.75335
-1.81597 3.47688 0.0940425 -0.837038 0.155797 2.5794 -0.0938851
2.21066 -0.4048 0.903007 0.167684 1.19697 -1.00276 -1.79072
0.929377 1.1883 -1.82898 1.01884 -0.725962 1.04085 2.17898
0.00539947 1.49683 1.25519 1.50682 2.72747 0.716122 2.52785
1.03204 3.22989 -1.66981 -0.999194 … -0.215202 -1.27665 0.376921
2.51096 2.41828 0.436551 0.517585 1.67277 0.609859 -1.54591
-1.34194 -0.228893 1.87149 -0.986849 -0.191224 0.687425 2.22133
1.9604 0.951124 1.43568 -0.238653 -1.622 4.54916 -3.99599
-0.993577 -2.96885 -1.70936 -0.713654 1.94885 -1.54148 0.403749
0.18666 0.834455 2.35449 1.00192 … -0.136148 0.861816 -1.7685
-2.36995 1.94883 1.31425 -1.37012 1.78269 -1.19305 0.525236
-0.556477 0.447952 -0.959529 0.850635 -1.19533 -0.692481 -1.17249
2.14281 0.17941 -0.65601 -3.38384 -0.336295 0.250721 -0.866344
2.52481 3.07921 -0.58382 -0.656336 -0.994389 -0.602142 0.530116
1.2512 0.877351 -0.74357 -0.797333 … 3.61359 -1.4924 -2.77331
0.0869287 -0.671315 -0.128169 1.9544 -0.242938 0.586071 -0.168547
1.09363 0.708124 -1.0453 2.32946 5.08991 -3.25003 0.0925286
-0.548058 -0.681359 0.0118403 -3.75676 -1.88147 0.104736 0.480259
⋮ ⋱
-0.416707 3.67179 -2.48939 1.52213 -0.776104 -0.346431 1.32079
0.655315 -0.415754 -1.45568 0.0851394 … 4.02886 2.77373 2.17698
-0.317264 -0.439673 -0.530158 -0.837444 0.284554 -1.00613 -0.366141
-0.296634 -1.96891 -2.48071 2.27509 -0.6101 -0.508833 -1.74481
1.96883 1.32886 -0.969475 -1.23352 -3.45104 2.03444 1.31539
3.81404 1.32852 2.34517 -2.12479 1.67277 0.0501646 1.32144
-0.0490075 -0.218952 2.18 3.05685 … -0.44117 -2.41891 -1.35152
1.33143 -0.689682 -1.03449 -0.0169412 -0.773172 -2.20266 -2.73936
-1.35926 -0.917676 4.6618 -1.13945 -3.41797 0.761221 0.333108
-0.225759 0.278201 1.78722 -0.131045 -2.63882 0.433773 -2.62248
4.81922 -0.870089 -4.80774 -1.5178 0.123205 -2.02181 1.56211
0.127518 0.723261 3.8712 -0.400356 … 0.197132 -3.68057 2.66511
-1.14512 -0.829157 0.0856611 0.0258443 -0.740243 -1.0791 0.617436
1.65157 -2.93585 0.989425 0.754669 0.606092 -1.09547 -1.23846
-0.490223 -0.190012 2.91653 1.45833 -0.137385 -2.23218 -1.20121
-1.37658 -4.13181 1.79136 -3.11379 -1.21975 0.521379 1.01322
-0.952514 1.02663 -0.793957 -1.69722 … -0.0394366 -4.34157 4.12784
-0.515579 -1.04139 -2.13667 1.92703 -0.915622 2.50567 -3.46607
3.25076 -2.62687 0.576621 1.19447 4.88387 0.0299822 -0.749113
As such, I'm closing this issue as the documentation need is already explained on #2263.
from flux.jl.
Related Issues (20)
- Adding Simple Recurrent Unit as a recurrent layer
- Collecting PyTorch -> Flux migration notes HOT 1
- tests are failing due to ComponentArrays HOT 2
- deprecate Flux.params HOT 7
- Significant time spent moving medium-size arrays to GPU, type instability HOT 10
- ConvTranspose errors with symmetric non-constant pad
- SamePad() for even sized filters.
- Dense layers with shared parameters HOT 5
- Implementation of `AdamW` differs from PyTorch HOT 10
- `gpu` should warn if cuDNN is not installed HOT 2
- Cannot take `gradient` of L2 regularization loss HOT 1
- Create a flag to use Enzyme as the AD in training/etc. HOT 13
- test Enzyme gradient for loss functions
- test Enzyme gpu support
- Enzyme fails with MultiHeadAttention layer HOT 13
- Enable github Discussions
- Stacked RNN in Flux.jl?
- Add option to throw error on passing wrong precision floats to layers HOT 3
- Potential bug of RNN training flow
- why is my `withgradient` type unstable ? HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flux.jl.