Commit 3df8d169 authored by kurumuz's avatar kurumuz

change a lot

parent 9abde698
...@@ -143,7 +143,7 @@ def init_config_model(): ...@@ -143,7 +143,7 @@ def init_config_model():
config.s3_access_key = os.environ('S3_ACCESS_KEY', None) config.s3_access_key = os.environ('S3_ACCESS_KEY', None)
config.s3_secret_key = os.environ('S3_SECRET_KEY', None) config.s3_secret_key = os.environ('S3_SECRET_KEY', None)
config.s3_bucket = os.environ('S3_BUCKET', None) config.s3_bucket = os.environ('S3_BUCKET', None)
config.s3_file = os.environ('S3_FILE', None) config.s3_folder = os.environ('S3_FOLDER', None)
config.s3_endpoint = os.environ('S3_ENDPOINT', None) config.s3_endpoint = os.environ('S3_ENDPOINT', None)
# Resolve where we get our model and data from. # Resolve where we get our model and data from.
......
...@@ -338,20 +338,19 @@ class StableDiffusionModel(nn.Module): ...@@ -338,20 +338,19 @@ class StableDiffusionModel(nn.Module):
def from_url(self, url): def from_url(self, url):
#read config url into bytes #read config url into bytes
default_config = self.config.default_config config_path = self.config.s3_folder + "/config.yaml"
model_config = requests.get(default_config, stream='True').raw s3_file = self.config.s3_folder + "/model.ckpt"
model_config = requests.get(config_path, stream='True').raw
model_config = OmegaConf.load(model_config) model_config = OmegaConf.load(model_config)
print(f"Downloading model from {url}") print(f"Downloading model from {url}")
headers = web.get_s3_secret_headers(endpoint=self.config.s3_endpoint, headers = web.get_s3_secret_headers(endpoint=self.config.s3_endpoint,
access_key=self.config.s3_access_key, access_key=self.config.s3_access_key,
secret_key=self.config.s3_secret_key, secret_key=self.config.s3_secret_key,
s3_file=self.config.s3_file s3_file=s3_file
) )
tensor_loader = web.CURLStreamFile(url, headers=headers) tensor_loader = web.CURLStreamFile(url, headers=headers)
if not default_config.is_file():
raise Exception("Default config to load not found! Either give a folder on MODEL_PATH or specify a config to use with this checkpoint on DEFAULT_CONFIG")
model_config = OmegaConf.load(default_config)
model = self.load_model_from_config(model_config, tensor_loader)
return model, model_config return model, model_config
def load_model_from_config(self, config, ckpt, verbose=False): def load_model_from_config(self, config, ckpt, verbose=False):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment