jxwufan / associativeretrieval Goto Github PK
View Code? Open in Web Editor NEWTensorFlow implementation of Fast Weights
License: Apache License 2.0
TensorFlow implementation of Fast Weights
License: Apache License 2.0
Hi Fan Wu
Thanks for making your implementation of Fast Associative Memory available! A few of us @zhongwen @jiamings have been discussing the memory aspect of this model, and would like some of your thoughts on this issue if you can help.
It seems the most straight forward way is to keep the entire fast_weights in memory, since it is easy to update over time steps by multiplying by lambda. However, the amount of memory fast_weights need is BatchSize x NumUnits x NumUnits, and for practical tasks, even a simple baseline task like character PTB, it will break memory requirements as typically BatchSize is ~ 100, and NumUnits ~ 1000 to get close to SOTA level results. From looking at your implementation and running it, it seems that is done here too, which is perfectly fine, since NumUnits is quite small for this task.
From my reading of the paper, it seems eq (3) describes the method of storing the entire outer product as a batch_size x N x N tensor, while eq (4) possibly allows us to get away with storing only the order of a batch_size x N tensor making it possible to compute the fast_weights on the fly and thus allowing this algorithm to work on practical tasks. However the only way I thought of doing that is to keep track of the previous, say M hidden separately, where M is much smaller than N, in the hopes that lambda ^ M is a very small number so we can ignore anything before M time steps. Been playing around with this variation here at the end of the file: https://github.com/hardmaru/supercell/blob/master/supercell.py
I'm wondering if you have also thought about this issue also thought about a more memory-efficient implementation.
Thanks,
David Ha
Hi @jxwufan,
I am trying to build the same fw-rnn cell from the paper 'Using Fast Weights to Attend to the Recent Past', but I am using Tensorflow 2.4.
But my implementation sadly returns an error, perhaps you can see what I am doing wrong? I am pretty new to tensorflow, at least to model subclassing, so I do not really understand what is going wrong. Here is a link to my code, error, and model summary:
https://gist.github.com/ion-elgreco/cc6fed29a4f6fb6813b71d5e7e8f2c87
The error code gives this:
TypeError: An op outside of the function building code is being passed a "Graph" tensor. It is possible to have Graph tensors leak out of the function building context by including a tf.init_scope in your function building code. For example, the following function will fail: @tf.function def has_init_scope(): my_constant = tf.constant(1.) with tf.init_scope(): added = my_constant * 2 The graph tensor has name: FW-RNN/while/fw_rnn_cell_2/add_2:0
I hope you can help me with this since you have successfully implemented this before!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.