Commit 9e1a4d34 authored by kurumuz's avatar kurumuz

lol maybe

parent 00f23fd6
...@@ -175,7 +175,7 @@ def init_config_model(): ...@@ -175,7 +175,7 @@ def init_config_model():
if config.model_path is None: if config.model_path is None:
try: try:
config.model_path = f"https://{config.s3_endpoint}/{config.s3_bucket}/{config.s3_file}" config.model_path = f"{config.s3_bucket}/{config.s3_folder}"
logger.info(f"Path is set to S3 {config.model_path}") logger.info(f"Path is set to S3 {config.model_path}")
except: except:
logger.error("No model path or S3 info provided") logger.error("No model path or S3 info provided")
......
...@@ -338,19 +338,26 @@ class StableDiffusionModel(nn.Module): ...@@ -338,19 +338,26 @@ class StableDiffusionModel(nn.Module):
def from_url(self, url): def from_url(self, url):
#read config url into bytes #read config url into bytes
config_path = self.config.s3_folder + "/config.yaml" s3_file = self.config.model_path + "/config.yaml"
s3_file = self.config.s3_folder + "/model.ckpt" headers = web.get_s3_secret_headers(endpoint=self.config.s3_endpoint,
model_config = requests.get(config_path, stream='True').raw access_key=self.config.s3_access_key,
model_config = OmegaConf.load(model_config) secret_key=self.config.s3_secret_key,
print(f"Downloading model from {url}") s3_file=s3_file
)
url = "https://" + self.config.s3_endpoint + "/" +s3_file
tensor_loader = web.CURLStreamFile(url, headers=headers)
model_config = OmegaConf.load(tensor_loader)
s3_file = self.config.model_path + "/model.ckpt"
print(f"Downloading model from {url}")
headers = web.get_s3_secret_headers(endpoint=self.config.s3_endpoint, headers = web.get_s3_secret_headers(endpoint=self.config.s3_endpoint,
access_key=self.config.s3_access_key, access_key=self.config.s3_access_key,
secret_key=self.config.s3_secret_key, secret_key=self.config.s3_secret_key,
s3_file=s3_file s3_file=s3_file
) )
url = "https://" + self.config.s3_endpoint + "/" +s3_file
tensor_loader = web.CURLStreamFile(url, headers=headers) tensor_loader = web.CURLStreamFile(url, headers=headers)
model = self.load_model_from_config(model_config, tensor_loader)
return model, model_config return model, model_config
def load_model_from_config(self, config, ckpt, verbose=False): def load_model_from_config(self, config, ckpt, verbose=False):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment