Commit 61649601 authored by kurumuz's avatar kurumuz

make encode work

parent 807acb90
......@@ -188,16 +188,13 @@ class StableDiffusionModel(nn.Module):
typex = torch.float32
self.model = model.to(config.device).to(typex)
if self.config.vae_path:
model.first_stage_model = model.first_stage_model.float()
ckpt=torch.load(self.config.vae_path, map_location="cpu")
loss = []
dec_ckpt = {}
for i in ckpt["state_dict"].keys():
if i[0:4] == "loss":
loss.append(i)
for i in loss:
del ckpt["state_dict"][i]
model.first_stage_model = model.first_stage_model.float()
model.first_stage_model.load_state_dict(ckpt["state_dict"])
if i[0:8] == "decoder.":
dec_ckpt[i[8:]] = ckpt["state_dict"][i]
x, y = model.first_stage_model.decoder.load_state_dict(dec_ckpt)
model.first_stage_model = model.first_stage_model.float()
del ckpt
del loss
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment