Commit c0573b3c authored by kurumuz's avatar kurumuz

it sucks!

parent 883a3277
......@@ -140,9 +140,12 @@ def init_config_model():
# Resolve where we get our model and data from.
config.model_path = os.getenv('MODEL_PATH', None)
config.enable_ema = os.getenv('ENABLE_EMA', "1")
config.basedformer = os.getenv('BASEDFORMER', "0")
config.penultimate = os.getenv('PENULTIMATE', "0")
config.vae_path = os.getenv('VAE_PATH', None)
config.module_path = os.getenv('MODULE_PATH', None)
config.prior_path = os.getenv('PRIOR_PATH', None)
config.default_config = os.getenv('DEFAULT_CONFIG', None)
# Misc settings
config.model_alias = os.getenv('MODEL_ALIAS')
......
......@@ -167,7 +167,15 @@ class StableDiffusionModel(nn.Module):
nn.Module.__init__(self)
self.config = config
self.premodules = None
model, model_config = self.from_folder(config.model_path)
if Path(self.config.model_path).is_dir():
model, model_config = self.from_folder(config.model_path)
elif Path(self.config.model_path).is_file():
model, model_config = self.from_file(config.model_path)
else:
raise Exception("Invalid model path!")
if config.dtype == "float16":
typex = torch.float16
else:
......@@ -188,6 +196,11 @@ class StableDiffusionModel(nn.Module):
del ckpt
del loss
if self.config.penultimate == "1":
model.cond_stage_model.return_layer = -2
model.cond_stage_model.do_final_ln = True
model.cond_stage_model.inference_mode = True
self.k_model = K.external.CompVisDenoiser(model)
self.k_model = StableInterface(self.k_model)
self.device = config.device
......@@ -220,12 +233,25 @@ class StableDiffusionModel(nn.Module):
model = self.load_model_from_config(model_config, model_path)
return model, model_config
def from_path(self, file):
default_config = Path(self.config.default_config)
if not default_config.is_file():
raise Exception("Default config to load not found! Either give a folder on MODEL_PATH or specify a config to use with this checkpoint on DEFAULT_CONFIG")
model_config = OmegaConf.load(default_config)
model = self.load_model_from_config(model_config, file)
return model, model_config
def load_model_from_config(self, config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd
if self.config.basedformer == "1":
sd = pl_sd
else:
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
......
......@@ -6,7 +6,10 @@ export MODEL_PATH="/home/xuser/nvme1/stableckpt/anime5000"
export MODULE_PATH="/home/xuser/nvme1/stableckpt/modules"
export TRANSFORMERS_CACHE="/home/xuser/nvme1/transformer_cache"
export SENTRY_URL="https://49ca8adcf4444f82a10eae1b3fd4182f@o846434.ingest.sentry.io/6612448"
export ENABLE_EMA="0"
export ENABLE_EMA="1"
export DEFAULT_CONFIG="/home/xuser/nvme1/stableckpt/defaultconfig.yaml"
export VAE_PATH="/home/xuser/nvme1/stableckpt/animevae.pt"
export BASEDFORMER="1"
export PENULTIMATE="1"
export PYTHONDONTWRITEBYTECODE=1
gunicorn main:app --workers 1 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:4315
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment