GithubHelp home page GithubHelp logo

cyclegan-keras's People

Contributors

shaofanl avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

cyclegan-keras's Issues

array dimension not consistent error

this project is great, i just search for a CycleGAN implemented in keras, when i run the demo, i encounter a error "expected input_3 to have shape (None, 3, 128, 128) but got array with shape (1, 128, 128, 3)" accured in "fake_A_pool.extend(self.BtoA.predict(real_B))" , it seems there are some problems with the array dimensions conversion, can you confirm it, thank you very much.

Found during testing

In resnet.py:

def resnet_6blocks(input_shape, output_nc, ngf, **kwargs):
ks = 3
f = 7
p = (f-1)/2 --> p = (f-1)//2
to avoid float argument causing error later

identity loss is not correct

I found in cyclegan.py, you are using opt.idloss as the loss weight of identity loss

if opt.idloss > 0:
            G_trainner = Model([real_A, real_B], 
                     [dis_fake_B,   dis_fake_A,     rec_A,      rec_B,      fake_B,     fake_A])
            
            G_trainner.compile(Adam(lr=opt.lr, beta_1=opt.beta1,),
                loss=['MSE',        'MSE',          'MAE',      'MAE',      'MAE',      'MAE'],
                loss_weights=[1,    1,              opt.lmbd,   opt.lmbd,   opt.idloss  ,opt.idloss])

but, in the original pytorch version, they are using

if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            idt_A = self.netG_A(self.real_B)
            loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt  # loss part?

So I think the weight of identity loss should be opt.idloss*opt.lmbd?

Loss is incorrect?

Hi,

Thanks for your valuable code. But I don't quite follow the logic for calculating the gan loss in the cyclegan.py. My understanding is the discriminator tries to tell the real one is real(1), fake one is fake(0). so in cyclegan.py

            for _ in range(opt.d_iter):
                _, D_loss_real_A, D_loss_fake_A, D_loss_real_B, D_loss_fake_B = \
                    self.D_trainner.train_on_batch([real_A, fake_A, real_B, fake_B],
                        [zeros, ones*0.9, zeros, ones*0.9])

should be

            for _ in range(opt.d_iter):
                _, D_loss_real_A, D_loss_fake_A, D_loss_real_B, D_loss_fake_B = \
                    self.D_trainner.train_on_batch([real_A, fake_A, real_B, fake_B],
                        [ones*0.9, zeros, ones*0.9, zeros])

While the generator tries to fool the system to treat fake one as real. so in cyclegan.py

                _, G_loss_fake_B, G_loss_fake_A, G_loss_rec_A, G_loss_rec_B = \
                    self.G_trainner.train_on_batch([real_A, real_B],
                        [zeros, zeros, real_A, real_B, ])

should be

                _, G_loss_fake_B, G_loss_fake_A, G_loss_rec_A, G_loss_rec_B = \
                    self.G_trainner.train_on_batch([real_A, real_B],
                        [ones, ones, real_A, real_B, ])

Have you tried same experiment e.g. horse2zebra using this code? How many epochs can you obtain reasonable results? I tried to run the code with tensorflow backend with slightly modification but never obtain correct results. I am wondering if this logic mentioned above has something to do with it. Please advice. Thanks.

Dimensions error

ValueError: Dimensions must be equal, but are 3 and 4 for 'gen_A/instance_normalization2d_21/mul' (op: 'Mul') with input shapes: [1,3,1,1], [?,4,128,64].

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.