GithubHelp home page GithubHelp logo

gmfss's People

Contributors

hyw-dev avatar justin62628 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

gmfss's Issues

Request: fp16 inference

This repository is amazing, I really like the results, but the vram usage is quite high when I try to use it with vapoursynth. 512x512px needs around 15gb VRAM with a real video. The speed is not bad, but I struggle to get some sort of usable resolution. For some reason with your test.py it needs much less vram. When I try to use the model class directly, for some reason it needs a lot more memory.

When i use test.py with a loop and also 512x512:

from tqdm import tqdm
for i in tqdm(range(1000)):
  result = make_inference(I0, I1, n_frames, scale, pred_bidir_flow)

image


class Model_inference(nn.Module):
    def __init__(self):
        super(Model_inference, self).__init__()
        self.model = Model()

    def forward(self, I0, I1):
        n_frames = 1
        scale = 1.0  # flow scale
        timesteps = [i / (n_frames + 1) for i in range(1, n_frames + 1)]

        # padding frames
        n, c, h, w = I0.shape
        tmp = max(32, int(32 / scale))
        ph = ((h - 1) // tmp + 1) * tmp
        pw = ((w - 1) // tmp + 1) * tmp
        padding = (0, pw - w, 0, ph - h)
        I0 = F.pad(I0, padding)
        I1 = F.pad(I1, padding)        
        return self.model.inference(I0, I1, timesteps, scale, pred_bidir_flow=False)[0][:h, :w, :]


model = Model_inference()
model.eval().cuda()

test_input = torch.rand(1,3,720,1280).cuda()

out = model(test_input, test_input)
print(out.shape)


from tqdm import tqdm
for i in tqdm(range(1000)):
  out = model(test_input, test_input)

image

720p is not possible.
image

512x512 is the highest i can do with vapoursynth for now.
image

I tried to add fp16 a bit but I didn't get far yet. With fp16 it should use a lot less vram.

import itertools
import numpy as np
import vapoursynth as vs
import functools
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM

import torch 
device = torch.device("cuda")
torch.set_grad_enabled(False)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

def make_inference(model, I0, I1, n, scale, pred_bidir_flow=False):
    timesteps = [i / (n + 1) for i in range(1, n + 1)]
    return model.inference(I0, I1, timesteps, scale, pred_bidir_flow)

# https://github.com/HolyWu/vs-rife/blob/master/vsrife/__init__.py
def GMFSS(
    clip: vs.VideoNode,
    fp16: bool = True,
) -> vs.VideoNode:

    core = vs.core
    from .GMFSS_arch import Model
    import torch

    n_frames = 1
    scale = 1.0  # flow scale
    pred_bidir_flow = False  # Estimate bilateral optical flow at once (accelerate)

    device = torch.device("cuda")
    torch.set_grad_enabled(False)
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True

    model = Model()
    model.load_model("/workspace/tensorrt/VSGAN-tensorrt-docker/weights/", -1)
    model.eval()
    model.device("cuda")

    w = clip.width
    h = clip.height
    scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale]


    def frame_to_tensor(frame: vs.VideoFrame):
        return np.stack(
            [np.asarray(frame[plane]) for plane in range(frame.format.num_planes)]
        )

    def tensor_to_frame(f: vs.VideoFrame, array) -> vs.VideoFrame:
        for plane in range(f.format.num_planes):
            d = np.asarray(f[plane])
            np.copyto(d, array[plane, :, :])
        return f

    def tensor_to_clip(clip: vs.VideoNode, image) -> vs.VideoNode:
        clip = core.std.BlankClip(
            clip=clip, width=image.shape[-1], height=image.shape[-2]
        )
        return core.std.ModifyFrame(
            clip=clip,
            clips=clip,
            selector=lambda n, f: tensor_to_frame(f.copy(), image),
        )

    def execute(n: int, clip: vs.VideoNode) -> vs.VideoNode:
        if (
            (n % 2 == 0)
            or n == 0
            or n in skip_framelist
            or n == clip.num_frames - 1
        ):
            return clip

        I0 = frame_to_tensor(clip.get_frame(n - 1))
        I1 = frame_to_tensor(clip.get_frame(n + 1))

        I0 = torch.Tensor(I0).unsqueeze(0).to("cuda", non_blocking=True)
        I1 = torch.Tensor(I1).unsqueeze(0).to("cuda", non_blocking=True)
        middle = make_inference(model, I0, I1, n_frames, scale, pred_bidir_flow)[0]
        middle = middle.swapaxes(0, 2).swapaxes(1, 2)/255

        return tensor_to_clip(clip=clip, image=middle)

    clip = core.std.Interleave([clip, clip])
    return core.std.FrameEval(
        core.std.BlankClip(clip=clip, width=clip.width, height=clip.height),
        functools.partial(execute, clip=clip),
    )

I can't figure out why your test.py needs less vram.

Update:
Right after writing this I figured out that

torch.set_grad_enabled(False)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

decreases vram usage a lot, leaving it here in case someone else will run into it. Sadly it does not in my vapoursynth script. My request for fp16 is still valid though.

Update2:
After a lot of trying I figured out that

with torch.inference_mode():

does the job as well, and uses even less vram than torch.backends.cudnn.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.