GithubHelp home page GithubHelp logo

Comments (33)

kato-megumi avatar kato-megumi commented on June 7, 2024 2

Btw I have no luck on converting GLSL shaders into actual models in PyTorch. I believe the problem is transforming the weight.

@Fannovel16
This is my script to convert GLSL shaders to PyTorch model.
https://gist.github.com/kato-megumi/44e52b4cc0e082e94d452a7df04243e0

from anime4k.

Fannovel16 avatar Fannovel16 commented on June 7, 2024 2

@arianaa30 Here it is
P/s: I made some changes based on kato's advice

def get_luma(x):
    x = x[:, 0] * 0.299 + x[:, 1] * 0.587 + x[:, 2] * 0.114
    x = x.unsqueeze(1)
    return x

class MaxPoolKeepShape(nn.Module):
    def __init__(self, kernel_size, stride=None):
        super(MaxPoolKeepShape, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride if stride is not None else kernel_size

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        kernel_height, kernel_width = self.kernel_size
        pad_height = (((height - 1) // self.stride + 1) - 1) * self.stride + kernel_height - height
        pad_width = (((width - 1) // self.stride + 1) - 1) * self.stride + kernel_width - width

        x = F.pad(x, (pad_width // 2, pad_width - pad_width // 2, pad_height // 2, pad_height - pad_height // 2))
        x = F.max_pool2d(x, kernel_size=self.kernel_size, stride=self.stride)
        return x

class ClampHighlight(nn.Module):
    def __init__(self):
        super(ClampHighlight, self).__init__()
        self.max_pool = MaxPoolKeepShape(kernel_size=(5, 5), stride=1)
    def forward(self, shader_img, orig_img):
        curr_luma = get_luma(shader_img)
        statsmax = self.max_pool(get_luma(orig_img))
        if statsmax.shape != curr_luma.shape:
            statsmax = F.interpolate(statsmax, curr_luma.shape[2:4])
        new_luma = torch.min(curr_luma, statsmax)
        return shader_img - (curr_luma - new_luma)

new_img = ClampHighlight()(out[None], image2)
display(to_pil(new_img[0]))

from anime4k.

kato-megumi avatar kato-megumi commented on June 7, 2024 2

I recommend using https://github.com/muslll/neosr/ to train model.
Just put pytorch model in arch/ folder, tweak some config in yml file and train.

from anime4k.

Fannovel16 avatar Fannovel16 commented on June 7, 2024 1

@Tama47 From what I've researched so far, there is no way to convert current version of MLModel to TF2 or ONNX. However, I managed to get Netron working and also loading weight:

  1. Change the file extension from .whml to .zip
  2. Compress all files inside of *.mlpackage folder (not including the folder itself) to a zip file

Preset-a-hq:
preset-a-hq zip

from anime4k.

kato-megumi avatar kato-megumi commented on June 7, 2024 1

Of course, just string the models together like this: model2(model1(image))

from anime4k.

Tama47 avatar Tama47 commented on June 7, 2024

Is there a way to train/load S/M/L CNN models in tensorflow?

Yes, you would need to load the original models in TensorFlow.

I see that there is one specific model in the tensorflow directory, but I am not sure which one is it.

Someone has converted the original Anime4K models into Core ML models. I can provide you the link.

The ones you're looking for are under Models >
model-sr-s.wifm / model-sr-m.wifm / model-sr-l.wifm for upscale models
model-restore-s.wifm / model-restore-m.wifm / model-restore-l.wifm for restore models

You would need to convert them to TensorFlow, then create a Python or Jupyter Notebook script to load the weights and models. You can use the models to fine-tune and train your own, better model.

Note I have not converted or trained the models myself, and cannot guarantee success. I can only provide general steps, and you will need to do your own research. Supposedly, the steps to convert between Core ML and TensorFlow should be relatively straightforward. The training process itself should be more or less the same as training any other TensorFlow or ESRGAN models.

Sample Python Script:
import coremltools as ct


# Load Core ML model coreml_model = ct.models.MLModel('model-sr-s.wifm')

# Convert Core ML to TensorFlow tf_model = ct.convert(coreml_model, source='mlmodel', target='tensorflow')

from anime4k.

arianaa30 avatar arianaa30 commented on June 7, 2024

@Tama47 The training code located in \tensorflow dir is for the restore or upscale model? And if it is the restore, is it easy to change it to the "upscale" model to train?

from anime4k.

Fannovel16 avatar Fannovel16 commented on June 7, 2024

I can convert some GLSL files to PyTorch now but still stuck at converting the weight. Here is the code if anyone interested:
https://colab.research.google.com/drive/11xAn4fyAUJPZOjrxwnL2ipl_1DGGegkB

from anime4k.

Fannovel16 avatar Fannovel16 commented on June 7, 2024

@arianaa30

The training code located in \tensorflow dir is for the restore or upscale model?

It contains both. In Gen_Shader.ipynb, SR1Model generate restore models while SR2Model generate upscale ones

from anime4k.

arianaa30 avatar arianaa30 commented on June 7, 2024

@arianaa30

The training code located in \tensorflow dir is for the restore or upscale model?

It contains both. In Gen_Shader.ipynb, SR1Model generate restore models while SR2Model generate upscale ones

Thanks. Apparently the models there for for M size shaders. Do you happen to know what parameter values (block_depth, etc?) I should use to get the S/L sizes?

Also, the code uses epochs=1 (3 times). Should I change them to like 100? I noticed the loss doesn't really decrease.

from anime4k.

Fannovel16 avatar Fannovel16 commented on June 7, 2024

@arianaa30

Thanks. Apparently the models there for for M size shaders. Do you happen to know what parameter values (block_depth, etc?) I should use to get the S/L sizes?

Ig you can figure out the block_depth with a model's components
Conv2d(3, 4) means it has 3 input channels and 4 output channels. CReLU() activation function doubles channel size, e.g. (1, 128, 128, 4) -> (1, 128, 128, 8)

Size S:

Anime4K(
  (conv2d_0): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_last_0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (upsample): Upsample(scale_factor=2.0, mode='bilinear')
  (depth_to_space): PixelShuffle(upscale_factor=2)
)

Size M

Anime4K(
  (conv2d_0): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_3_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_4_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_5_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_6_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_last_0): Conv2d(56, 4, kernel_size=(1, 1), stride=(1, 1))
  (upsample): Upsample(scale_factor=2.0, mode='bilinear')
  (depth_to_space): PixelShuffle(upscale_factor=2)
)

Size L:

Anime4K(
  (conv2d_0): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_last_0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2d_last_1): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2d_last_2): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (upsample): Upsample(scale_factor=2.0, mode='bilinear')
  (depth_to_space): PixelShuffle(upscale_factor=2)
)

Size VL:

Anime4K(
  (conv2d_0): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_3_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_3_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_4_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_4_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_5_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_5_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_6_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_6_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_last_0): Conv2d(112, 4, kernel_size=(1, 1), stride=(1, 1))
  (conv2d_last_1): Conv2d(112, 4, kernel_size=(1, 1), stride=(1, 1))
  (conv2d_last_2): Conv2d(112, 4, kernel_size=(1, 1), stride=(1, 1))
  (upsample): Upsample(scale_factor=2.0, mode='bilinear')
  (depth_to_space): PixelShuffle(upscale_factor=2)
)

Size UL:

Anime4K(
  (conv2d_0): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1_0): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1_1): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1_2): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2_0): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2_1): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2_2): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_3_0): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_3_1): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_3_2): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_4_0): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_4_1): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_4_2): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_5_0): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_5_1): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_5_2): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_6_0): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_6_1): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_6_2): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_last_0): Conv2d(120, 4, kernel_size=(1, 1), stride=(1, 1))
  (conv2d_last_1): Conv2d(120, 4, kernel_size=(1, 1), stride=(1, 1))
  (conv2d_last_2): Conv2d(120, 4, kernel_size=(1, 1), stride=(1, 1))
  (upsample): Upsample(scale_factor=2.0, mode='bilinear')
  (depth_to_space): PixelShuffle(upscale_factor=2)
)

from anime4k.

arianaa30 avatar arianaa30 commented on June 7, 2024

Thanks. I have those architecture. But do you know what to pass to this function to get each of those S, L, VL sizes? I need it for training.

def SR2Model (input_depth=3, highway_depth=4, block_depth=4, init='he_normal', init_last = RandomNormal (mean=0.0, stddev=0.001)):

from anime4k.

Fannovel16 avatar Fannovel16 commented on June 7, 2024

@arianaa30 My main library is PyTorch so Idk tbh

from anime4k.

arianaa30 avatar arianaa30 commented on June 7, 2024

@Fannovel16 Btw, do you know how to measure SSIM/PSNR of what Anime4K shaders provide me (upscaled version of low-res image) vs the original high resolution image? Is there a way to measure them?

from anime4k.

Fannovel16 avatar Fannovel16 commented on June 7, 2024

@arianaa30 You can pass images to mpv, ffmpeg with compiled libplacebo using commands and save upscaled images. I'm not sure how to do that tho since I'm not familiar much with ffmpeg or mpv. Alternatively, you can try Anime4K-GPU and use canvas.toDataURL("image/png") to save the results

from anime4k.

Fannovel16 avatar Fannovel16 commented on June 7, 2024

Btw I have no luck on converting GLSL shaders into actual models in PyTorch. I believe the problem is transforming the weight.

from anime4k.

arianaa30 avatar arianaa30 commented on June 7, 2024

@arianaa30 You can pass images to mpv, ffmpeg with compiled libplacebo using commands and save upscaled images. I'm not sure how to do that tho since I'm not familiar much with ffmpeg or mpv. Alternatively, you can try Anime4K-GPU and use canvas.toDataURL("image/png") to save the results

Hmm ok thanks. The problem is we apply multiple anime4k shaders (restore, upscale, restore, ...). Not sure if we can do that..

from anime4k.

Fannovel16 avatar Fannovel16 commented on June 7, 2024

@arianaa30 It's possible: mpv-player/mpv#9589. But now you mentioned it, I kinda wonder how A4K shaders were actually trained.

from anime4k.

arianaa30 avatar arianaa30 commented on June 7, 2024

@Fannovel16 Yeah the training has some unknowns. Using the Tensorflow script, I trained a model/shader by calling SR2Model() function, and it works. But when I trained the SR1Model (which should be the Restore), the h5 model training works. But when trying to convert with Gen_Shader.py, it shows me a "Shape Mismatch" error. Have you experienced it before?

  • Adding @Tama47 In case you might have insights on this.
 Layer (type)                                Output Shape                                 Param #        Connected to
======================================================================================================================================================
 input.MAIN (InputLayer)                     [(None, None, None, 3)]                      0              []

 conv2d (Conv2D)                             (None, None, None, 4)                        112            ['input.MAIN[0][0]']

 tf.compat.v1.nn.crelu (TFOpLambda)          (None, None, None, 8)                        0              ['conv2d[0][0]']

 conv2d_1 (Conv2D)                           (None, None, None, 4)                        292            ['tf.compat.v1.nn.crelu[0][0]']

 tf.compat.v1.nn.crelu_1 (TFOpLambda)        (None, None, None, 8)                        0              ['conv2d_1[0][0]']

 conv2d_2 (Conv2D)                           (None, None, None, 4)                        292            ['tf.compat.v1.nn.crelu_1[0][0]']

 tf.compat.v1.nn.crelu_2 (TFOpLambda)        (None, None, None, 8)                        0              ['conv2d_2[0][0]']

 conv2d_3 (Conv2D)                           (None, None, None, 4)                        292            ['tf.compat.v1.nn.crelu_2[0][0]']

 tf.compat.v1.nn.crelu_3 (TFOpLambda)        (None, None, None, 8)                        0              ['conv2d_3[0][0]']

 conv2d_4 (Conv2D)                           (None, None, None, 4)                        292            ['tf.compat.v1.nn.crelu_3[0][0]']

 tf.compat.v1.nn.crelu_4 (TFOpLambda)        (None, None, None, 8)                        0              ['conv2d_4[0][0]']

 conv2d_5 (Conv2D)                           (None, None, None, 4)                        292            ['tf.compat.v1.nn.crelu_4[0][0]']

 tf.compat.v1.nn.crelu_5 (TFOpLambda)        (None, None, None, 8)                        0              ['conv2d_5[0][0]']

 conv2d_6 (Conv2D)                           (None, None, None, 4)                        292            ['tf.compat.v1.nn.crelu_5[0][0]']

 tf.compat.v1.nn.crelu_6 (TFOpLambda)        (None, None, None, 8)                        0              ['conv2d_6[0][0]']

 concatenate (Concatenate)                   (None, None, None, 56)                       0              ['tf.compat.v1.nn.crelu[0][0]',
                                                                                                          'tf.compat.v1.nn.crelu_1[0][0]',
                                                                                                          'tf.compat.v1.nn.crelu_2[0][0]',
                                                                                                          'tf.compat.v1.nn.crelu_3[0][0]',
                                                                                                          'tf.compat.v1.nn.crelu_4[0][0]',
                                                                                                          'tf.compat.v1.nn.crelu_5[0][0]',
                                                                                                          'tf.compat.v1.nn.crelu_6[0][0]']

 conv2d_lastresid.MAIN (Conv2D)              (None, None, None, 3)                        171            ['concatenate[0][0]']

 add.ignore.MAIN (Add)                       (None, None, None, 3)                        0              ['conv2d_lastresid.MAIN[0][0]',
                                                                                                          'input.MAIN[0][0]']

======================================================================================================================================================
Total params: 2035 (7.95 KB)
Trainable params: 2035 (7.95 KB)
Non-trainable params: 0 (0.00 Byte)
______________________________________________________________________________________________________________________________________________________
Traceback (most recent call last):
  File "Gen_Shader.py", line 141, in <module>
    model.load_weights("model-checkpoint-new.h5")
  File "/opt/tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/opt/tensorflow/lib/python3.10/site-packages/keras/src/backend.py", line 4361, in _assign_value_to_variable
    variable.assign(value)
ValueError: Cannot assign value to variable ' conv2d_lastresid.MAIN/kernel:0': Shape mismatch.The variable shape (1, 1, 56, 3), and the assigned value shape (12, 56, 1, 1) are incompatible.

from anime4k.

arianaa30 avatar arianaa30 commented on June 7, 2024

Btw I have no luck on converting GLSL shaders into actual models in PyTorch. I believe the problem is transforming the weight.

@Fannovel16 This is my script to convert GLSL shaders to PyTorch model. https://gist.github.com/kato-megumi/44e52b4cc0e082e94d452a7df04243e0

Is the displayed image the upscaled output? Can we apply multiple shaders as well?

from anime4k.

Fannovel16 avatar Fannovel16 commented on June 7, 2024

@kato-megumi Thanks! It seems like I got the CreLU formula wrong

from anime4k.

arianaa30 avatar arianaa30 commented on June 7, 2024

Of course, just string the models together like this: model2(model1(image))

Great thanks. Can we simply add other shaders to the list as well? I want to use Anime4K_Clamp_Highlights.glsl as well. Instructions highly recommend to have this in the list as it highly increases the quality.

from anime4k.

kato-megumi avatar kato-megumi commented on June 7, 2024
  • ClampHighlight clamps the output of another shader using the original image's luminance, so it requir two images as input.
  • The kernel size should be (5, 5).

from anime4k.

Fannovel16 avatar Fannovel16 commented on June 7, 2024

@kato-megumi

The kernel size should be (5, 5).

Oh so the first block iterates x-axis while the second block iterates y-axis

ClampHighlight clamps the output of another shader using the original image's luminance, so it requir two images as input.

What is PREKERNEL? I assumed it is the same as MAIN as the mpv doc is a bit ambiguous

from anime4k.

kato-megumi avatar kato-megumi commented on June 7, 2024

Oh so the first block iterates x-axis while the second block iterates y-axis

Yeah, it reduce computation cost compare to find max of 25 pixel in single pass.

What is PREKERNEL? I assumed it is the same as MAIN as the mpv doc is a bit ambiguous

In anime4k doc about ClampHighlight: "Computes and saves image statistics at the location it is placed in the shader stage, then clamps the image highlights at the end after all the shaders to prevent overshoot and reduce ringing."

PREKERNEL
The image immediately before the scaler kernel runs.

I think it refers to the image right before mpv performs internal scaling.
Other shaders are hooked to MAIN, which come before PREKERNEL in mpv's rendering process, so those should run first.

from anime4k.

Fannovel16 avatar Fannovel16 commented on June 7, 2024

I added ClampHightlight, AutoDownscalePre, automatic glsl downloading and pipeline class for convenience:
https://colab.research.google.com/drive/11xAn4fyAUJPZOjrxwnL2ipl_1DGGegkB

from anime4k.

arianaa30 avatar arianaa30 commented on June 7, 2024

@arianaa30 Here it is P/s: I made some changes based on kato's advice

def get_luma(x):
    x = x[:, 0] * 0.299 + x[:, 1] * 0.587 + x[:, 2] * 0.114
    x = x.unsqueeze(1)
    return x

class MaxPoolKeepShape(nn.Module):
    def __init__(self, kernel_size, stride=None):
        super(MaxPoolKeepShape, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride if stride is not None else kernel_size

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        kernel_height, kernel_width = self.kernel_size
        pad_height = (((height - 1) // self.stride + 1) - 1) * self.stride + kernel_height - height
        pad_width = (((width - 1) // self.stride + 1) - 1) * self.stride + kernel_width - width

        x = F.pad(x, (pad_width // 2, pad_width - pad_width // 2, pad_height // 2, pad_height - pad_height // 2))
        x = F.max_pool2d(x, kernel_size=self.kernel_size, stride=self.stride)
        return x

class ClampHighlight(nn.Module):
    def __init__(self):
        super(ClampHighlight, self).__init__()
        self.max_pool = MaxPoolKeepShape(kernel_size=(5, 5), stride=1)
    def forward(self, shader_img, orig_img):
        curr_luma = get_luma(shader_img)
        statsmax = self.max_pool(get_luma(orig_img))
        if statsmax.shape != curr_luma.shape:
            statsmax = F.interpolate(statsmax, curr_luma.shape[2:4])
        new_luma = torch.min(curr_luma, statsmax)
        return shader_img - (curr_luma - new_luma)

new_img = ClampHighlight()(out[None], image2)
display(to_pil(new_img[0]))

Great I will try it.

from anime4k.

arianaa30 avatar arianaa30 commented on June 7, 2024

@Fannovel16 Btw do you have a training code for the PyTorch models? Would you be able to share?

from anime4k.

Fannovel16 avatar Fannovel16 commented on June 7, 2024

@arianaa30 No but you can use my notebook to get the model and randomize its parameters to train

from anime4k.

arianaa30 avatar arianaa30 commented on June 7, 2024

@arianaa30 No but you can use my notebook to get the model and randomize its parameters to train

Should I fine tune it (only train last layers) or train the whole network?
Btw your notebook shows some errors in the convert() function and use of combination () when I want to run the pipeline code. Maybe something recently changed.

from anime4k.

Fannovel16 avatar Fannovel16 commented on June 7, 2024

@arianaa30 I forgot to test 😅 . It works now

Should I fine tune it (only train last layers) or train the whole network?

Anime4K's CNN networks are pretty small so training from scratch is a better choice, imo.

from anime4k.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.