Commit 10aca1ca authored by AUTOMATIC's avatar AUTOMATIC

more careful loading of model weights (eliminates some issues with checkpoints...

more careful loading of model weights (eliminates some issues with checkpoints that have weird cond_stage_model layer names)
parent c1093b80
......@@ -122,11 +122,33 @@ def select_checkpoint():
return checkpoint_info
chckpoint_dict_replacements = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
}
def transform_checkpoint_dict_key(k):
for text, replacement in chckpoint_dict_replacements.items():
if k.startswith(text):
k = replacement + k[len(text):]
return k
def get_state_dict_from_checkpoint(pl_sd):
if "state_dict" in pl_sd:
return pl_sd["state_dict"]
pl_sd = pl_sd["state_dict"]
sd = {}
for k, v in pl_sd.items():
new_key = transform_checkpoint_dict_key(k)
if new_key is not None:
sd[new_key] = v
return pl_sd
return sd
def load_model_weights(model, checkpoint_info):
......@@ -141,7 +163,7 @@ def load_model_weights(model, checkpoint_info):
print(f"Global Step: {pl_sd['global_step']}")
sd = get_state_dict_from_checkpoint(pl_sd)
model.load_state_dict(sd, strict=False)
missing, extra = model.load_state_dict(sd, strict=False)
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment