up你好,我将正向加噪过程过程中使用到的高斯噪声保存了下来,在去噪的时候用到了这些噪声,但是发现最终得到的图像全是噪声点,请问一下这是咋回事啊,下面是我的代码,我是你在b站上的粉丝。`import numpy as np
import torch
from PIL import Image
import os
def preprocess_input(x):
x /= 255
x -= 0.5
x /= 0.5
return x
def cvtColor(image):
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
return image
else:
image = image.convert('RGB')
return image
def postprocess_output(x):
x *= 0.5
x += 0.5
x *= 255
return x
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def perturb_x(sqrt_alphas_cumprod, x, t, noise, sqrt_one_minus_alphas_cumprod):
return (
extract(sqrt_alphas_cumprod, t, x.shape) * x +
extract(sqrt_one_minus_alphas_cumprod, t, x.shape) * noise
)
def remove_noise(remove_noise_coeff, noise, reciprocal_sqrt_alphas, x, t, use_ema=False):
if use_ema:
return (
(x - extract(remove_noise_coeff, t, x.shape) * noise) *
extract(reciprocal_sqrt_alphas, t, x.shape)
)
else:
return (
(x - extract(remove_noise_coeff, t, x.shape) * noise) *
extract(reciprocal_sqrt_alphas, t, x.shape)
)
num_timesteps = 100
save_path = "tmp.jpg"
if not os.path.exists("original_pic"):
os.makedirs("original_pic")
if not os.path.exists("after_pic"):
os.makedirs("after_pic")
image = Image.open("0_clean.png")
image = cvtColor(image).resize([128, 128], Image.BICUBIC)
image = np.array(image, dtype=np.float32)
image = np.transpose(preprocess_input(image), (2, 0, 1))
x = torch.from_numpy(np.array(image, np.float32))
x = x[None,:,:,:]
betas = torch.linspace(start=0.0001, end=0.02, steps=1000)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas,dim=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)
reciprocal_sqrt_alphas = torch.sqrt(1 / alphas)
remove_noise_coeff = betas / torch.sqrt(1 - alphas_cumprod)
sigma = torch.sqrt(betas)
保留加噪过程中的epsilon,用于下个阶段的还原
epsilon_list = []
for t in range(num_timesteps):
t = torch.tensor([t])
epsilon = torch.randn_like(x)
epsilon_list.append(epsilon)
x_t = perturb_x(sqrt_alphas_cumprod, x, t, epsilon, sqrt_one_minus_alphas_cumprod)
tmp1 = x_t.clone()
test_images = postprocess_output(tmp1[0].cpu().data.numpy().transpose(1, 2, 0))
Image.fromarray(np.uint8(test_images)).save(os.path.join("original_pic", str(t) + ".png"))
去噪过程随机采样的xt
x = x_t
#x = torch.randn((1, 3, 128, 128))
for t in range(num_timesteps - 1, -1, -1):
t_batch = torch.tensor([t]).repeat(1)
x = remove_noise(remove_noise_coeff, epsilon_list[t], reciprocal_sqrt_alphas, x, t_batch)
if t > 0:
x += extract(sigma, t_batch, x.shape) * torch.randn_like(x)
tmp = x.clone()
test_images = postprocess_output(tmp[0].cpu().data.numpy().transpose(1, 2, 0))
Image.fromarray(np.uint8(test_images)).save(os.path.join("after_pic", str(t) + ".png"))
test_images = postprocess_output(x[0].cpu().data.numpy().transpose(1, 2, 0))
Image.fromarray(np.uint8(test_images)).save(save_path)`