Comments (5)
The steps you tried are quite reasonable! The missing ingredient is that preprocessing computes various vocabularies of motif/atom types based on the data, which affects the shapes of some layers in the model, but those shapes are already fixed if a pretrained model is loaded, hence the shape mismatch (intuitively, this is saying that the number of motif/atom types found in your finetuning dataset was smaller than the number of types found in the original dataset, which is expected).
So, one would have to tell preprocessing to use the vocabularies from the pretrained checkpoint instead of computing them afresh. This isn't supported in the current code (we did briefly experiment with fine-tuning, but not enough for this to end up in the release), but shouldn't be hard to add. I'll hack something together this week and then share with you as a branch; once you test it out I can then make a PR and merge it into main
.
from molecule-generation.
Oh, I see, I get now.
Thank you so much!
from molecule-generation.
@gmseabra: can you pull kmaziarz/finetuning
, re-install the package from there, and try fine-tuning again?
The only change to the workflow you described would be passing --pretrained-model-path
when doing preprocessing. However, note that by default molecule_generation train
will do validation every 5000 steps, and wait until there is no improvement on the validation dataset. If you're fine-tuning on a small set of molecules, it may make sense to set this to something lower (so that training has a chance to stop before it overfits) and/or limit the total number of such rounds of validation. For example, passing
--model-params-override '{"num_train_steps_between_valid": 50}' --max-epochs 8
means that you will do some multiple of 50
steps, at most 8 * 50
, but possibly less if validation stops improving.
Let me know how this goes!
from molecule-generation.
@gmseabra Did you have any luck with fine-tuning?
from molecule-generation.
Hi! I've been trying to replicate this example with the steps you provided above, where I try to finetune it on a small set of 3K molecules, I still encounter the following error (running everything on Colab now). I just took the existing checkpoint, and I wanted to finetune it to a smaller set:
!molecule_generation preprocess input output trace
--pretrained-model-path --load-saved-model /content/drive/MyDrive/subset_gpu_finetuning/moler/molecule-generation/best_model/GNN_Edge_MLP_MoLeR__2022-02-24_07-16-23_best.pkl
!molecule_generation train MoLeR trace
--model-params-override '{"num_train_steps_between_valid": 50}' --max-epochs 8
--load-saved-model /content/drive/MyDrive/subset_gpu_finetuning/moler/molecule-generation/best_model/GNN_Edge_MLP_MoLeR__2022-02-24_07-16-23_best.pkl
--load-weights-only
But even when aligning the metadata that the model was originally trained with (I took the same Guacamol files), it still didn't want to run. This is the error I encountered:
Traceback (most recent call last):
File "/usr/local/bin/molecule_generation", line 8, in <module>
sys.exit(main())
File "/usr/local/lib/python3.10/site-packages/molecule_generation/cli/cli.py", line 35, in main
run_and_debug(lambda: commands[args.command].run_from_args(args), getattr(args, "debug", False))
File "/usr/local/lib/python3.10/site-packages/dpu_utils/utils/debughelper.py", line 21, in run_and_debug
func()
File "/usr/local/lib/python3.10/site-packages/molecule_generation/cli/cli.py", line 35, in <lambda>
run_and_debug(lambda: commands[args.command].run_from_args(args), getattr(args, "debug", False))
File "/usr/local/lib/python3.10/site-packages/molecule_generation/cli/train.py", line 140, in run_from_args
loaded_model_dataset = training_utils.get_model_and_dataset(
File "/usr/local/lib/python3.10/site-packages/tf2_gnn/cli_utils/model_utils.py", line 319, in get_model_and_dataset
load_weights_verbosely(trained_model_file, model)
File "/usr/local/lib/python3.10/site-packages/tf2_gnn/cli_utils/model_utils.py", line 148, in load_weights_verbosely
K.batch_set_value(tfvar_weight_tuples)
File "/usr/local/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/usr/local/lib/python3.10/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 1022, in assign
raise ValueError(
ValueError: Cannot assign value to variable ' decoder/node_type_selector/MLP_final_layer/kernel:0': Shape mismatch.The variable shape (256, 142), and the assigned value shape (256, 167) are incompatible.
I'm not sure whether I did something wrong, and where I could fix it!
from molecule-generation.
Related Issues (20)
- Sample molecules based on a specific scaffold HOT 1
- Improve and clean up the visualisers
- M1 Mac problem HOT 2
- Pre-training model download link failed HOT 6
- How does decode can return multiple similar molecules? HOT 2
- Motif embeddings HOT 2
- Clarification: correct_edge_choices is array of all zeros, while valid_edge_choices has a few candidates HOT 3
- Script for recreating evaluation scores on Guacamol benchmark HOT 14
- Query about data split! HOT 1
- Question about node_type_predictor_class_loss_weight_factor HOT 2
- where is the training datasets? HOT 2
- how can i generate large SMILES ? for example generate 100000000? HOT 1
- IndexError: pop from empty list HOT 11
- Data Preprocessing HOT 1
- Warning when using Load_model_from_directory(dir) HOT 1
- Tensorflow warnings when using encode HOT 5
- Large amount of error messages when using decode HOT 5
- libdevice not found during training using default conda environment on Ubuntu 22.04.2 with a RTX A4000 HOT 4
- Computing likely next actions HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from molecule-generation.