Commit 34a4ea83 authored by nanahira's avatar nanahira

first

parent b070739b
Pipeline #17126 canceled with stages
in 4 minutes and 51 seconds
.git*
__pycache__
.dockerignore
Dockerfile
/*.bat
/*.sh
__pycache__
stages:
- build
- deploy
variables:
GIT_DEPTH: "1"
before_script:
- docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY
.build-image:
stage: build
script:
- docker build --pull -t $TARGET_IMAGE .
- docker push $TARGET_IMAGE
build-x86:
extends: .build-image
tags:
- docker
variables:
TARGET_IMAGE: $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG
.deploy:
stage: deploy
tags:
- docker
script:
- docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG
- docker push $TARGET_IMAGE
deploy_latest:
extends: .deploy
variables:
TARGET_IMAGE: $CI_REGISTRY_IMAGE:latest
only:
- master
FROM nvidia/cuda:11.6.1-runtime-ubuntu20.04
RUN apt update && apt -y install python3-pip python-is-python3 && \
pip3 install -U pip && \
rm -rf /var/lib/apt/lists/* /var/log* /tmp/* /var/tmp/*
WORKDIR /app
COPY ./requirements.txt ./
RUN pip install -r requirements.txt
COPY . ./
#ENV SAVE_FILES="1"
ENV DTYPE="float32"
ENV CLIP_CONTEXTS=3
ENV AMP="1"
ENV MODEL="stable-diffusion"
ENV DEV="True"
ENV MODEL_PATH="models/animefull-final-pruned"
#these aren't actually used by the site
#ENV MODULE_PATH="models/modules"
#unclear if these are used either
#ENV PRIOR_PATH="models/vector_adjust_v2.pt"
ENV ENABLE_EMA="1"
ENV VAE_PATH="models/animevae.pt"
ENV PENULTIMATE="1"
ENV PYTHONDONTWRITEBYTECODE=1
CMD ["python3", "-m", "uvicorn", "--host", "0.0.0.0", "--port=6969", "main:app"]
# naifu
## Volumes
- `/app/models`
from .clip import *
import hashlib
import os
import urllib
import warnings
from typing import Any, Union, List
from pkg_resources import packaging
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm
from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
__all__ = ["available_models", "load", "tokenize"]
_tokenizer = _Tokenizer()
_MODELS = {
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
}
def _download(url: str, root: str):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def _convert_image_to_rgb(image):
return image.convert("RGB")
def _transform(n_px):
return Compose([
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
def available_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model or more hackable non-JIT model (default).
download_root: str
path to download the model files; by default, it uses "~/.cache/clip"
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS:
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with open(model_path, 'rb') as opened_file:
try:
# loading JIT archive
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
jit = False
state_dict = torch.load(opened_file, map_location="cpu")
if not jit:
model = build_model(state_dict or model.state_dict()).to(device)
if str(device) == "cpu":
model.float()
return model, _transform(model.visual.input_resolution)
# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if str(device) == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
return model, _transform(model.input_resolution.item())
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
truncate: bool
Whether to truncate the text in case its encoding is longer than the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
else:
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)
return result
This diff is collapsed.
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text
import os
import torch
import logging
import os
import platform
import socket
import sys
import time
from dotmap import DotMap
from hydra_node.models import StableDiffusionModel, DalleMiniModel, BasedformerModel, EmbedderModel
from hydra_node import lowvram
import traceback
import zlib
from pathlib import Path
from ldm.modules.attention import CrossAttention, HyperLogic
model_map = {
"stable-diffusion": StableDiffusionModel,
"dalle-mini": DalleMiniModel,
"basedformer": BasedformerModel,
"embedder": EmbedderModel,
}
def no_init(loading_code):
def dummy(self):
return
modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
original = {}
for mod in modules:
original[mod] = mod.reset_parameters
mod.reset_parameters = dummy
result = loading_code()
for mod in modules:
mod.reset_parameters = original[mod]
return result
def crc32(filename, chunksize=65536):
"""Compute the CRC-32 checksum of the contents of the given filename"""
with open(filename, "rb") as f:
checksum = 0
while (chunk := f.read(chunksize)) :
checksum = zlib.crc32(chunk, checksum)
return '%08X' % (checksum & 0xFFFFFFFF)
def load_modules(path):
path = Path(path)
modules = {}
if not path.is_dir():
return
for file in path.iterdir():
module = load_module(file, "cpu")
modules[file.stem] = module
print(f"Loaded module {file.stem}")
return modules
def load_module(path, device):
path = Path(path)
if not path.is_file():
print("Module path {} is not a file".format(path))
network = {
768: (HyperLogic(768).to(device), HyperLogic(768).to(device)),
1280: (HyperLogic(1280).to(device), HyperLogic(1280).to(device)),
640: (HyperLogic(640).to(device), HyperLogic(640).to(device)),
320: (HyperLogic(320).to(device), HyperLogic(320).to(device)),
}
state_dict = torch.load(path)
for key in state_dict.keys():
network[key][0].load_state_dict(state_dict[key][0])
network[key][1].load_state_dict(state_dict[key][1])
return network
def init_config_model():
config = DotMap()
config.savefiles = os.getenv("SAVE_FILES", False)
config.dtype = os.getenv("DTYPE", "float16")
config.device = os.getenv("DEVICE", "cuda")
config.amp = os.getenv("AMP", False)
if config.amp == "1":
config.amp = True
elif config.amp == "0":
config.amp = False
is_dev = ""
environment = "production"
if os.environ['DEV'] == "True":
is_dev = "_dev"
environment = "staging"
config.is_dev = is_dev
# Setup logger
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
fh = logging.StreamHandler()
fh_formatter = logging.Formatter(
"%(asctime)s %(levelname)s %(filename)s(%(process)d) - %(message)s"
)
fh.setFormatter(fh_formatter)
logger.addHandler(fh)
config.logger = logger
# Gather node information
config.cuda_dev = torch.cuda.current_device()
cpu_id = platform.processor()
if os.path.exists('/proc/cpuinfo'):
cpu_id = [line for line in open("/proc/cpuinfo", 'r').readlines() if
'model name' in line][0].rstrip().split(': ')[-1]
config.cpu_id = cpu_id
config.gpu_id = torch.cuda.get_device_name(config.cuda_dev)
config.node_id = platform.node()
# Report on our CUDA memory and model.
gb_gpu = int(torch.cuda.get_device_properties(
config.cuda_dev).total_memory / (1000 * 1000 * 1000))
logger.info(f"CPU: {config.cpu_id}")
logger.info(f"GPU: {config.gpu_id}")
logger.info(f"GPU RAM: {gb_gpu}gb")
config.model_name = os.environ['MODEL']
logger.info(f"MODEL: {config.model_name}")
# Resolve where we get our model and data from.
config.model_path = os.getenv('MODEL_PATH', None)
config.enable_ema = os.getenv('ENABLE_EMA', "1")
config.basedformer = os.getenv('BASEDFORMER', "0")
config.penultimate = os.getenv('PENULTIMATE', "0")
config.vae_path = os.getenv('VAE_PATH', None)
config.module_path = os.getenv('MODULE_PATH', None)
config.prior_path = os.getenv('PRIOR_PATH', None)
config.default_config = os.getenv('DEFAULT_CONFIG', None)
config.quality_hack = os.getenv('QUALITY_HACK', "0")
config.clip_contexts = os.getenv('CLIP_CONTEXTS', "1")
try:
config.clip_contexts = int(config.clip_contexts)
if config.clip_contexts < 1 or config.clip_contexts > 10:
config.clip_contexts = 1
except:
config.clip_contexts = 1
# Misc settings
config.model_alias = os.getenv('MODEL_ALIAS')
# Instantiate our actual model.
load_time = time.time()
model_hash = None
try:
if config.model_name != "dalle-mini":
model = no_init(lambda: model_map[config.model_name](config))
else:
model = model_map[config.model_name](config)
except Exception as e:
traceback.print_exc()
logger.error(f"Failed to load model: {str(e)}")
#exit gunicorn
sys.exit(4)
if config.model_name == "stable-diffusion":
folder = Path(config.model_path)
if (folder / "pruned.ckpt").is_file():
model_path = folder / "pruned.ckpt"
else:
model_path = folder / "model.ckpt"
model_hash = crc32(model_path)
#Load Modules
if config.module_path is not None:
modules = load_modules(config.module_path)
#attach it to the model
model.premodules = modules
lowvram.setup_for_low_vram(model.model, True)
config.model = model
time_load = time.time() - load_time
logger.info(f"Models loaded in {time_load:.2f}s")
return model, config, model_hash
# from github.com/AUTOMATIC1111/stable-diffusion-webui
import torch
from torch.nn.functional import silu
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
module_in_gpu = None
cpu = torch.device("cpu")
device = gpu = torch.device("cuda")
def send_everything_to_cpu():
global module_in_gpu
if module_in_gpu is not None:
module_in_gpu.to(cpu)
module_in_gpu = None
def setup_for_low_vram(sd_model, use_medvram):
parents = {}
def send_me_to_gpu(module, _):
"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;
we add this as forward_pre_hook to a lot of modules and this way all but one of them will
be in CPU
"""
global module_in_gpu
module = parents.get(module, module)
if module_in_gpu == module:
return
if module_in_gpu is not None:
module_in_gpu.to(cpu)
module.to(gpu)
module_in_gpu = module
# see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods
def first_stage_model_encode_wrap(self, encoder, x):
send_me_to_gpu(self, None)
return encoder(x)
def first_stage_model_decode_wrap(self, decoder, z):
send_me_to_gpu(self, None)
return decoder(z)
# remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
sd_model.to(device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
# register hooks for those the first two models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if use_medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
else:
diff_model = sd_model.model.diffusion_model
# the third remaining model is still too big for 4 GB, so we also do the same for its submodules
# so that only one of them is in GPU at a time
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
sd_model.model.to(device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
# install hooks for bits of third model
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.input_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
ldm.modules.diffusionmodules.model.nonlinearity = silu
try:
import xformers
except ImportError:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
import math
import torch
from torch import einsum
from ldm.util import default
from einops import rearrange
# taken from https://github.com/Doggettx/stable-diffusion
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context) * self.scale
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q1.shape
q2 = q1.reshape(b, c, h*w)
del q1
q = q2.permute(0, 2, 1) # b,hw,c
del q2
k = k1.reshape(b, c, h*w) # b,c,hw
del k1
h_ = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c)**(-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
# attend to values
v1 = v.reshape(b, c, h*w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4
h2 = h_.reshape(b, c, h, w)
del h_
h3 = self.proj_out(h2)
del h2
h3 += x
return h3
This diff is collapsed.
import traceback
from dotmap import DotMap
import math
from io import BytesIO
import base64
import random
v1pp_defaults = {
'steps': 50,
'sampler': "plms",
'image': None,
'fixed_code': False,
'ddim_eta': 0.0,
'height': 512,
'width': 512,
'latent_channels': 4,
'downsampling_factor': 8,
'scale': 7.0,
'dynamic_threshold': None,
'seed': None,
'stage_two_seed': None,
'module': None,
'masks': None,
}
v1pp_forced_defaults = {
'latent_channels': 4,
'downsampling_factor': 8,
}
dalle_mini_defaults = {
'temp': 1.0,
'top_k': 256,
'scale': 16,
'grid_size': 4,
}
dalle_mini_forced_defaults = {
}
defaults = {
'stable-diffusion': (v1pp_defaults, v1pp_forced_defaults),
'dalle-mini': (dalle_mini_defaults, dalle_mini_forced_defaults),
'basedformer': ({}, {}),
'embedder': ({}, {}),
}
samplers = [
"plms",
"ddim",
"k_euler",
"k_euler_ancestral",
"k_heun",
"k_dpm_2",
"k_dpm_2_ancestral",
"k_lms"
]
def closest_multiple(num, mult):
num_int = int(num)
floor = math.floor(num_int / mult) * mult
ceil = math.ceil(num_int / mult) * mult
return floor if (num_int - floor) < (ceil - num_int) else ceil
def sanitize_stable_diffusion(request, config):
if request.steps > 50:
return False, "steps must be smaller than 50"
if request.width * request.height == 0:
return False, "width and height must be non-zero"
if request.width <= 0:
return False, "width must be positive"
if request.height <= 0:
return False, "height must be positive"
if request.steps <= 0:
return False, "steps must be positive"
if request.ddim_eta < 0:
return False, "ddim_eta shouldn't be negative"
if request.scale < 1.0:
return False, "scale should be at least 1.0"
if request.dynamic_threshold is not None and request.dynamic_threshold < 0:
return False, "dynamic_threshold shouldn't be negative"
if request.width * request.height >= 1024*1025:
return False, "width and height must be less than 1024*1025"
if request.strength < 0.0 or request.strength >= 1.0:
return False, "strength should be more than 0.0 and less than 1.0"
if request.noise < 0.0 or request.noise > 1.0:
return False, "noise should be more than 0.0 and less than 1.0"
if request.advanced:
request.width = closest_multiple(request.width // 2, 64)
request.height = closest_multiple(request.height // 2, 64)
if request.sampler not in samplers:
return False, "sampler should be one of {}".format(samplers)
if request.seed is None:
state = random.getstate()
request.seed = random.randint(0, 2**32)
random.setstate(state)
if request.module is not None:
if request.module not in config.model.premodules and request.module != "vanilla":
return False, "module should be one of: " + ", ".join(config.model.premodules)
max_gens = 100
if 0:
num_gen_tiers = [(1024*512, 4), (640*640, 6), (704*512, 8), (512*512, 16), (384*640, 18)]
pixel_count = request.width * request.height
for tier in num_gen_tiers:
if pixel_count <= tier[0]:
max_gens = tier[1]
else:
break
if request.n_samples > max_gens:
return False, f"requested more ({request.n_samples}) images than possible at this resolution"
if request.image is not None:
#decode from base64
try:
request.image = base64.b64decode(request.image.encode('utf-8'))
except Exception as e:
traceback.print_exc()
return False, "image is not valid base64"
#check if image is valid
try:
from PIL import Image
image = Image.open(BytesIO(request.image))
image.verify()
except Exception as e:
traceback.print_exc()
return False, "image is not valid"
#image is valid, load it again(still check again, verify() can't be sure as it doesn't decode.)
try:
image = Image.open(BytesIO(request.image))
image = image.convert('RGB')
image = image.resize((request.width, request.height), resample=Image.Resampling.LANCZOS)
request.image = image
except Exception as e:
traceback.print_exc()
return False, "Error while opening and cleaning image"
if request.masks is not None:
masks = request.masks
for x in range(len(masks)):
image = masks[x]["mask"]
try:
image_bytes = base64.b64decode(image.encode('utf-8'))
except Exception as e:
traceback.print_exc()
return False, "image is not valid base64"
try:
from PIL import Image
image = Image.open(BytesIO(image_bytes))
image.verify()
except Exception as e:
traceback.print_exc()
return False, "image is not valid"
#image is valid, load it again(still check again, verify() can't be sure as it doesn't decode.)
try:
image = Image.open(BytesIO(image_bytes))
#image = image.convert('RGB')
image = image.resize((request.width//request.downsampling_factor, request.height//request.downsampling_factor), resample=Image.Resampling.LANCZOS)
except Exception as e:
traceback.print_exc()
return False, "Error while opening and cleaning image"
masks[x]["mask"] = image
return True, request
def sanitize_dalle_mini(request):
return True, request
def sanitize_basedformer(request):
return True, request
def sanitize_embedder(request):
return True, request
def sanitize_input(config, request):
"""
Sanitize the input data and set defaults
"""
request = DotMap(request)
default, forced_default = defaults[config.model_name]
for k, v in default.items():
if k not in request:
request[k] = v
for k, v in forced_default.items():
request[k] = v
if config.model_name == 'stable-diffusion':
return sanitize_stable_diffusion(request, config)
elif config.model_name == 'dalle-mini':
return sanitize_dalle_mini(request)
elif config.model_name == 'basedformer':
return sanitize_basedformer(request)
elif config.model_name == "embedder":
return sanitize_embedder(request)
from . import augmentation, config, external, gns, layers, models, sampling, utils
from .layers import Denoiser
from functools import reduce
import math
import operator
import numpy as np
from skimage import transform
import torch
from torch import nn
def translate2d(tx, ty):
mat = [[1, 0, tx],
[0, 1, ty],
[0, 0, 1]]
return torch.tensor(mat, dtype=torch.float32)
def scale2d(sx, sy):
mat = [[sx, 0, 0],
[ 0, sy, 0],
[ 0, 0, 1]]
return torch.tensor(mat, dtype=torch.float32)
def rotate2d(theta):
mat = [[torch.cos(theta), torch.sin(-theta), 0],
[torch.sin(theta), torch.cos(theta), 0],
[ 0, 0, 1]]
return torch.tensor(mat, dtype=torch.float32)
class KarrasAugmentationPipeline:
def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8):
self.a_prob = a_prob
self.a_scale = a_scale
self.a_aniso = a_aniso
self.a_trans = a_trans
def __call__(self, image):
h, w = image.size
mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)]
# x-flip
a0 = torch.randint(2, []).float()
mats.append(scale2d(1 - 2 * a0, 1))
# y-flip
do = (torch.rand([]) < self.a_prob).float()
a1 = torch.randint(2, []).float() * do
mats.append(scale2d(1, 1 - 2 * a1))
# scaling
do = (torch.rand([]) < self.a_prob).float()
a2 = torch.randn([]) * do
mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2))
# rotation
do = (torch.rand([]) < self.a_prob).float()
a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do
mats.append(rotate2d(-a3))
# anisotropy
do = (torch.rand([]) < self.a_prob).float()
a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do
a5 = torch.randn([]) * do
mats.append(rotate2d(a4))
mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5))
mats.append(rotate2d(-a4))
# translation
do = (torch.rand([]) < self.a_prob).float()
a6 = torch.randn([]) * do
a7 = torch.randn([]) * do
mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7))
# form the transformation matrix and conditioning vector
mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5))
mat = reduce(operator.matmul, mats)
cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7])
# apply the transformation
image_orig = np.array(image, dtype=np.float32) / 255
if image_orig.ndim == 2:
image_orig = image_orig[..., None]
tf = transform.AffineTransform(mat.numpy())
image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True)
image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1
image = torch.as_tensor(image).movedim(2, 0) * 2 - 1
return image, image_orig, cond
class KarrasAugmentWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs):
if aug_cond is None:
aug_cond = input.new_zeros([input.shape[0], 9])
if mapping_cond is None:
mapping_cond = aug_cond
else:
mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1)
return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs)
def set_skip_stages(self, skip_stages):
return self.inner_model.set_skip_stages(skip_stages)
def set_patch_size(self, patch_size):
return self.inner_model.set_patch_size(patch_size)
from functools import partial
import json
from jsonmerge import merge
from . import augmentation, models, utils
def load_config(file):
defaults = {
'model': {
'sigma_data': 1.,
'patch_size': 1,
'dropout_rate': 0.,
'augment_prob': 0.,
'mapping_cond_dim': 0,
'unet_cond_dim': 0,
'cross_cond_dim': 0,
'cross_attn_depths': None,
'skip_stages': 0,
},
'dataset': {
'type': 'imagefolder',
},
'optimizer': {
'type': 'adamw',
'lr': 1e-4,
'betas': [0.95, 0.999],
'eps': 1e-6,
'weight_decay': 1e-3,
},
'lr_sched': {
'type': 'inverse',
'inv_gamma': 20000.,
'power': 1.,
'warmup': 0.99,
},
'ema_sched': {
'type': 'inverse',
'power': 0.6667,
'max_value': 0.9999
},
}
config = json.load(file)
return merge(defaults, config)
def make_model(config):
config = config['model']
assert config['type'] == 'image_v1'
model = models.ImageDenoiserModelV1(
config['input_channels'],
config['mapping_out'],
config['depths'],
config['channels'],
config['self_attn_depths'],
config['cross_attn_depths'],
patch_size=config['patch_size'],
dropout_rate=config['dropout_rate'],
mapping_cond_dim=config['mapping_cond_dim'] + 9,
unet_cond_dim=config['unet_cond_dim'],
cross_cond_dim=config['cross_cond_dim'],
skip_stages=config['skip_stages'],
)
model = augmentation.KarrasAugmentWrapper(model)
return model
def make_sample_density(config):
config = config['sigma_sample_density']
if config['type'] == 'lognormal':
loc = config['mean'] if 'mean' in config else config['loc']
scale = config['std'] if 'std' in config else config['scale']
return partial(utils.rand_log_normal, loc=loc, scale=scale)
if config['type'] == 'loglogistic':
loc = config['loc']
scale = config['scale']
min_value = config['min_value'] if 'min_value' in config else 0.
max_value = config['max_value'] if 'max_value' in config else float('inf')
return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value)
if config['type'] == 'loguniform':
min_value = config['min_value']
max_value = config['max_value']
return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value)
raise ValueError('Unknown sample density type')
import math
import torch
from torch import nn
from . import sampling, utils
class VDenoiser(nn.Module):
"""A v-diffusion-pytorch model wrapper for k-diffusion."""
def __init__(self, inner_model):
super().__init__()
self.inner_model = inner_model
self.sigma_data = 1.
def get_scalings(self, sigma):
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_skip, c_out, c_in
def sigma_to_t(self, sigma):
return sigma.atan() / math.pi * 2
def t_to_sigma(self, t):
return (t * math.pi / 2).tan()
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
class DiscreteSchedule(nn.Module):
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
levels."""
def __init__(self, sigmas, quantize):
super().__init__()
self.register_buffer('sigmas', sigmas)
self.quantize = quantize
def get_sigmas(self, n=None):
if n is None:
return sampling.append_zero(self.sigmas.flip(0))
t_max = len(self.sigmas) - 1
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
return sampling.append_zero(self.t_to_sigma(t))
def sigma_to_t(self, sigma, quantize=None):
quantize = self.quantize if quantize is None else quantize
dists = torch.abs(sigma - self.sigmas[:, None])
if quantize:
return torch.argmin(dists, dim=0).view(sigma.shape)
low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
low, high = self.sigmas[low_idx], self.sigmas[high_idx]
w = (low - sigma) / (low - high)
w = w.clamp(0, 1)
t = (1 - w) * low_idx + w * high_idx
return t.view(sigma.shape)
def t_to_sigma(self, t):
t = t.float()
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
return (1 - w) * self.sigmas[low_idx] + w * self.sigmas[high_idx]
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
noise)."""
def __init__(self, model, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
self.inner_model = model
self.sigma_data = 1.
def get_scalings(self, sigma):
c_out = -sigma
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_out, c_in
def get_eps(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)
def loss(self, input, noise, sigma, **kwargs):
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
return (eps - noise).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
return input + eps * c_out
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for OpenAI diffusion models."""
def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32)
super().__init__(model, alphas_cumprod, quantize=quantize)
self.has_learned_sigmas = has_learned_sigmas
def get_eps(self, *args, **kwargs):
model_output = self.inner_model(*args, **kwargs)
if self.has_learned_sigmas:
return model_output.chunk(2, dim=1)[0]
return model_output
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for CompVis diffusion models."""
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_eps(self, *args, **kwargs):
return self.inner_model.apply_model(*args, **kwargs)
import torch
from torch import nn
class DDPGradientStatsHook:
def __init__(self, ddp_module):
try:
ddp_module.register_comm_hook(self, self._hook_fn)
except AttributeError:
raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules')
self._clear_state()
def _clear_state(self):
self.bucket_sq_norms_small_batch = []
self.bucket_sq_norms_large_batch = []
@staticmethod
def _hook_fn(self, bucket):
buf = bucket.buffer()
self.bucket_sq_norms_small_batch.append(buf.pow(2).sum())
fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future()
def callback(fut):
buf = fut.value()[0]
self.bucket_sq_norms_large_batch.append(buf.pow(2).sum())
return buf
return fut.then(callback)
def get_stats(self):
sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch)
sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch)
self._clear_state()
return torch.stack([sq_norm_small_batch, sq_norm_large_batch])[None]
class GradientNoiseScale:
"""Calculates the gradient noise scale (1 / SNR), or critical batch size,
from _An Empirical Model of Large-Batch Training_,
https://arxiv.org/abs/1812.06162).
Args:
beta (float): The decay factor for the exponential moving averages used to
calculate the gradient noise scale.
Default: 0.9998
eps (float): Added for numerical stability.
Default: 1e-8
"""
def __init__(self, beta=0.9998, eps=1e-8):
self.beta = beta
self.eps = eps
self.ema_sq_norm = 0.
self.ema_var = 0.
self.beta_cumprod = 1.
self.gradient_noise_scale = float('nan')
def state_dict(self):
"""Returns the state of the object as a :class:`dict`."""
return dict(self.__dict__.items())
def load_state_dict(self, state_dict):
"""Loads the object's state.
Args:
state_dict (dict): object state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch):
"""Updates the state with a new batch's gradient statistics, and returns the
current gradient noise scale.
Args:
sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or
per sample gradients.
sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or
per sample gradients.
n_small_batch (int): The batch size of the individual microbatch or per sample
gradients (1 if per sample).
n_large_batch (int): The total batch size of the mean of the microbatch or
per sample gradients.
"""
est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch)
est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch)
self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm
self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var
self.beta_cumprod *= self.beta
self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps)
return self.gradient_noise_scale
def get_gns(self):
"""Returns the current gradient noise scale."""
return self.gradient_noise_scale
def get_stats(self):
"""Returns the current (debiased) estimates of the squared mean gradient
and gradient variance."""
return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod)
import math
from einops import rearrange, repeat
import torch
from torch import nn
from torch.nn import functional as F
from . import utils
# Karras et al. preconditioned denoiser
class Denoiser(nn.Module):
"""A Karras et al. preconditioner for denoising diffusion models."""
def __init__(self, inner_model, sigma_data=1.):
super().__init__()
self.inner_model = inner_model
self.sigma_data = sigma_data
def get_scalings(self, sigma):
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_skip, c_out, c_in
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.inner_model(noised_input * c_in, sigma, **kwargs)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip
# Residual blocks
class ResidualBlock(nn.Module):
def __init__(self, *main, skip=None):
super().__init__()
self.main = nn.Sequential(*main)
self.skip = skip if skip else nn.Identity()
def forward(self, input):
return self.main(input) + self.skip(input)
# Noise level (and other) conditioning
class ConditionedModule(nn.Module):
pass
class UnconditionedModule(ConditionedModule):
def __init__(self, module):
self.module = module
def forward(self, input, cond):
return self.module(input)
class ConditionedSequential(nn.Sequential, ConditionedModule):
def forward(self, input, cond):
for module in self:
if isinstance(module, ConditionedModule):
input = module(input, cond)
else:
input = module(input)
return input
class ConditionedResidualBlock(ConditionedModule):
def __init__(self, *main, skip=None):
super().__init__()
self.main = ConditionedSequential(*main)
self.skip = skip if skip else nn.Identity()
def forward(self, input, cond):
skip = self.skip(input, cond) if isinstance(self.skip, ConditionedModule) else self.skip(input)
return self.main(input, cond) + skip
class AdaGN(ConditionedModule):
def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key='cond'):
super().__init__()
self.num_groups = num_groups
self.eps = eps
self.cond_key = cond_key
self.mapper = nn.Linear(feats_in, c_out * 2)
def forward(self, input, cond):
weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1)
input = F.group_norm(input, self.num_groups, eps=self.eps)
return torch.addcmul(utils.append_dims(bias, input.ndim), input, utils.append_dims(weight, input.ndim) + 1)
# Attention
class SelfAttention2d(ConditionedModule):
def __init__(self, c_in, n_head, norm, dropout_rate=0.):
super().__init__()
assert c_in % n_head == 0
self.norm_in = norm(c_in)
self.n_head = n_head
self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1)
self.out_proj = nn.Conv2d(c_in, c_in, 1)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, input, cond):
n, c, h, w = input.shape
qkv = self.qkv_proj(self.norm_in(input, cond))
qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3)
q, k, v = qkv.chunk(3, dim=1)
scale = k.shape[3] ** -0.25
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
att = self.dropout(att)
y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w])
return input + self.out_proj(y)
class CrossAttention2d(ConditionedModule):
def __init__(self, c_dec, c_enc, n_head, norm_dec, dropout_rate=0.,
cond_key='cross', cond_key_padding='cross_padding'):
super().__init__()
assert c_dec % n_head == 0
self.cond_key = cond_key
self.cond_key_padding = cond_key_padding
self.norm_enc = nn.LayerNorm(c_enc)
self.norm_dec = norm_dec(c_dec)
self.n_head = n_head
self.q_proj = nn.Conv2d(c_dec, c_dec, 1)
self.kv_proj = nn.Linear(c_enc, c_dec * 2)
self.out_proj = nn.Conv2d(c_dec, c_dec, 1)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, input, cond):
n, c, h, w = input.shape
q = self.q_proj(self.norm_dec(input, cond))
q = q.view([n, self.n_head, c // self.n_head, h * w]).transpose(2, 3)
kv = self.kv_proj(self.norm_enc(cond[self.cond_key]))
kv = kv.view([n, -1, self.n_head * 2, c // self.n_head]).transpose(1, 2)
k, v = kv.chunk(2, dim=1)
scale = k.shape[3] ** -0.25
att = ((q * scale) @ (k.transpose(2, 3) * scale))
att = att - (cond[self.cond_key_padding][:, None, None, :]) * 10000
att = att.softmax(3)
att = self.dropout(att)
y = (att @ v).transpose(2, 3)
y = y.contiguous().view([n, c, h, w])
return input + self.out_proj(y)
# Downsampling/upsampling
_kernels = {
'linear':
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
'cubic':
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
0.43359375, 0.11328125, -0.03515625, -0.01171875],
'lanczos3':
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
}
_kernels['bilinear'] = _kernels['linear']
_kernels['bicubic'] = _kernels['cubic']
class Downsample2d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect'):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([_kernels[kernel]])
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
def forward(self, x):
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv2d(x, weight, stride=2)
class Upsample2d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect'):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([_kernels[kernel]]) * 2
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
def forward(self, x):
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
# Embeddings
class FourierFeatures(nn.Module):
def __init__(self, in_features, out_features, std=1.):
super().__init__()
assert out_features % 2 == 0
self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std)
def forward(self, input):
f = 2 * math.pi * input @ self.weight.T
return torch.cat([f.cos(), f.sin()], dim=-1)
# U-Nets
class UNet(ConditionedModule):
def __init__(self, d_blocks, u_blocks, skip_stages=0):
super().__init__()
self.d_blocks = nn.ModuleList(d_blocks)
self.u_blocks = nn.ModuleList(u_blocks)
self.skip_stages = skip_stages
def forward(self, input, cond):
skips = []
for block in self.d_blocks[self.skip_stages:]:
input = block(input, cond)
skips.append(input)
for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))):
input = block(input, cond, skip if i > 0 else None)
return input
from .image_v1 import ImageDenoiserModelV1
import math
import torch
from torch import nn
from torch.nn import functional as F
from .. import layers, utils
class ResConvBlock(layers.ConditionedResidualBlock):
def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.):
skip = None if c_in == c_out else nn.Conv2d(c_in, c_out, 1, bias=False)
super().__init__(
layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)),
nn.GELU(),
nn.Conv2d(c_in, c_mid, 3, padding=1),
nn.Dropout2d(dropout_rate, inplace=True),
layers.AdaGN(feats_in, c_mid, max(1, c_mid // group_size)),
nn.GELU(),
nn.Conv2d(c_mid, c_out, 3, padding=1),
nn.Dropout2d(dropout_rate, inplace=True),
skip=skip)
class DBlock(layers.ConditionedSequential):
def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False, cross_attn=False, c_enc=0):
modules = [nn.Identity()]
for i in range(n_layers):
my_c_in = c_in if i == 0 else c_mid
my_c_out = c_mid if i < n_layers - 1 else c_out
modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
if self_attn:
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
if cross_attn:
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
super().__init__(*modules)
self.set_downsample(downsample)
def set_downsample(self, downsample):
self[0] = layers.Downsample2d() if downsample else nn.Identity()
return self
class UBlock(layers.ConditionedSequential):
def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False, cross_attn=False, c_enc=0):
modules = []
for i in range(n_layers):
my_c_in = c_in if i == 0 else c_mid
my_c_out = c_mid if i < n_layers - 1 else c_out
modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
if self_attn:
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
if cross_attn:
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
modules.append(nn.Identity())
super().__init__(*modules)
self.set_upsample(upsample)
def forward(self, input, cond, skip=None):
if skip is not None:
input = torch.cat([input, skip], dim=1)
return super().forward(input, cond)
def set_upsample(self, upsample):
self[-1] = layers.Upsample2d() if upsample else nn.Identity()
return self
class MappingNet(nn.Sequential):
def __init__(self, feats_in, feats_out, n_layers=2):
layers = []
for i in range(n_layers):
layers.append(nn.Linear(feats_in if i == 0 else feats_out, feats_out))
layers.append(nn.GELU())
super().__init__(*layers)
for layer in self:
if isinstance(layer, nn.Linear):
nn.init.orthogonal_(layer.weight)
class ImageDenoiserModelV1(nn.Module):
def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_attn_depths=None, mapping_cond_dim=0, unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0):
super().__init__()
self.c_in = c_in
self.channels = channels
self.unet_cond_dim = unet_cond_dim
self.patch_size = patch_size
self.timestep_embed = layers.FourierFeatures(1, feats_in)
if mapping_cond_dim > 0:
self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False)
self.mapping = MappingNet(feats_in, feats_in)
self.proj_in = nn.Conv2d((c_in + unet_cond_dim) * self.patch_size ** 2, channels[max(0, skip_stages - 1)], 1)
self.proj_out = nn.Conv2d(channels[max(0, skip_stages - 1)], c_in * self.patch_size ** 2, 1)
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)
if cross_cond_dim == 0:
cross_attn_depths = [False] * len(self_attn_depths)
d_blocks, u_blocks = [], []
for i in range(len(depths)):
my_c_in = channels[max(0, i - 1)]
d_blocks.append(DBlock(depths[i], feats_in, my_c_in, channels[i], channels[i], downsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
for i in range(len(depths)):
my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i]
my_c_out = channels[max(0, i - 1)]
u_blocks.append(UBlock(depths[i], feats_in, my_c_in, channels[i], my_c_out, upsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
self.u_net = layers.UNet(d_blocks, reversed(u_blocks), skip_stages=skip_stages)
def forward(self, input, sigma, mapping_cond=None, unet_cond=None, cross_cond=None, cross_cond_padding=None):
c_noise = sigma.log() / 4
timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2))
mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond)
mapping_out = self.mapping(timestep_embed + mapping_cond_embed)
cond = {'cond': mapping_out}
if unet_cond is not None:
input = torch.cat([input, unet_cond], dim=1)
if cross_cond is not None:
cond['cross'] = cross_cond
cond['cross_padding'] = cross_cond_padding
if self.patch_size > 1:
input = F.pixel_unshuffle(input, self.patch_size)
input = self.proj_in(input)
input = self.u_net(input, cond)
input = self.proj_out(input)
if self.patch_size > 1:
input = F.pixel_shuffle(input, self.patch_size)
return input
def set_skip_stages(self, skip_stages):
self.proj_in = nn.Conv2d(self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1)
self.proj_out = nn.Conv2d(self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1)
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)
self.u_net.skip_stages = skip_stages
for i, block in enumerate(self.u_net.d_blocks):
block.set_downsample(i > skip_stages)
for i, block in enumerate(reversed(self.u_net.u_blocks)):
block.set_upsample(i > skip_stages)
return self
def set_patch_size(self, patch_size):
self.patch_size = patch_size
self.proj_in = nn.Conv2d((self.c_in + self.unet_cond_dim) * self.patch_size ** 2, self.channels[max(0, self.u_net.skip_stages - 1)], 1)
self.proj_out = nn.Conv2d(self.channels[max(0, self.u_net.skip_stages - 1)], self.c_in * self.patch_size ** 2, 1)
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)
This diff is collapsed.
This diff is collapsed.
import os
import numpy as np
from abc import abstractmethod
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
class Txt2ImgIterableBaseDataset(IterableDataset):
'''
Define an interface to make the IterableDatasets for text2img data chainable
'''
def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__()
self.num_records = num_records
self.valid_ids = valid_ids
self.sample_ids = valid_ids
self.size = size
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
def __len__(self):
return self.num_records
@abstractmethod
def __iter__(self):
pass
class PRNGMixin(object):
"""
Adds a prng property which is a numpy RandomState which gets
reinitialized whenever the pid changes to avoid synchronized sampling
behavior when used in conjunction with multiprocessing.
"""
@property
def prng(self):
currentpid = os.getpid()
if getattr(self, "_initpid", None) != currentpid:
self._initpid = currentpid
self._prng = np.random.RandomState()
return self._prng
This diff is collapsed.
import numpy as np
import random
import string
from torch.utils.data import Dataset, Subset
class DummyData(Dataset):
def __init__(self, length, size):
self.length = length
self.size = size
def __len__(self):
return self.length
def __getitem__(self, i):
x = np.random.randn(*self.size)
letters = string.ascii_lowercase
y = ''.join(random.choice(string.ascii_lowercase) for i in range(10))
return {"jpg": x, "txt": y}
class DummyDataWithEmbeddings(Dataset):
def __init__(self, length, size, emb_size):
self.length = length
self.size = size
self.emb_size = emb_size
def __len__(self):
return self.length
def __getitem__(self, i):
x = np.random.randn(*self.size)
y = np.random.randn(*self.emb_size).astype(np.float32)
return {"jpg": x, "txt": y}
This diff is collapsed.
from PIL import Image, ImageDraw
import numpy as np
settings = {
"256narrow": {
"p_irr": 1,
"min_n_irr": 4,
"max_n_irr": 50,
"max_l_irr": 40,
"max_w_irr": 10,
"min_n_box": None,
"max_n_box": None,
"min_s_box": None,
"max_s_box": None,
"marg": None,
},
"256train": {
"p_irr": 0.5,
"min_n_irr": 1,
"max_n_irr": 5,
"max_l_irr": 200,
"max_w_irr": 100,
"min_n_box": 1,
"max_n_box": 4,
"min_s_box": 30,
"max_s_box": 150,
"marg": 10,
},
"512train": { # TODO: experimental
"p_irr": 0.5,
"min_n_irr": 1,
"max_n_irr": 5,
"max_l_irr": 450,
"max_w_irr": 250,
"min_n_box": 1,
"max_n_box": 4,
"min_s_box": 30,
"max_s_box": 300,
"marg": 10,
},
"512train-large": { # TODO: experimental
"p_irr": 0.5,
"min_n_irr": 1,
"max_n_irr": 5,
"max_l_irr": 450,
"max_w_irr": 400,
"min_n_box": 1,
"max_n_box": 4,
"min_s_box": 75,
"max_s_box": 450,
"marg": 10,
},
}
def gen_segment_mask(mask, start, end, brush_width):
mask = mask > 0
mask = (255 * mask).astype(np.uint8)
mask = Image.fromarray(mask)
draw = ImageDraw.Draw(mask)
draw.line([start, end], fill=255, width=brush_width, joint="curve")
mask = np.array(mask) / 255
return mask
def gen_box_mask(mask, masked):
x_0, y_0, w, h = masked
mask[y_0:y_0 + h, x_0:x_0 + w] = 1
return mask
def gen_round_mask(mask, masked, radius):
x_0, y_0, w, h = masked
xy = [(x_0, y_0), (x_0 + w, y_0 + w)]
mask = mask > 0
mask = (255 * mask).astype(np.uint8)
mask = Image.fromarray(mask)
draw = ImageDraw.Draw(mask)
draw.rounded_rectangle(xy, radius=radius, fill=255)
mask = np.array(mask) / 255
return mask
def gen_large_mask(prng, img_h, img_w,
marg, p_irr, min_n_irr, max_n_irr, max_l_irr, max_w_irr,
min_n_box, max_n_box, min_s_box, max_s_box):
"""
img_h: int, an image height
img_w: int, an image width
marg: int, a margin for a box starting coordinate
p_irr: float, 0 <= p_irr <= 1, a probability of a polygonal chain mask
min_n_irr: int, min number of segments
max_n_irr: int, max number of segments
max_l_irr: max length of a segment in polygonal chain
max_w_irr: max width of a segment in polygonal chain
min_n_box: int, min bound for the number of box primitives
max_n_box: int, max bound for the number of box primitives
min_s_box: int, min length of a box side
max_s_box: int, max length of a box side
"""
mask = np.zeros((img_h, img_w))
uniform = prng.randint
if np.random.uniform(0, 1) < p_irr: # generate polygonal chain
n = uniform(min_n_irr, max_n_irr) # sample number of segments
for _ in range(n):
y = uniform(0, img_h) # sample a starting point
x = uniform(0, img_w)
a = uniform(0, 360) # sample angle
l = uniform(10, max_l_irr) # sample segment length
w = uniform(5, max_w_irr) # sample a segment width
# draw segment starting from (x,y) to (x_,y_) using brush of width w
x_ = x + l * np.sin(a)
y_ = y + l * np.cos(a)
mask = gen_segment_mask(mask, start=(x, y), end=(x_, y_), brush_width=w)
x, y = x_, y_
else: # generate Box masks
n = uniform(min_n_box, max_n_box) # sample number of rectangles
for _ in range(n):
h = uniform(min_s_box, max_s_box) # sample box shape
w = uniform(min_s_box, max_s_box)
x_0 = uniform(marg, img_w - marg - w) # sample upper-left coordinates of box
y_0 = uniform(marg, img_h - marg - h)
if np.random.uniform(0, 1) < 0.5:
mask = gen_box_mask(mask, masked=(x_0, y_0, w, h))
else:
r = uniform(0, 60) # sample radius
mask = gen_round_mask(mask, masked=(x_0, y_0, w, h), radius=r)
return mask
make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256train"])
make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256narrow"])
make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train"])
make_512_lama_mask_large = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train-large"])
MASK_MODES = {
"256train": make_lama_mask,
"256narrow": make_narrow_lama_mask,
"512train": make_512_lama_mask,
"512train-large": make_512_lama_mask_large
}
if __name__ == "__main__":
import sys
out = sys.argv[1]
prng = np.random.RandomState(1)
kwargs = settings["256train"]
mask = gen_large_mask(prng, 256, 256, **kwargs)
mask = (255 * mask).astype(np.uint8)
mask = Image.fromarray(mask)
mask.save(out)
This diff is collapsed.
import os
import numpy as np
import PIL
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
class LSUNBase(Dataset):
def __init__(self,
txt_file,
data_root,
size=None,
interpolation="bicubic",
flip_p=0.5
):
self.data_paths = txt_file
self.data_root = data_root
with open(self.data_paths, "r") as f:
self.image_paths = f.read().splitlines()
self._length = len(self.image_paths)
self.labels = {
"relative_file_path_": [l for l in self.image_paths],
"file_path_": [os.path.join(self.data_root, l)
for l in self.image_paths],
}
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self):
return self._length
def __getitem__(self, i):
example = dict((k, self.labels[k][i]) for k in self.labels)
image = Image.open(example["file_path_"])
if not image.mode == "RGB":
image = image.convert("RGB")
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example
class LSUNChurchesTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
class LSUNChurchesValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
flip_p=flip_p, **kwargs)
class LSUNBedroomsTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
class LSUNBedroomsValidation(LSUNBase):
def __init__(self, flip_p=0.0, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
flip_p=flip_p, **kwargs)
class LSUNCatsTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
class LSUNCatsValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs):
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
flip_p=flip_p, **kwargs)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
#!/bin/bash
set -ex
virtualenv venv
. venv/bin/activate
pip install -r requirements.txt
This diff is collapsed.
self.__BUILD_MANIFEST={__rewrites:{beforeFiles:[],afterFiles:[],fallback:[]},"/":["static/chunks/pages/index-67838cc9a284731b.js"],"/_error":["static/chunks/pages/_error-41335750e27ece98.js"],sortedPages:["/","/_app","/_error"]},self.__BUILD_MANIFEST_CB&&self.__BUILD_MANIFEST_CB();
\ No newline at end of file
self.__SSG_MANIFEST=new Set,self.__SSG_MANIFEST_CB&&self.__SSG_MANIFEST_CB();
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment