GithubHelp home page GithubHelp logo

maskgit-pytorch's People

Contributors

kifarid avatar llvictorll avatar mickaelchen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

maskgit-pytorch's Issues

reproducibility

Hi @llvictorll, thanks for your nice reproduction. When I evaluated the checkpoints provided, with the following command

torchrun --standalone --nnodes=1 --nproc_per_node=1 main.py --bsize 128 --data-folder imagenet --vit-folder pretrained_maskgit/MaskGIT/MaskGIT_ImageNet_256.pth --vqgan-folder pretrained_maskgit/VQGAN/ --writer-log logs --num_workers 16 --img-size 256 --epoch 301 --resume --test-only

Size of model autoencoder: 72.142M                                                                                                                       
Acquired codebook size: 1024                                                                                                                             
load ckpt from: pretrained_maskgit/MaskGIT/MaskGIT_ImageNet_256.pth                                                                                      
Size of model vit: 174.161M                                                                                                                              
Evaluation with hyper-parameter ->                                                                                                                       
scheduler: arccos, number of step: 8, softmax temperature: 1.0, cfg weight: 3, gumbel temperature: 4.5                                                   
{'Eval/fid_conditional': 7.655000113633889, 'Eval/inception_score_conditional': 228.72691345214844, 'Eval/precision_conditional': 0.8194600000000002, 'Eval/recall_conditional': 0.5016600000000001, 'Eval/density_conditional': 1.2358733333333334, 'Eval/coverage_conditional': 0.8560800000000001}    

The FID result is lower than you have reported (6.80), as shown above. Could you please help figure out where this gap come from? Thanks.

Target tokens for loss computation

Hi, I'd like to qustion about a loss computation part.
This repository (and the original repository?) compute cross-emtropy loss with entire groud-truth tokens.
This implies that the model learns to predict 'known (unmasked)' tokens as well, which is relatively easy to estimate.
As a result, the training may exhibit a strong bias towards the known tokens.

loss = self.criterion(pred.reshape(-1, 1024 + 1), code.view(-1)) / self.args.grad_cum

Intuitively thinking, in this case, a model firstly ignore the loss of 'masked' tokens and the loss of known tokens would be drastically descreased at the beginning of training.

I think there is another option of masking the known position in the target tokens, which results in forcing a model to predict only the unknown (masked) tokens (as same as the approach taken in the following repository).
https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/muse_maskgit_pytorch.py#L680

I'd like to know if you have any insights on the following two points regarding this.

  1. Which approach tends to yield better learning outcomes: masking the target or not masking the target?
  2. Could you share the loss curve through epochs during training so that we can confirm if our training is going well?

Thank you for considering my request!
Best regards,
Yukara

About the training intermediate result.

Hello, really thanks for your great work,
I have a question about the intermediate result. I am trying to reproduce the result in the video domain, but it is really hard to train and the loss does not drop significantly and it keeps producing the pure image where this whole image contains only one color. I just want to ask is it also true for the intermedia result in the Maskgit ?
Really looking forward and thanks for your reply

Warm-up of CFG weight

First of all, thank you for providing such great codes and materials. I was also struggling to reproduce MaskGIT, so it has been a tremendous help.

I noticed an implementation that was not mentioned in the report, which is the warm-up of CFG weight during sampling.

_w = w * (indice / len(scheduler))

If you don't mind, could you please provide insights into the differences in results when this warm-up is applied versus not applied?

Here's another minor point, but would it be more in line with the intended processing if the weight calculation is modified as follows?
_w = w * (indice / (len(scheduler)-1))

Sampling with CFG = 0

Hello,

in the vit.py, on line 381, there is

logit = self.vit(code.clone(), labels, drop_label=~drop)

When debugging, I found that drop is Tensor([True, True, ...]), so it is turned to Tensor([False, False, ...]), meaning the labels are not dropped.
I'm wondering whether this is working as expected, since a CFG of 0 usually means that the label is ignored, right?

train vqgan

I want to use my own Dataset to train. So do I need to retrain vqgan? if so I see that vqgan training seems to be missing the discriminator. how do I train vqgan?

Regarding training a mask model with my own data, could you please provide guidance on the steps involved

Thank you for your great work,I have some questions I would like to ask you, if you don't mind.
data_folder="/datasets_local/ImageNet/"
vit_folder="./pretrained_maskgit/MaskGIT/MaskGIT_ImageNet_256.pth"
vqgan_folder="./pretrained_maskgit/VQGAN/"
writer_log="./logs/"
num_worker=16
bsize=64

Single GPU

python main.py --bsize ${bsize} --data-folder "${data_folder}" --vit-folder "${vit_folder}" --vqgan-folder "${vqgan_folder}" --writer-log "${writer_log}" --num_workers ${num_worker} --img-size 256 --epoch 301 --resume

Multiple GPUs single node

torchrun --standalone --nnodes=1 --nproc_per_node=gpu main.py --bsize ${bsize} --data-folder "${data_folder}" --vit-folder "${vit_folder}" --vqgan-folder "${vqgan_folder}" --writer-log "${writer_log}" --num_workers ${num_worker} --img-size 256 --epoch 301 --resume
If I want to train the mask with custom data, what changes do I need to make to this code? I've already trained my own VQGAN

questions about two stage training

Hey @llvictorll and team,

Really appreciate your reproducing and open source it! It's really helpful for the community. I want to further understand the training and fine-tuning strategy mentioned in the tech report Sec.2. Is that meaning the first stage training is for 256256 and the second fine-tuning is for 512512?

It would be very helpful if you can kindly explain it more.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.