Commit 87844d91 authored by Eren Doğan's avatar Eren Doğan Committed by GitHub

custom vae support

parent ed4d8848
......@@ -139,6 +139,7 @@ def init_config_model():
# Resolve where we get our model and data from.
config.model_path = os.getenv('MODEL_PATH', None)
config.vae_path = os.getenv('VAE_PATH', None)
config.module_path = os.getenv('MODULE_PATH', None)
config.prior_path = os.getenv('PRIOR_PATH', None)
......
......@@ -170,6 +170,18 @@ class StableDiffusionModel(nn.Module):
else:
typex = torch.float32
self.model = model.to(config.device).to(typex)
if self.config.vae_path:
ckpt=torch.load(self.config.vae_path, map_location="cpu")
loss = []
for i in ckpt["state_dict"].keys():
if i[0:4] == "loss":
loss.append(i)
for i in loss:
del ckpt["state_dict"][i]
model.first_stage_model.load_state_dict(ckpt["state_dict"])
del ckpt
del loss
self.k_model = K.external.CompVisDenoiser(model)
self.k_model = StableInterface(self.k_model)
self.device = config.device
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment