naver-ai / stylemapgan Goto Github PK
View Code? Open in Web Editor NEWOfficial pytorch implementation of StyleMapGAN (CVPR 2021)
Home Page: https://www.youtube.com/watch?v=qCapNyRA_Ng
License: Other
Official pytorch implementation of StyleMapGAN (CVPR 2021)
Home Page: https://www.youtube.com/watch?v=qCapNyRA_Ng
License: Other
Hi,authors,thank you for your great work. What should I do if I want to train your model on my own dataset? I think you are supposed to tell readers how to process their own data if they want to use it for training. Another question is why not release the model about LSUN Car & Church?I want to get the model to vertify the effect your model behaved on LSUN Car & Church.
Hello, thanks for sharing your code. I'm very interested in your work and I got a question when I run the metrics for reconstruction. The output of metrics.reconstruction are 0.015(mse) and 0.214(lpips). The output seems much better than the results reported in Table 3 in the paper (mse=0.024, lpips=0.242). I would appreciate it if you could give me an explanation.
Hi, thank you for sharing your work. I can't download weight through script, please check.
please add example for colab
Because I have a limited memory GPU, I just can train model with small generator. But I don't enough understand about the difference between normal G and small G. Can you explain the main difference, please?
searching for a way to add class conditional for image edit.
How do you project an image into the latent space of Structured Noise[3] in Table 3?
Hi, thanks for your excellent work!
May I know how to calculate or obtain GT_labels and LMDB_test_mask for other images in CelebA_HQ as this repo just provides the processed masks for downloading via download.sh.
Thanks!
Hi,
I use the provided prepare_data.py to transfer the downloaded raw images to the corresponding lmdb data with 256x256 size. But I find that the lmdb data has different sizes compared to the downloaded lmdb data (LMDB_train/test/val). I also tried other lmdb sizes such as 128, 512 and 1024 but no image size can match the downloaded lmdb data.
May I figure out the reason?
Thanks.
Hi,Thanks for your great work. As you say in the paper, the generator and the encoder are trained jointly not separately, how about trained together as a Unified framework without fixing weights.
Hello,
We followed the guide to prepare_data for my own images, and generate lmdb file. When run testing:
python generate.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --mixing_type local_editing --test_lmdb data/celeba_hq/LMDB_test --local_editing_part nose
It inform no image.
How do we test on my own image ?
there is how to get pretrained network
# Download raw images and create LMDB datasets using them
# Additional files are also downloaded for local editing
bash download.sh create-lmdb-dataset celeba_hq
# Download the pretrained network (256x256)
bash download.sh download-pretrained-network-256 celeba_hq # 20M-image-trained models
bash download.sh download-pretrained-network-256 celeba_hq_5M # 5M-image-trained models used in our paper for comparison with other baselines and for ablation studies.
# Download the pretrained network (1024x1024 image / 16x16 stylemap / Light version of Generator)
bash download.sh download-pretrained-network-1024 ffhq_16x16
but with these networks, it doesn't work
File "demo.py", line 192, in <module> ckpt = torch.load(args.ckpt) File "/root/anaconda3/envs/stylemap/lib/python3.6/site-packages/torch/serialization.py", line 608, in load return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args) File "/root/anaconda3/envs/stylemap/lib/python3.6/site-packages/torch/serialization.py", line 777, in _legacy_load magic_number = pickle_module.load(f, **pickle_load_args) _pickle.UnpicklingError: invalid load key, '<'.
and i found that after run train.py, i got 000000.pt (didn't trained)
and with this 000000.pt, demo.py works well,(but output image is noisy image)
so is there any way get new pretrained network?
i tried on (pytorch= 1.4.0, 1.10) and (remote server-docker , colab)
what are min. system requirements to run demo.py?
Thank you for publishing the code. Your work is very impressive. I wonder how many iterations do you use to train the model?I noticed that the default is 1,400,000. Is this value for all your training datasets? 1400000 training iterations a bit too long, I am curious if there is a small number of training iteration.
Hi! On the left side of the equation, does the "hi+1" mean " hi' " or " i+1-the layer"?
Hello. Thanks for this code.
I have trained the model on my own dataset that is different from the face dataset. The image reconstruction, random generation works well. However, If I do the style mixing their is a hard coded pkl
file required that is related to the celeba_hq dataset (data/celeba_hq/local_editing/celeba_hq_test_GT_sorted_pair.pkl)
.
Just for the sake of running stylemixing I downloaded the dataset and ran the code, however, the results are not good.
Could you please share your thoughts on how to apply stylemixing on a custom dataset?
I downloaded the pre-trained networks using the updated download.sh files, and pytorch is unable to load them, giving the following error: _pickle.UnpicklingError: invalid load key, ‘<‘
I suspect that the pre-trained network on the google drive is corrupted. Could you check whether it is the case?
Hi , thanks for your great work, i want to konw why do you train the G before the D?
Hello authors,
I have an error when using the pretrained checkpoints when torch.load(args.ckpt)
is run in generate.py (or any piece of code with the function). I tried re-downloading the model as proposed in this issue.
The only change I made to accommodate my CUDA 11 version was to instal pytorch with conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
. Thanks for you attention, it is much appreciated.
hello, how can I get the pre training model....It seems can't download it now
Hi there, thanks so much for sharing the codes - really amazing work!
When I was trying to train the model on my own, I ran into an error at line 383 in train.py
((w_rec_loss * args.lambda_w_rec_loss).backward()
) that
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
I read the model definition and the implementation looked correct to me - so I don't understand why such error was thrown. Do you maybe have any idea what could have gone wrong?
In the meanwhile, to unblock myself, I modified the codes a bit to run backward()
for g_loss
and w_rec_loss
in one go (g_w_rec_loss
in the following example). Does this modification make sense to you? Why did you separate the backward
operation in the first place?
adv_loss, w_rec_loss, stylecode = model(None, "G")
adv_loss = adv_loss.mean()
w_rec_loss = w_rec_loss.mean()
g_loss = adv_loss * args.lambda_adv_loss
g_optim.zero_grad()
e_optim.zero_grad()
g_w_rec_loss = g_loss + w_rec_loss * args.lambda_w_rec_loss
g_w_rec_loss.backward()
gather_grad(
g_module.parameters(), world_size
) # Explicitly synchronize Generator parameters. There is a gradient sync bug in G.
g_optim.step()
e_optim.step()
Thanks in advance for your help!
when I press the button to generate result it works only once, the second time it throws a cuda error.
/content/StyleMapGAN
train_args: Namespace(batch=16, batch_per_gpu=8, channel_multiplier=2, ckpt=None, d_reg_every=16, dataset='celeba_hq', iter=1400000, lambda_adv_loss=1, lambda_d_loss=1, lambda_indomainGAN_D_loss=1, lambda_indomainGAN_E_loss=1, lambda_perceptual_loss=1, lambda_w_rec_loss=1, lambda_x_rec_loss=1, latent_channel_size=64, latent_spatial_size=8, lr=0.002, lr_mul=0.01, mapping_layer_num=8, mapping_method='MLP', n_sample=16, ngpus=2, normalize_mode='LayerNorm', num_workers=10, r1=10, remove_indomain=False, remove_w_rec=False, size=256, small_generator=False, start_iter=0, train_lmdb='/data/celeba_hq_lmdb/train/LMDB_train', val_lmdb='/data/celeba_hq_lmdb/train/LMDB_val')
* Serving Flask app "demo" (lazy loading)
* Environment: production
WARNING: Do not use the development server in a production environment.
Use a production WSGI server instead.
* Debug mode: on
* Running on http://127.0.0.1:6006/ (Press CTRL+C to quit)
* Restarting with stat
train_args: Namespace(batch=16, batch_per_gpu=8, channel_multiplier=2, ckpt=None, d_reg_every=16, dataset='celeba_hq', iter=1400000, lambda_adv_loss=1, lambda_d_loss=1, lambda_indomainGAN_D_loss=1, lambda_indomainGAN_E_loss=1, lambda_perceptual_loss=1, lambda_w_rec_loss=1, lambda_x_rec_loss=1, latent_channel_size=64, latent_spatial_size=8, lr=0.002, lr_mul=0.01, mapping_layer_num=8, mapping_method='MLP', n_sample=16, ngpus=2, normalize_mode='LayerNorm', num_workers=10, r1=10, remove_indomain=False, remove_w_rec=False, size=256, small_generator=False, start_iter=0, train_lmdb='/data/celeba_hq_lmdb/train/LMDB_train', val_lmdb='/data/celeba_hq_lmdb/train/LMDB_val')
* Debugger is active!
* Debugger PIN: 746-628-559
127.0.0.1 - - [17/Sep/2021 23:56:51] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:56:51] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "POST /post HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/15.png?ehKqCSWsm2_dZDSEqGn2KA HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/14.png?bpSTEHoDmsEdxklrmK1Iyw HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/13.png?vNWTCVyaDi3SgiGKoZVgqg HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/12.png?YpPhMa2VRbTKOyk_bXF1uw HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/11.png?Rv2b_tny9RSuUGjNv-Aj8w HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/10.png?gUzT-B3BIBWt8zKBCN-UJQ HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/9.png?hgD_0xCxc1D3SvoRnOL4kQ HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/8.png?irG3xuBZ7T4aUC1KSpNl0g HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/7.png?FzAEfzS1Y-PaknX8gkh3Ow HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/6.png?qSPaQEGIwCJJcp_lqKLDrA HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/5.png?uHnXqW2275FRM4kIAWlWoQ HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/4.png?KnBxIm4rmz5WSjDZ3i9fUA HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/3.png?i5-jUUhv1ey199OxX_dbCg HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/2.png?CiAW6qyZmqz93OLXoZx_MQ HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/0.png?72jqdYShEUiOpkSMW9R_qg HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/1.png?8AK-OswmqdND3xa-F9ZJvw HTTP/1.1" 200 -
127.0.0.1 - - [18/Sep/2021 00:14:17] "POST /post HTTP/1.1" 500 -
Traceback (most recent call last):
File "/usr/local/lib/python3.7/site-packages/flask/app.py", line 2309, in __call__
return self.wsgi_app(environ, start_response)
File "/usr/local/lib/python3.7/site-packages/flask/app.py", line 2295, in wsgi_app
response = self.handle_exception(e)
File "/usr/local/lib/python3.7/site-packages/flask/app.py", line 1741, in handle_exception
reraise(exc_type, exc_value, tb)
File "/usr/local/lib/python3.7/site-packages/flask/_compat.py", line 35, in reraise
raise value
File "/usr/local/lib/python3.7/site-packages/flask/app.py", line 2292, in wsgi_app
response = self.full_dispatch_request()
File "/usr/local/lib/python3.7/site-packages/flask/app.py", line 1815, in full_dispatch_request
rv = self.handle_user_exception(e)
File "/usr/local/lib/python3.7/site-packages/flask/app.py", line 1718, in handle_user_exception
reraise(exc_type, exc_value, tb)
File "/usr/local/lib/python3.7/site-packages/flask/_compat.py", line 35, in reraise
raise value
File "/usr/local/lib/python3.7/site-packages/flask/app.py", line 1813, in full_dispatch_request
rv = self.dispatch_request()
File "/usr/local/lib/python3.7/site-packages/flask/app.py", line 1799, in dispatch_request
return self.view_functions[rule.endpoint](**req.view_args)
File "/content/StyleMapGAN/demo.py", line 181, in post
save_dir=save_dir,
File "/usr/local/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 49, in decorate_no_grad
return func(*args, **kwargs)
File "/content/StyleMapGAN/demo.py", line 133, in my_morphed_images
mixed = model(original_image, reference_images, masks, shift_values).cpu()
File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/content/StyleMapGAN/demo.py", line 73, in forward
mask=[masks, shift_values, args.interpolation_step],
File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/content/StyleMapGAN/training/model.py", line 1153, in forward
image = self.decoder(stylecode, mix_space=mix_space, mask=mask)
File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/content/StyleMapGAN/training/model.py", line 1073, in forward
out = self.convs[i](out, [style_codes[2 * i + 1], style_codes[2 * i + 2]])
File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/content/StyleMapGAN/training/model.py", line 720, in forward
out = self.conv2(out, stylecodes[1])
File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/content/StyleMapGAN/training/model.py", line 499, in forward
out = self.conv(input, style)
File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/content/StyleMapGAN/training/model.py", line 610, in forward
input = input * gamma + beta
RuntimeError: CUDA out of memory. Tried to allocate 512.0
```0 MiB (GPU 0; 11.17 GiB total capacity; 6.90 GiB already allocated; 215.44 MiB free; 8.10 GiB reserved in total by PyTorch)
Hello, I would first like to thank you for sharing your work.
I am having problem on loading weights from checkpoints(i.e on continuing halted training)
I am training StyleMapGAN on custom dataset(~200K images in the training dataset, 1024*1024 resoulution), and I am currently using 3 TitanRTX GPUs. I am using latent_spatial_size=16 considering input image resolution and GPU memory. On training with such configuration, batch 2 is allocated per GPU using ~21 GiB memory.
There is no problem on training from scratch. I have not tried using pretrained weights trained on FFHQ or CelebA because my data is quite different from human faces. Moreover, as I have succeeded on generating images from generate.py, I think weights were saved in proper way.
However, memory allocation problem occurs every time I load custom weights to continue training. I assumed extra memory may be required on loading weights, so I tried using smaller batch size (batch 2 per GPU->batch 1 per GPU), but same memory shortage problem occurs.
To summarize, I cannot load weights to continue training, whereas training from scratch or loading weights to generate images are working well. Thereafter, I would like to ask following questions.
I would be grateful if you take a look into my question. Thank you!
I found that layer norm is slower than the other norm mode at inference time.
Hello, thank you for interesting work!
used python demo.py --ckpt expr/checkpoints/celeba_hq_8x8_20M_revised.pt --dataset celeba_hq
I try to evaluate quality of real photos embedding.
It works impressively good on provided by default images:
however drastically worse on photos from internet:
(photos are aligned by face landmarks and cropped to 1024x1024 by https://github.com/ZPdesu/Barbershop/blob/main/align_face.py)
what can be the reason and how do I fix this?
Thanks for excellent work! I want to know how to realize Interpolation specifically In experiment?
how can u get MSE, LPIPS of image2stylegan, structured noise and stylegan2 in table3, i mean, structured noise didn't provide the code about how to get image reconstruction, and image2stylegan didn't provide codes(No official code found), and what about stylegan2, can u tell me please?
I would like to run some operations, like construction on FFHQ_1024 but I am missing the lmdb files and set up. Can you make it available?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.