Commit 10874651 authored by kurumuz's avatar kurumuz

read_from_url

parent d9545e60
......@@ -19,6 +19,9 @@ from PIL import Image
import k_diffusion as K
import contextlib
import random
import web
import io
import requests
def seed_everything(seed: int):
torch.manual_seed(seed)
......@@ -190,16 +193,20 @@ class StableDiffusionModel(nn.Module):
nn.Module.__init__(self)
self.config = config
self.premodules = None
if Path(self.config.model_path).is_dir():
config.logger.info(f"Loading model from folder {self.config.model_path}")
model, model_config = self.from_folder(config.model_path)
elif Path(self.config.model_path).is_file():
config.logger.info(f"Loading model from file {self.config.model_path}")
model, model_config = self.from_file(config.model_path)
if self.config.model_path.startswith("https://"):
self.model = self.from_url(self.config.model_path)
else:
raise Exception("Invalid model path!")
if Path(self.config.model_path).is_dir():
config.logger.info(f"Loading model from folder {self.config.model_path}")
model, model_config = self.from_folder(config.model_path)
elif Path(self.config.model_path).is_file():
config.logger.info(f"Loading model from file {self.config.model_path}")
model, model_config = self.from_file(config.model_path)
else:
raise Exception("Invalid model path!")
if config.dtype == "float16":
typex = torch.float16
......@@ -329,8 +336,21 @@ class StableDiffusionModel(nn.Module):
model = self.load_model_from_config(model_config, file)
return model, model_config
def from_url(self, url):
#read config url into bytes
default_config = self.config.default_config
model_config = requests.get(default_config, stream='True').raw
model_config = OmegaConf.load(model_config)
print(f"Downloading model from {url}")
tensor_loader = web.CURLStreamFile(url)
if not default_config.is_file():
raise Exception("Default config to load not found! Either give a folder on MODEL_PATH or specify a config to use with this checkpoint on DEFAULT_CONFIG")
model_config = OmegaConf.load(default_config)
model = self.load_model_from_config(model_config, tensor_loader)
return model, model_config
def load_model_from_config(self, config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
#print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
......
import fcntl
import subprocess
import resource
import io
from io import SEEK_SET, SEEK_END
import numpy
import struct
import torch
import typing
import pickle
import pickletools
import time
import logging
import sys
import pathlib
import json
import os
from urllib import request
import tempfile
from collections import OrderedDict
from typing import Tuple, Union, List, Iterator, Callable
F_SETPIPE_SZ = 1031
# Whether the tensor is a parameter or a buffer on the model.
TENSOR_PARAM = 0
TENSOR_BUFFER = 1
# Setup logger
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
fh = logging.StreamHandler()
fh_formatter = logging.Formatter(
"%(asctime)s %(levelname)s %(filename)s(%(process)d) - %(message)s"
)
fh.setFormatter(fh_formatter)
logger.addHandler(fh)
# Silly function to convert to human bytes
def convert_bytes(num):
"""
this function will convert bytes to MB.... GB... etc
"""
step_unit = 1000.0
for x in ['bytes', 'KB', 'MB', 'GB', 'TB']:
if num < step_unit:
return "%3.1f %s" % (num, x)
num /= step_unit
class CURLStreamFile(object):
"""
CURLStreamFile implements a file-like object around an HTTP download, the
intention being to not buffer more than we have to.
"""
def __init__(self, uri: str) -> None:
# NOTE: `256mb` buffer on the python IO object.
self._curl = subprocess.Popen(['/usr/bin/curl',
'--header', 'Accept-Encoding: identity',
'-s', uri],
stdout=subprocess.PIPE,
bufsize=256 * 1024 * 1024)
# Read our max-fd-size, fall back to 1mb if invalid.
pipe_buf_sz = 1024 * 1024
try:
pipe_file = open("/proc/sys/fs/pipe-max-size", "r")
pipe_buf_sz = int(pipe_file.read())
logger.debug(f"pipe-max-size: {pipe_buf_sz}")
except IOError as e:
logger.warning(
f"Could not read /proc/sys/fs/pipe-max-size: {e.strerror}")
try:
fcntl.fcntl(self._curl.stdout.fileno(), F_SETPIPE_SZ, pipe_buf_sz)
except PermissionError as e:
logger.warning(
f"Couldn't fcntl F_SETPIPE_SZ to {pipe_buf_sz}: {e.strerror}")
self._curr = 0
self.closed = False
def _read_until(self, goal_position: int,
ba: Union[bytearray, None] = None) -> \
Union[bytes, int]:
if ba is None:
rq_sz = goal_position - self._curr
ret_buff = self._curl.stdout.read(rq_sz)
ret_buff_sz = len(ret_buff)
else:
rq_sz = len(ba)
ret_buff_sz = self._curl.stdout.readinto(ba)
ret_buff = ba
if ret_buff_sz != rq_sz:
self.closed = True
err = self._curl.stderr.read()
self._curl.terminate()
if self._curl.returncode != 0:
raise (IOError(f"curl error: {self._curl.returncode}, {err}"))
else:
raise (IOError(f"Requested {rq_sz} != {ret_buff_sz}"))
self._curr += ret_buff_sz
if ba is None:
return ret_buff
else:
return ret_buff_sz
def tell(self) -> int:
return self._curr
def readinto(self, ba: bytearray) -> int:
goal_position = self._curr + len(ba)
return self._read_until(goal_position, ba)
def read(self, size=None) -> bytes:
if self.closed:
raise (IOError("CURLStreamFile closed."))
if size is None:
return self._curl.stdout.read()
goal_position = self._curr + size
return self._read_until(goal_position)
@staticmethod
def writable() -> bool:
return False
@staticmethod
def fileno() -> int:
return -1
def close(self):
self.closed = True
self._curl.terminate()
def readline(self):
raise Exception("Unimplemented")
"""
This seek() implementation is effectively a no-op, and will throw an
exception for anything other than a seek to the current position.
"""
def seek(self, position, whence=SEEK_SET):
if position == self._curr:
return
if whence == SEEK_END:
raise (Exception("Unsupported `whence`"))
else:
raise (Exception("Seeking is unsupported"))
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment