Commit e64da730 authored by novelailab's avatar novelailab

Remove fusing

parent 6be043cc
......@@ -149,7 +149,6 @@ def init_config_model():
config.default_config = os.getenv('DEFAULT_CONFIG', None)
config.quality_hack = os.getenv('QUALITY_HACK', "0")
config.clip_contexts = os.getenv('CLIP_CONTEXTS', "1")
config.jit_optimize = os.getenv('JIT_OPTIMIZE', "0")
try:
config.clip_contexts = int(config.clip_contexts)
if config.clip_contexts < 1 or config.clip_contexts > 10:
......@@ -190,10 +189,6 @@ def init_config_model():
modules = load_modules(config.module_path)
#attach it to the model
model.premodules = modules
# enable JIT
if config.jit_optimize == "1":
model.fuse_model()
config.model = model
......
......@@ -233,9 +233,9 @@ class StableDiffusionModel(nn.Module):
self.model_config = model_config
self.plms = PLMSSampler(model)
self.ddim = DDIMSampler(model)
self.ema = True
self.ema_manager = self.model.ema_scope
if self.config.enable_ema == "0":
self.ema = False
self.ema_manager = contextlib.nullcontext
config.logger.info("Disabling EMA")
else:
config.logger.info(f"Using EMA")
......@@ -276,43 +276,6 @@ class StableDiffusionModel(nn.Module):
}
return DotMap(dict_config)
def fuse_model(self, requires_grad=False):
ema = self.ema
for param in self.model.model.parameters():
param.requires_grad = False
c = self.model.get_learned_conditioning(["what the hell is wrong with you!"]).float()
uc = self.model.get_learned_conditioning([""]).float()
sigmas = self.k_model.get_sigmas(30)
start_code = torch.randn([1, 4, 64, 64], device="cuda").float()
x_0 = start_code * sigmas[0]
test_sigma = sigmas[1] * x_0.new_ones([x_0.shape[0]])
with torch.autocast("cuda", torch.float16):
self.single_step(ema)
x_two = torch.cat([x_0] * 2)
cnd = torch.cat([uc, c])
sigma_two = torch.cat([test_sigma] * 2)
inputs = {'apply_model': (x_two, sigma_two, cnd)}
traced_model = torch.jit.trace_module(self.model, inputs)
if requires_grad:
self.model.apply_model = lambda x, t, c : traced_model.apply_model(x, t, c)
#traced_model = traced_model.half()
self.k_model = K.external.CompVisDenoiser(traced_model)
self.k_model = StableInterface(self.k_model)
self.single_step(ema)
for param in self.model.model.parameters():
param.requires_grad = requires_grad
def single_step(self, ema):
config = self.get_default_config
config.steps = 1
config.prompt = ""
self.sample(config)
def from_folder(self, folder):
folder = Path(folder)
model_config = OmegaConf.load(folder / "config.yaml")
......@@ -356,15 +319,6 @@ class StableDiffusionModel(nn.Module):
@torch.no_grad()
@torch.autocast("cuda", enabled=True, dtype=torch.float16)
def sample(self, request):
ema_manager = contextlib.nullcontext
if self.ema and not self.copied_ema:
self.model.model_ema.store(self.model.model.parameters())
self.model.model_ema.copy_to(self.model.model)
self.copied_ema = True
if not self.ema and self.copied_ema:
self.model.model_ema.restore(self.model.model.parameters())
self.copied_ema = False
if request.module is not None:
if request.module == "vanilla":
pass
......@@ -514,15 +468,6 @@ class StableDiffusionModel(nn.Module):
@torch.no_grad()
def sample_two_stages(self, request):
ema_manager = contextlib.nullcontext
if self.ema and not self.copied_ema:
self.model.model_ema.store(self.model.model.parameters())
self.model.model_ema.copy_to(self.model.model)
self.copied_ema = True
if not self.ema and self.copied_ema:
self.model.model_ema.restore(self.model.model.parameters())
self.copied_ema = False
request = DotMap(request)
if request.seed is not None:
seed_everything(request.seed)
......
Subproject commit 579edde884723227b5a3b87600f1ae3af93dc98f
Subproject commit acdc20a6de698156418ad20ee277ccc45fe6787b
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment