GithubHelp home page GithubHelp logo

naver-ai / stylemapgan Goto Github PK

View Code? Open in Web Editor NEW
458.0 10.0 81.0 3.89 MB

Official pytorch implementation of StyleMapGAN (CVPR 2021)

Home Page: https://www.youtube.com/watch?v=qCapNyRA_Ng

License: Other

Python 90.86% Shell 2.51% C++ 0.98% Cuda 5.65%

stylemapgan's Introduction

StyleMapGAN - Official PyTorch Implementation

StyleMapGAN: Exploiting Spatial Dimensions of Latent in GAN for Real-time Image Editing
Hyunsu Kim, Yunjey Choi, Junho Kim, Sungjoo Yoo, Youngjung Uh
In CVPR 2021.

Paper: https://arxiv.org/abs/2104.14754
5-minute video (CVPR): https://www.youtube.com/watch?v=7sJqjm1qazk
Demo video: https://youtu.be/qCapNyRA_Ng

Abstract: Generative adversarial networks (GANs) synthesize realistic images from random latent vectors. Although manipulating the latent vectors controls the synthesized outputs, editing real images with GANs suffers from i) time-consuming optimization for projecting real images to the latent vectors, ii) or inaccurate embedding through an encoder. We propose StyleMapGAN: the intermediate latent space has spatial dimensions, and a spatially variant modulation replaces AdaIN. It makes the embedding through an encoder more accurate than existing optimization-based methods while maintaining the properties of GANs. Experimental results demonstrate that our method significantly outperforms state-of-the-art models in various image manipulation tasks such as local editing and image interpolation. Last but not least, conventional editing methods on GANs are still valid on our StyleMapGAN. Source code is available at https://github.com/naver-ai/StyleMapGAN.

Demo

Youtube video Click the figure to watch the teaser video.

Interactive demo app Run demo in your local machine.

All test images are from CelebA-HQ, AFHQ, and LSUN.

python demo.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --dataset celeba_hq

Installation

ubuntu gcc 7.4.0 CUDA CUDA-driver cudnn7 conda Python 3.6.12 pytorch 1.4.0

Clone this repository:

git clone https://github.com/naver-ai/StyleMapGAN.git
cd StyleMapGAN/

Install the dependencies:

conda create -y -n stylemapgan python=3.6.12
conda activate stylemapgan
./install.sh

Datasets and pre-trained networks

We provide a script to download datasets used in StyleMapGAN and the corresponding pre-trained networks. The datasets and network checkpoints will be downloaded and stored in the data and expr/checkpoints directories, respectively.

CelebA-HQ. To download the CelebA-HQ dataset and parse it, run the following commands:

# Download raw images and create LMDB datasets using them
# Additional files are also downloaded for local editing
bash download.sh create-lmdb-dataset celeba_hq

# Download the pretrained network (256x256) 
bash download.sh download-pretrained-network-256 celeba_hq # 20M-image-trained models
bash download.sh download-pretrained-network-256 celeba_hq_5M # 5M-image-trained models used in our paper for comparison with other baselines and for ablation studies.

# Download the pretrained network (1024x1024 image / 16x16 stylemap / Light version of Generator)
bash download.sh download-pretrained-network-1024 ffhq_16x16

AFHQ. For AFHQ, change above commands from 'celeba_hq' to 'afhq'.

Train network

Implemented using DistributedDataParallel.

# CelebA-HQ
python train.py --dataset celeba_hq --train_lmdb data/celeba_hq/LMDB_train --val_lmdb data/celeba_hq/LMDB_val

# AFHQ
python train.py --dataset afhq --train_lmdb data/afhq/LMDB_train --val_lmdb data/afhq/LMDB_val

# CelebA-HQ / 1024x1024 image / 16x16 stylemap / Light version of Generator
python train.py --size 1024 --latent_spatial_size 16 --small_generator --dataset celeba_hq --train_lmdb data/celeba_hq/LMDB_train --val_lmdb data/celeba_hq/LMDB_val 

Generate images

Reconstruction Results are saved to expr/reconstruction.

# CelebA-HQ
python generate.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --mixing_type reconstruction --test_lmdb data/celeba_hq/LMDB_test

# AFHQ
python generate.py --ckpt expr/checkpoints/afhq_256_8x8.pt --mixing_type reconstruction --test_lmdb data/afhq/LMDB_test

W interpolation Results are saved to expr/w_interpolation.

# CelebA-HQ
python generate.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --mixing_type w_interpolation --test_lmdb data/celeba_hq/LMDB_test

# AFHQ
python generate.py --ckpt expr/checkpoints/afhq_256_8x8.pt --mixing_type w_interpolation --test_lmdb data/afhq/LMDB_test

Local editing Results are saved to expr/local_editing. We pair images using a target semantic mask similarity. If you want to see details, please follow preprocessor/README.md.

# Using GroundTruth(GT) segmentation masks for CelebA-HQ dataset.
python generate.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --mixing_type local_editing --test_lmdb data/celeba_hq/LMDB_test --local_editing_part nose

# Using half-and-half masks for AFHQ dataset.
python generate.py --ckpt expr/checkpoints/afhq_256_8x8.pt --mixing_type local_editing --test_lmdb data/afhq/LMDB_test

Unaligned transplantation Results are saved to expr/transplantation. It shows local transplantations examples of AFHQ. We recommend the demo code instead of this.

python generate.py --ckpt expr/checkpoints/afhq_256_8x8.pt --mixing_type transplantation --test_lmdb data/afhq/LMDB_test

Random Generation Results are saved to expr/random_generation. It shows random generation examples.

python generate.py --mixing_type random_generation --ckpt expr/checkpoints/celeba_hq_256_8x8.pt

Style Mixing Results are saved to expr/stylemixing. It shows style mixing examples.

python generate.py --mixing_type stylemixing --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --test_lmdb data/celeba_hq/LMDB_test

Semantic Manipulation Results are saved to expr/semantic_manipulation. It shows local semantic manipulation examples.

python semantic_manipulation.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --LMDB data/celeba_hq/LMDB --svm_train_iter 10000

Metrics

  • Reconstruction: LPIPS, MSE
  • W interpolation: FIDlerp
  • Generation: FID
  • Local editing: MSEsrc, MSEref, Detectability (Refer to CNNDetection)

If you want to see details, please follow metrics/README.md.

License

The source code, pre-trained models, and dataset are available under Creative Commons BY-NC 4.0 license by NAVER Corporation. You can use, copy, tranform and build upon the material for non-commercial purposes as long as you give appropriate credit by citing our paper, and indicate if changes were made.

For business inquiries, please contact [email protected].
For technical and other inquires, please contact [email protected].

Citation

If you find this work useful for your research, please cite our paper:

@inproceedings{kim2021stylemapgan,
  title={Exploiting Spatial Dimensions of Latent in GAN for Real-time Image Editing},
  author={Kim, Hyunsu and Choi, Yunjey and Kim, Junho and Yoo, Sungjoo and Uh, Youngjung},
  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
  year={2021}
}

Related Projects

Model code starts from StyleGAN2 PyTorch unofficial code, which refers to StyleGAN2 official code. LPIPS, FID, and CNNDetection codes are used for evaluation. In semantic manipulation, we used StyleGAN pretrained network to get positive and negative samples by ranking. The demo code starts from Neural-Collage.

stylemapgan's People

Contributors

blandocs avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

stylemapgan's Issues

Questions about metrics for reconstruction

Hello, thanks for sharing your code. I'm very interested in your work and I got a question when I run the metrics for reconstruction. The output of metrics.reconstruction are 0.015(mse) and 0.214(lpips). The output seems much better than the results reported in Table 3 in the paper (mse=0.024, lpips=0.242). I would appreciate it if you could give me an explanation.
image

How to do StyleMixing on custom dataset

Hello. Thanks for this code.
I have trained the model on my own dataset that is different from the face dataset. The image reconstruction, random generation works well. However, If I do the style mixing their is a hard coded pkl file required that is related to the celeba_hq dataset (data/celeba_hq/local_editing/celeba_hq_test_GT_sorted_pair.pkl).

Just for the sake of running stylemixing I downloaded the dataset and ran the code, however, the results are not good.

Could you please share your thoughts on how to apply stylemixing on a custom dataset?

Fail to open demo

Thanks for the excellent work!

I am failed to run the demo on my local machine. The app does not produce the wrong information and the website seems can not load the demo images. Can you help me to fix it?

I can run all the commands except the demo one such as the unaligned transplantation command.

screenshot0
screenshot1

Encoding quality of real images

Hello, thank you for interesting work!
used python demo.py --ckpt expr/checkpoints/celeba_hq_8x8_20M_revised.pt --dataset celeba_hq
I try to evaluate quality of real photos embedding.

It works impressively good on provided by default images:
image

however drastically worse on photos from internet:
image
image
image
(photos are aligned by face landmarks and cropped to 1024x1024 by https://github.com/ZPdesu/Barbershop/blob/main/align_face.py)

what can be the reason and how do I fix this?

source images:
barack
curly2
durov

Download.sh updated link points to possibly corrupted models

I downloaded the pre-trained networks using the updated download.sh files, and pytorch is unable to load them, giving the following error: _pickle.UnpicklingError: invalid load key, ‘<‘

I suspect that the pre-trained network on the google drive is corrupted. Could you check whether it is the case?

GPU memory shortage problem when loading weights from checkpoints

Hello, I would first like to thank you for sharing your work.

I am having problem on loading weights from checkpoints(i.e on continuing halted training)

I am training StyleMapGAN on custom dataset(~200K images in the training dataset, 1024*1024 resoulution), and I am currently using 3 TitanRTX GPUs. I am using latent_spatial_size=16 considering input image resolution and GPU memory. On training with such configuration, batch 2 is allocated per GPU using ~21 GiB memory.

There is no problem on training from scratch. I have not tried using pretrained weights trained on FFHQ or CelebA because my data is quite different from human faces. Moreover, as I have succeeded on generating images from generate.py, I think weights were saved in proper way.

However, memory allocation problem occurs every time I load custom weights to continue training. I assumed extra memory may be required on loading weights, so I tried using smaller batch size (batch 2 per GPU->batch 1 per GPU), but same memory shortage problem occurs.

To summarize, I cannot load weights to continue training, whereas training from scratch or loading weights to generate images are working well. Thereafter, I would like to ask following questions.

  1. Had any of the authors experienced with similar problems?
  2. Would there be any possible solutions to my problem?

I would be grateful if you take a look into my question. Thank you!

can't download model

hello, how can I get the pre training model....It seems can't download it now

UnpicklingError: invalid load key, '<'.

Hello authors,

I have an error when using the pretrained checkpoints when torch.load(args.ckpt) is run in generate.py (or any piece of code with the function). I tried re-downloading the model as proposed in this issue.
The only change I made to accommodate my CUDA 11 version was to instal pytorch with conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch. Thanks for you attention, it is much appreciated.

How to edit my own images ?

Hello,
We followed the guide to prepare_data for my own images, and generate lmdb file. When run testing:
python generate.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --mixing_type local_editing --test_lmdb data/celeba_hq/LMDB_test --local_editing_part nose
It inform no image.

How do we test on my own image ?

Different sizes between downloaded lmdb data and generated data with prepare_data.py

Hi,

I use the provided prepare_data.py to transfer the downloaded raw images to the corresponding lmdb data with 256x256 size. But I find that the lmdb data has different sizes compared to the downloaded lmdb data (LMDB_train/test/val). I also tried other lmdb sizes such as 128, 512 and 1024 but no image size can match the downloaded lmdb data.

May I figure out the reason?

Thanks.

Generated Result out of memory

when I press the button to generate result it works only once, the second time it throws a cuda error.

/content/StyleMapGAN
train_args:  Namespace(batch=16, batch_per_gpu=8, channel_multiplier=2, ckpt=None, d_reg_every=16, dataset='celeba_hq', iter=1400000, lambda_adv_loss=1, lambda_d_loss=1, lambda_indomainGAN_D_loss=1, lambda_indomainGAN_E_loss=1, lambda_perceptual_loss=1, lambda_w_rec_loss=1, lambda_x_rec_loss=1, latent_channel_size=64, latent_spatial_size=8, lr=0.002, lr_mul=0.01, mapping_layer_num=8, mapping_method='MLP', n_sample=16, ngpus=2, normalize_mode='LayerNorm', num_workers=10, r1=10, remove_indomain=False, remove_w_rec=False, size=256, small_generator=False, start_iter=0, train_lmdb='/data/celeba_hq_lmdb/train/LMDB_train', val_lmdb='/data/celeba_hq_lmdb/train/LMDB_val')
 * Serving Flask app "demo" (lazy loading)
 * Environment: production
   WARNING: Do not use the development server in a production environment.
   Use a production WSGI server instead.
 * Debug mode: on
 * Running on http://127.0.0.1:6006/ (Press CTRL+C to quit)
 * Restarting with stat
train_args:  Namespace(batch=16, batch_per_gpu=8, channel_multiplier=2, ckpt=None, d_reg_every=16, dataset='celeba_hq', iter=1400000, lambda_adv_loss=1, lambda_d_loss=1, lambda_indomainGAN_D_loss=1, lambda_indomainGAN_E_loss=1, lambda_perceptual_loss=1, lambda_w_rec_loss=1, lambda_x_rec_loss=1, latent_channel_size=64, latent_spatial_size=8, lr=0.002, lr_mul=0.01, mapping_layer_num=8, mapping_method='MLP', n_sample=16, ngpus=2, normalize_mode='LayerNorm', num_workers=10, r1=10, remove_indomain=False, remove_w_rec=False, size=256, small_generator=False, start_iter=0, train_lmdb='/data/celeba_hq_lmdb/train/LMDB_train', val_lmdb='/data/celeba_hq_lmdb/train/LMDB_val')
 * Debugger is active!
 * Debugger PIN: 746-628-559
127.0.0.1 - - [17/Sep/2021 23:56:51] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:56:51] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "POST /post HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/15.png?ehKqCSWsm2_dZDSEqGn2KA HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/14.png?bpSTEHoDmsEdxklrmK1Iyw HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/13.png?vNWTCVyaDi3SgiGKoZVgqg HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/12.png?YpPhMa2VRbTKOyk_bXF1uw HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/11.png?Rv2b_tny9RSuUGjNv-Aj8w HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/10.png?gUzT-B3BIBWt8zKBCN-UJQ HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/9.png?hgD_0xCxc1D3SvoRnOL4kQ HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/8.png?irG3xuBZ7T4aUC1KSpNl0g HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/7.png?FzAEfzS1Y-PaknX8gkh3Ow HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/6.png?qSPaQEGIwCJJcp_lqKLDrA HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/5.png?uHnXqW2275FRM4kIAWlWoQ HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/4.png?KnBxIm4rmz5WSjDZ3i9fUA HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/3.png?i5-jUUhv1ey199OxX_dbCg HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/2.png?CiAW6qyZmqz93OLXoZx_MQ HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/0.png?72jqdYShEUiOpkSMW9R_qg HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2021 23:57:30] "GET /demo/static/generated/93huqTOEE9NBQ3MPvw41/1.png?8AK-OswmqdND3xa-F9ZJvw HTTP/1.1" 200 -
127.0.0.1 - - [18/Sep/2021 00:14:17] "POST /post HTTP/1.1" 500 -
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/site-packages/flask/app.py", line 2309, in __call__
    return self.wsgi_app(environ, start_response)
  File "/usr/local/lib/python3.7/site-packages/flask/app.py", line 2295, in wsgi_app
    response = self.handle_exception(e)
  File "/usr/local/lib/python3.7/site-packages/flask/app.py", line 1741, in handle_exception
    reraise(exc_type, exc_value, tb)
  File "/usr/local/lib/python3.7/site-packages/flask/_compat.py", line 35, in reraise
    raise value
  File "/usr/local/lib/python3.7/site-packages/flask/app.py", line 2292, in wsgi_app
    response = self.full_dispatch_request()
  File "/usr/local/lib/python3.7/site-packages/flask/app.py", line 1815, in full_dispatch_request
    rv = self.handle_user_exception(e)
  File "/usr/local/lib/python3.7/site-packages/flask/app.py", line 1718, in handle_user_exception
    reraise(exc_type, exc_value, tb)
  File "/usr/local/lib/python3.7/site-packages/flask/_compat.py", line 35, in reraise
    raise value
  File "/usr/local/lib/python3.7/site-packages/flask/app.py", line 1813, in full_dispatch_request
    rv = self.dispatch_request()
  File "/usr/local/lib/python3.7/site-packages/flask/app.py", line 1799, in dispatch_request
    return self.view_functions[rule.endpoint](**req.view_args)
  File "/content/StyleMapGAN/demo.py", line 181, in post
    save_dir=save_dir,
  File "/usr/local/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 49, in decorate_no_grad
    return func(*args, **kwargs)
  File "/content/StyleMapGAN/demo.py", line 133, in my_morphed_images
    mixed = model(original_image, reference_images, masks, shift_values).cpu()
  File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/content/StyleMapGAN/demo.py", line 73, in forward
    mask=[masks, shift_values, args.interpolation_step],
  File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/content/StyleMapGAN/training/model.py", line 1153, in forward
    image = self.decoder(stylecode, mix_space=mix_space, mask=mask)
  File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/content/StyleMapGAN/training/model.py", line 1073, in forward
    out = self.convs[i](out, [style_codes[2 * i + 1], style_codes[2 * i + 2]])
  File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/content/StyleMapGAN/training/model.py", line 720, in forward
    out = self.conv2(out, stylecodes[1])
  File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/content/StyleMapGAN/training/model.py", line 499, in forward
    out = self.conv(input, style)
  File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/content/StyleMapGAN/training/model.py", line 610, in forward
    input = input * gamma + beta
RuntimeError: CUDA out of memory. Tried to allocate 512.0
```0 MiB (GPU 0; 11.17 GiB total capacity; 6.90 GiB already allocated; 215.44 MiB free; 8.10 GiB reserved in total by PyTorch)

metrics in table 3

how can u get MSE, LPIPS of image2stylegan, structured noise and stylegan2 in table3, i mean, structured noise didn't provide the code about how to get image reconstruction, and image2stylegan didn't provide codes(No official code found), and what about stylegan2, can u tell me please?

Reconstrcution n FFHQ?

I would like to run some operations, like construction on FFHQ_1024 but I am missing the lmdb files and set up. Can you make it available?

How many iterations do you use to train the model

Thank you for publishing the code. Your work is very impressive. I wonder how many iterations do you use to train the model?I noticed that the default is 1,400,000. Is this value for all your training datasets? 1400000 training iterations a bit too long, I am curious if there is a small number of training iteration.

How can I train on my own data

Hi,authors,thank you for your great work. What should I do if I want to train your model on my own dataset? I think you are supposed to tell readers how to process their own data if they want to use it for training. Another question is why not release the model about LSUN Car & Church?I want to get the model to vertify the effect your model behaved on LSUN Car & Church.

can i get a new pretrained network file?

there is how to get pretrained network

# Download raw images and create LMDB datasets using them
# Additional files are also downloaded for local editing
bash download.sh create-lmdb-dataset celeba_hq

# Download the pretrained network (256x256) 
bash download.sh download-pretrained-network-256 celeba_hq # 20M-image-trained models
bash download.sh download-pretrained-network-256 celeba_hq_5M # 5M-image-trained models used in our paper for comparison with other baselines and for ablation studies.

# Download the pretrained network (1024x1024 image / 16x16 stylemap / Light version of Generator)
bash download.sh download-pretrained-network-1024 ffhq_16x16

but with these networks, it doesn't work

File "demo.py", line 192, in <module> ckpt = torch.load(args.ckpt) File "/root/anaconda3/envs/stylemap/lib/python3.6/site-packages/torch/serialization.py", line 608, in load return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args) File "/root/anaconda3/envs/stylemap/lib/python3.6/site-packages/torch/serialization.py", line 777, in _legacy_load magic_number = pickle_module.load(f, **pickle_load_args) _pickle.UnpicklingError: invalid load key, '<'.

and i found that after run train.py, i got 000000.pt (didn't trained)
and with this 000000.pt, demo.py works well,(but output image is noisy image)

so is there any way get new pretrained network?

i tried on (pytorch= 1.4.0, 1.10) and (remote server-docker , colab)

RuntimeError: Trying to backward through the graph a second time

Hi there, thanks so much for sharing the codes - really amazing work!

When I was trying to train the model on my own, I ran into an error at line 383 in train.py ((w_rec_loss * args.lambda_w_rec_loss).backward()) that

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

I read the model definition and the implementation looked correct to me - so I don't understand why such error was thrown. Do you maybe have any idea what could have gone wrong?

In the meanwhile, to unblock myself, I modified the codes a bit to run backward() for g_loss and w_rec_loss in one go (g_w_rec_loss in the following example). Does this modification make sense to you? Why did you separate the backward operation in the first place?

        adv_loss, w_rec_loss, stylecode = model(None, "G")
        adv_loss = adv_loss.mean()
        w_rec_loss = w_rec_loss.mean()
        g_loss = adv_loss * args.lambda_adv_loss

        g_optim.zero_grad()
        e_optim.zero_grad()
        g_w_rec_loss = g_loss + w_rec_loss * args.lambda_w_rec_loss
        g_w_rec_loss.backward()
        gather_grad(
            g_module.parameters(), world_size
        )  # Explicitly synchronize Generator parameters. There is a gradient sync bug in G.
        g_optim.step()
        e_optim.step()

Thanks in advance for your help!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.