Commit c2d5b290 authored by Jairo Correa's avatar Jairo Correa

Move silu to sd_hijack

parent c938679d
......@@ -12,6 +12,7 @@ from ldm.util import default
from einops import rearrange
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
from torch.nn.functional import silu
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
......@@ -100,14 +101,6 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2)
def nonlinearity_hijack(x):
# swish
t = torch.sigmoid(x)
x *= t
del t
return x
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
......@@ -245,11 +238,12 @@ class StableDiffusionModelHijack:
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
self.clip = m.cond_stage_model
ldm.modules.diffusionmodules.model.nonlinearity = silu
if cmd_opts.opt_split_attention_v1:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
def flatten(el):
......
......@@ -22,10 +22,7 @@ import modules.txt2img
import modules.img2img
import modules.swinir as swinir
import modules.sd_models
from torch.nn.functional import silu
import ldm
ldm.modules.diffusionmodules.model.nonlinearity = silu
modules.codeformer_model.setup_codeformer()
modules.gfpgan_model.setup_gfpgan()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment