Commit 61649601 authored by kurumuz's avatar kurumuz

make encode work

parent 807acb90
...@@ -188,16 +188,13 @@ class StableDiffusionModel(nn.Module): ...@@ -188,16 +188,13 @@ class StableDiffusionModel(nn.Module):
typex = torch.float32 typex = torch.float32
self.model = model.to(config.device).to(typex) self.model = model.to(config.device).to(typex)
if self.config.vae_path: if self.config.vae_path:
model.first_stage_model = model.first_stage_model.float()
ckpt=torch.load(self.config.vae_path, map_location="cpu") ckpt=torch.load(self.config.vae_path, map_location="cpu")
loss = [] dec_ckpt = {}
for i in ckpt["state_dict"].keys(): for i in ckpt["state_dict"].keys():
if i[0:4] == "loss": if i[0:8] == "decoder.":
loss.append(i) dec_ckpt[i[8:]] = ckpt["state_dict"][i]
for i in loss: x, y = model.first_stage_model.decoder.load_state_dict(dec_ckpt)
del ckpt["state_dict"][i]
model.first_stage_model = model.first_stage_model.float()
model.first_stage_model.load_state_dict(ckpt["state_dict"])
model.first_stage_model = model.first_stage_model.float() model.first_stage_model = model.first_stage_model.float()
del ckpt del ckpt
del loss del loss
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment