Comments (4)
I have never implemented a prioritized replay buffer, but your idea looks like it can be extended to storing trajectories rather than terminal states.
For example, you can use the provided ReplayBuffer class, and keep a sorted list of indices (the indices of the trajectories). Every time you add a trajectory, you sort again the list of indices (using the corresponding trajectory reward for sorting). Then you could use the sorted list of indices to subsample your replay buffer, rather than using the sample()
method.
from torchgfn.
before trajectories.extend(offline_trajectories) seems to work but I don't know if there will be unexpected behavior downstream. It seems that log_probs needs to be padded after Trajectories.revert_backward_trajectories(). I would appreciate your insight.
If you actually don't care about the log_probs
attribute of the obtained trajectories, then yeah your suggestion should work fine.
would it be better to sample forward trajectories stored in ReplayBuffer instead?
I'm not sure what you mean by that. Did you mean "to store" instead of "to sample" ? If so, note that a replay buffer works essentially with extend
s. The difference with what you wrote is that a ReplayBuffer
object has a limited capacity. What do you think such a replay buffer would be helpful for ?
from torchgfn.
If you actually don't care about the log_probs attribute of the obtained trajectories, then yeah your suggestion should work fine.
Alright thank you! I can set on_policy=False
on the loss function so that the log_probs
attribute won't be used during training, right? I was just wondering if I was converting backward trajectories into forward ones correctly. Since I'm also using them in training
I'm not sure what you mean by that. Did you mean "to store" instead of "to sample" ? If so, note that a replay buffer works essentially with extends. The difference with what you wrote is that a ReplayBuffer object has a limited capacity. What do you think such a replay buffer would be helpful for ?
I'm implementing a replay buffer where we can sample from the top highest (and lowest) reward terminal states seen so far. Its PRT from https://arxiv.org/abs/2305.07170. We need it because our rewards are very skewed.
Currently, every time I add terminal states to the buffer, I remove any duplicated terminal states and I sort them according to their corresponding rewards. So when I sample from the buffer, I only sample from the first n states and last n states. If I reach the maximum capacity, I remove states from the middle of the buffer.
My question was, instead of storing just terminal states and sampling the highest (and lowest) reward terminal states from them, can I store the trajectories themselves and find a way to sample those trajectories with the desired rewards? This is so that I wouldn't need to generate the backward trajectories from the terminal states and convert them to forward ones.
What do you think is the best way to implement such a buffer?
from torchgfn.
Alright, let me try that out! Thank you :)
from torchgfn.
Related Issues (20)
- Simplify `SubTBGFlowNet.get_scores` and `SubTBGFlowNet.loss`
- Renaming `batch_dim`
- Add "EdgeBasedGFlowNet" and "StateBasedGFlowNet". HOT 1
- Manually create more involved documentation HOT 1
- Environment's should be removed from `GFlowNets` in a way that allows one to easily specify multiple environments (e.g., Multi Objective GFN). HOT 1
- Let's build some black box design examples using DesignBench and using a gflownet (on a simple problem).
- Implement a uniform forward policy as a random baseline method for easy benchmarking.
- Environment composer
- typeguard / CI
- `DiscretePolicyEstimator` has no printable representation HOT 1
- [help wanted] Traning a DiscreteEBM ends up with "Log probabilities are inf. This should not happen." HOT 4
- Example scripts typing fails with `X is not a generic class` HOT 3
- `test_scripts` failing by small deltas HOT 2
- GFN Object Typing still isn't quite right.
- flow matching cannot pass `to_probability_distribution_kwargs`
- `maskless_step` and `maskless_backward_step` should throw an error if you don't return a tensor HOT 1
- Make loss debgging easier
- Conserve LogZ value while using epsilon HOT 6
- Add conditional LogZ calculation
- Rethinking the License HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from torchgfn.