Commit 807acb90 authored by kurumuz's avatar kurumuz

change

parent 44328813
......@@ -19,6 +19,11 @@ import k_diffusion as K
import contextlib
import random
def seed_everything(seed: int):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def pil_upscale(image, scale=1):
device = image.device
dtype = image.dtype
......@@ -64,16 +69,14 @@ def prompt_mixing(model, prompt_body, batch_size):
def sample_start_noise(seed, C, H, W, f, device="cuda"):
if seed:
torch.manual_seed(seed)
np.random.seed(seed)
seed_everything(seed)
noise = torch.randn([C, (H) // f, (W) // f], device=device).unsqueeze(0)
return noise
def sample_start_noise_special(seed, request, device="cuda"):
if seed:
torch.manual_seed(seed)
np.random.seed(seed)
seed_everything(seed)
noise = torch.randn([request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor], device=device).unsqueeze(0)
return noise
......@@ -163,11 +166,6 @@ class StableInterface(nn.Module):
return x_0
def seed_everything(seed: int):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
class StableDiffusionModel(nn.Module):
def __init__(self, config):
nn.Module.__init__(self)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment