Commit 14943112 authored by novelailab's avatar novelailab

Use generators

parent 7dc1b8b3
......@@ -83,16 +83,20 @@ def prompt_mixing(model, prompt_body, batch_size):
def sample_start_noise(seed, C, H, W, f, device="cuda"):
if seed:
seed_everything(seed)
noise = torch.randn([C, (H) // f, (W) // f], device=device).unsqueeze(0)
gen = torch.Generator(device=device)
gen.manual_seed(seed)
noise = torch.randn([C, (H) // f, (W) // f], generator=gen, device=device).unsqueeze(0)
else:
noise = torch.randn([C, (H) // f, (W) // f], device=device).unsqueeze(0)
return noise
def sample_start_noise_special(seed, request, device="cuda"):
if seed:
seed_everything(seed)
noise = torch.randn([request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor], device=device).unsqueeze(0)
gen = torch.Generator(device=device)
gen.manual_seed(seed)
noise = torch.randn([request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor], generator=gen, device=device).unsqueeze(0)
else:
noise = torch.randn([request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor], device=device).unsqueeze(0)
return noise
@torch.no_grad()
......@@ -319,6 +323,8 @@ class StableDiffusionModel(nn.Module):
@torch.no_grad()
@torch.autocast("cuda", enabled=True, dtype=torch.float16)
def sample(self, request):
seed_everything(abs(hash(str(request.prompt))%(2**31-1))) # ensure consistent states
if request.module is not None:
if request.module == "vanilla":
pass
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment