Commit b75b004f authored by AUTOMATIC1111's avatar AUTOMATIC1111

lora extension rework to include other types of networks

parent 7d26c479
from modules import extra_networks, shared
import lora
import networks
class ExtraNetworkLora(extra_networks.ExtraNetwork):
......@@ -9,7 +9,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def activate(self, p, params_list):
additional = shared.opts.sd_lora
if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional):
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
......@@ -21,12 +21,12 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
names.append(params.items[0])
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
lora.load_loras(names, multipliers)
networks.load_networks(names, multipliers)
if shared.opts.lora_add_hashes_to_infotext:
lora_hashes = []
for item in lora.loaded_loras:
shorthash = item.lora_on_disk.shorthash
network_hashes = []
for item in networks.loaded_networks:
shorthash = item.network_on_disk.shorthash
if not shorthash:
continue
......@@ -36,10 +36,10 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
alias = alias.replace(":", "").replace(",", "")
lora_hashes.append(f"{alias}: {shorthash}")
network_hashes.append(f"{alias}: {shorthash}")
if lora_hashes:
p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes)
if network_hashes:
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
def deactivate(self, p):
pass
import torch
def make_weight_cp(t, wa, wb):
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
def rebuild_conventional(up, down, shape, dyn_dim=None):
up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1)
if dyn_dim is not None:
up = up[:, :dyn_dim]
down = down[:dyn_dim, :]
return (up @ down).reshape(shape)
import os
from collections import namedtuple
import torch
from modules import devices, sd_models, cache, errors, hashes, shared
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
class NetworkOnDisk:
def __init__(self, name, filename):
self.name = name
self.filename = filename
self.metadata = {}
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
def read_metadata():
metadata = sd_models.read_metadata_from_safetensors(filename)
metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text
return metadata
if self.is_safetensors:
try:
self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
except Exception as e:
errors.display(e, f"reading lora {filename}")
if self.metadata:
m = {}
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
m[k] = v
self.metadata = m
self.alias = self.metadata.get('ss_output_name', self.name)
self.hash = None
self.shorthash = None
self.set_hash(
self.metadata.get('sshs_model_hash') or
hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
''
)
def set_hash(self, v):
self.hash = v
self.shorthash = self.hash[0:12]
if self.shorthash:
import networks
networks.available_network_hash_lookup[self.shorthash] = self
def read_hash(self):
if not self.hash:
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
def get_alias(self):
import networks
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
return self.name
else:
return self.alias
class Network: # LoraModule
def __init__(self, name, network_on_disk: NetworkOnDisk):
self.name = name
self.network_on_disk = network_on_disk
self.multiplier = 1.0
self.modules = {}
self.mtime = None
self.mentioned_name = None
"""the text that was used to add the network to prompt - can be either name or an alias"""
class ModuleType:
def create_module(self, net: Network, weights: NetworkWeights) -> Network | None:
return None
class NetworkModule:
def __init__(self, net: Network, weights: NetworkWeights):
self.network = net
self.network_key = weights.network_key
self.sd_key = weights.sd_key
self.sd_module = weights.sd_module
def calc_updown(self, target):
raise NotImplementedError()
def forward(self, x, y):
raise NotImplementedError()
import lyco_helpers
import network
import network_lyco
class ModuleTypeHada(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]):
return NetworkModuleHada(net, weights)
return None
class NetworkModuleHada(network_lyco.NetworkModuleLyco):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
self.w1a = weights.w["hada_w1_a"]
self.w1b = weights.w["hada_w1_b"]
self.dim = self.w1b.shape[0]
self.w2a = weights.w["hada_w2_a"]
self.w2b = weights.w["hada_w2_b"]
self.t1 = weights.w.get("hada_t1")
self.t2 = weights.w.get("hada_t2")
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
def calc_updown(self, orig_weight):
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
output_shape = [w1a.size(0), w1b.size(1)]
if self.t1 is not None:
output_shape = [w1a.size(1), w1b.size(1)]
t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
output_shape += t1.shape[2:]
else:
if len(w1b.shape) == 4:
output_shape += w1b.shape[2:]
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
if self.t2 is not None:
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
else:
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
updown = updown1 * updown2
return self.finalize_updown(updown, orig_weight, output_shape)
import torch
import network
from modules import devices
class ModuleTypeLora(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
return NetworkModuleLora(net, weights)
return None
class NetworkModuleLora(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.up = self.create_module(weights.w["lora_up.weight"])
self.down = self.create_module(weights.w["lora_down.weight"])
self.alpha = weights.w["alpha"] if "alpha" in weights.w else None
def create_module(self, weight, none_ok=False):
if weight is None and none_ok:
return None
if type(self.sd_module) == torch.nn.Linear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(self.sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(self.sd_module) == torch.nn.MultiheadAttention:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
else:
print(f'Network layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
return None
with torch.no_grad():
module.weight.copy_(weight)
module.to(device=devices.cpu, dtype=devices.dtype)
module.weight.requires_grad_(False)
return module
def calc_updown(self, target):
up = self.up.weight.to(target.device, dtype=target.dtype)
down = self.down.weight.to(target.device, dtype=target.dtype)
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
updown = updown * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0)
return updown
def forward(self, x, y):
self.up.to(device=devices.device)
self.down.to(device=devices.device)
return y + self.up(self.down(x)) * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0)
import torch
import lyco_helpers
import network
from modules import devices
class NetworkModuleLyco(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
self.dim = None
self.bias = weights.w.get("bias")
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
def finalize_updown(self, updown, orig_weight, output_shape):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
updown = updown.reshape(output_shape)
if len(output_shape) == 4:
updown = updown.reshape(output_shape)
if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape)
scale = (
self.scale if self.scale is not None
else self.alpha / self.dim if self.dim is not None and self.alpha is not None
else 1.0
)
return updown * scale * self.network.multiplier
import os
import re
import network
import network_lora
import network_hada
import torch
from typing import Union
from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes, cache
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
module_types = [
network_lora.ModuleTypeLora(),
network_hada.ModuleTypeHada(),
]
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
re_digits = re.compile(r"\d+")
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
......@@ -79,81 +88,8 @@ def convert_diffusers_name_to_compvis(key, is_sd2):
return key
class LoraOnDisk:
def __init__(self, name, filename):
self.name = name
self.filename = filename
self.metadata = {}
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
def read_metadata():
metadata = sd_models.read_metadata_from_safetensors(filename)
metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text
return metadata
if self.is_safetensors:
try:
self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
except Exception as e:
errors.display(e, f"reading lora {filename}")
if self.metadata:
m = {}
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
m[k] = v
self.metadata = m
self.alias = self.metadata.get('ss_output_name', self.name)
self.hash = None
self.shorthash = None
self.set_hash(
self.metadata.get('sshs_model_hash') or
hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
''
)
def set_hash(self, v):
self.hash = v
self.shorthash = self.hash[0:12]
if self.shorthash:
available_lora_hash_lookup[self.shorthash] = self
def read_hash(self):
if not self.hash:
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
def get_alias(self):
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in forbidden_lora_aliases:
return self.name
else:
return self.alias
class LoraModule:
def __init__(self, name, lora_on_disk: LoraOnDisk):
self.name = name
self.lora_on_disk = lora_on_disk
self.multiplier = 1.0
self.modules = {}
self.mtime = None
self.mentioned_name = None
"""the text that was used to add lora to prompt - can be either name or an alias"""
class LoraUpDownModule:
def __init__(self):
self.up = None
self.down = None
self.alpha = None
def assign_lora_names_to_compvis_modules(sd_model):
lora_layer_mapping = {}
def assign_network_names_to_compvis_modules(sd_model):
network_layer_mapping = {}
if shared.sd_model.is_sdxl:
for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
......@@ -161,166 +97,132 @@ def assign_lora_names_to_compvis_modules(sd_model):
continue
for name, module in embedder.wrapped.named_modules():
lora_name = f'{i}_{name.replace(".", "_")}'
lora_layer_mapping[lora_name] = module
module.lora_layer_name = lora_name
network_name = f'{i}_{name.replace(".", "_")}'
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
else:
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
lora_name = name.replace(".", "_")
lora_layer_mapping[lora_name] = module
module.lora_layer_name = lora_name
network_name = name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
for name, module in shared.sd_model.model.named_modules():
lora_name = name.replace(".", "_")
lora_layer_mapping[lora_name] = module
module.lora_layer_name = lora_name
network_name = name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
sd_model.lora_layer_mapping = lora_layer_mapping
sd_model.network_layer_mapping = network_layer_mapping
def load_lora(name, lora_on_disk):
lora = LoraModule(name, lora_on_disk)
lora.mtime = os.path.getmtime(lora_on_disk.filename)
def load_network(name, network_on_disk):
net = network.Network(name, network_on_disk)
net.mtime = os.path.getmtime(network_on_disk.filename)
sd = sd_models.read_state_dict(lora_on_disk.filename)
sd = sd_models.read_state_dict(network_on_disk.filename)
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
if not hasattr(shared.sd_model, 'lora_layer_mapping'):
assign_lora_names_to_compvis_modules(shared.sd_model)
if not hasattr(shared.sd_model, 'network_layer_mapping'):
assign_network_names_to_compvis_modules(shared.sd_model)
keys_failed_to_match = {}
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
matched_networks = {}
for key_lora, weight in sd.items():
key_lora_without_lora_parts, lora_key = key_lora.split(".", 1)
for key_network, weight in sd.items():
key_network_without_network_parts, network_part = key_network.split(".", 1)
key = convert_diffusers_name_to_compvis(key_lora_without_lora_parts, is_sd2)
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None:
m = re_x_proj.match(key)
if m:
sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)
sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
if sd_module is None and "lora_unet" in key_lora_without_lora_parts:
key = key_lora_without_lora_parts.replace("lora_unet", "diffusion_model")
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
elif sd_module is None and "lora_te1_text_model" in key_lora_without_lora_parts:
key = key_lora_without_lora_parts.replace("lora_te1_text_model", "0_transformer_text_model")
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
if sd_module is None and "lora_unet" in key_network_without_network_parts:
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None:
keys_failed_to_match[key_lora] = key
keys_failed_to_match[key_network] = key
continue
lora_module = lora.modules.get(key, None)
if lora_module is None:
lora_module = LoraUpDownModule()
lora.modules[key] = lora_module
if key not in matched_networks:
matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
if lora_key == "alpha":
lora_module.alpha = weight.item()
continue
if type(sd_module) == torch.nn.Linear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.MultiheadAttention:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
else:
print(f'Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}')
continue
raise AssertionError(f"Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}")
matched_networks[key].w[network_part] = weight
with torch.no_grad():
module.weight.copy_(weight)
for key, weights in matched_networks.items():
net_module = None
for nettype in module_types:
net_module = nettype.create_module(net, weights)
if net_module is not None:
break
module.to(device=devices.cpu, dtype=devices.dtype)
if net_module is None:
raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}")
if lora_key == "lora_up.weight":
lora_module.up = module
elif lora_key == "lora_down.weight":
lora_module.down = module
else:
raise AssertionError(f"Bad Lora layer name: {key_lora} - must end in lora_up.weight, lora_down.weight or alpha")
net.modules[key] = net_module
if keys_failed_to_match:
print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
return lora
return net
def load_loras(names, multipliers=None):
def load_networks(names, multipliers=None):
already_loaded = {}
for lora in loaded_loras:
if lora.name in names:
already_loaded[lora.name] = lora
for net in loaded_networks:
if net.name in names:
already_loaded[net.name] = net
loaded_loras.clear()
loaded_networks.clear()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
if any(x is None for x in loras_on_disk):
list_available_loras()
networks_on_disk = [available_network_aliases.get(name, None) for name in names]
if any(x is None for x in networks_on_disk):
list_available_networks()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
networks_on_disk = [available_network_aliases.get(name, None) for name in names]
failed_to_load_loras = []
failed_to_load_networks = []
for i, name in enumerate(names):
lora = already_loaded.get(name, None)
net = already_loaded.get(name, None)
lora_on_disk = loras_on_disk[i]
network_on_disk = networks_on_disk[i]
if lora_on_disk is not None:
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
if network_on_disk is not None:
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
try:
lora = load_lora(name, lora_on_disk)
net = load_network(name, network_on_disk)
except Exception as e:
errors.display(e, f"loading Lora {lora_on_disk.filename}")
errors.display(e, f"loading network {network_on_disk.filename}")
continue
lora.mentioned_name = name
net.mentioned_name = name
lora_on_disk.read_hash()
network_on_disk.read_hash()
if lora is None:
failed_to_load_loras.append(name)
print(f"Couldn't find Lora with name {name}")
if net is None:
failed_to_load_networks.append(name)
print(f"Couldn't find network with name {name}")
continue
lora.multiplier = multipliers[i] if multipliers else 1.0
loaded_loras.append(lora)
if failed_to_load_loras:
sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras))
def lora_calc_updown(lora, module, target):
with torch.no_grad():
up = module.up.weight.to(target.device, dtype=target.dtype)
down = module.down.weight.to(target.device, dtype=target.dtype)
net.multiplier = multipliers[i] if multipliers else 1.0
loaded_networks.append(net)
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return updown
if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "lora_weights_backup", None)
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "network_weights_backup", None)
if weights_backup is None:
return
......@@ -332,144 +234,148 @@ def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linea
self.weight.copy_(weights_backup)
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of Loras to the weights of torch layer self.
If weights already have this particular set of loras applied, does nothing.
If not, restores orginal weights from backup and alters weights according to loras.
Applies the currently selected set of networks to the weights of torch layer self.
If weights already have this particular set of networks applied, does nothing.
If not, restores orginal weights from backup and alters weights according to networks.
"""
lora_layer_name = getattr(self, 'lora_layer_name', None)
if lora_layer_name is None:
network_layer_name = getattr(self, 'network_layer_name', None)
if network_layer_name is None:
return
current_names = getattr(self, "lora_current_names", ())
wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)
current_names = getattr(self, "network_current_names", ())
wanted_names = tuple((x.name, x.multiplier) for x in loaded_networks)
weights_backup = getattr(self, "lora_weights_backup", None)
weights_backup = getattr(self, "network_weights_backup", None)
if weights_backup is None:
if isinstance(self, torch.nn.MultiheadAttention):
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
else:
weights_backup = self.weight.to(devices.cpu, copy=True)
self.lora_weights_backup = weights_backup
self.network_weights_backup = weights_backup
if current_names != wanted_names:
lora_restore_weights_from_backup(self)
network_restore_weights_from_backup(self)
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
for net in loaded_networks:
module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'):
self.weight += lora_calc_updown(lora, module, self.weight)
continue
with torch.no_grad():
updown = module.calc_updown(self.weight)
module_q = lora.modules.get(lora_layer_name + "_q_proj", None)
module_k = lora.modules.get(lora_layer_name + "_k_proj", None)
module_v = lora.modules.get(lora_layer_name + "_v_proj", None)
module_out = lora.modules.get(lora_layer_name + "_out_proj", None)
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight)
updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight)
updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight)
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
self.weight += updown
self.in_proj_weight += updown_qkv
self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight)
continue
module_q = net.modules.get(network_layer_name + "_q_proj", None)
module_k = net.modules.get(network_layer_name + "_k_proj", None)
module_v = net.modules.get(network_layer_name + "_v_proj", None)
module_out = net.modules.get(network_layer_name + "_out_proj", None)
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
with torch.no_grad():
updown_q = module_q.calc_updown(self.in_proj_weight)
updown_k = module_k.calc_updown(self.in_proj_weight)
updown_v = module_v.calc_updown(self.in_proj_weight)
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
self.in_proj_weight += updown_qkv
self.out_proj.weight += module_out.calc_updown(self.out_proj.weight)
continue
if module is None:
continue
print(f'failed to calculate lora weights for layer {lora_layer_name}')
print(f'failed to calculate network weights for layer {network_layer_name}')
self.lora_current_names = wanted_names
self.network_current_names = wanted_names
def lora_forward(module, input, original_forward):
def network_forward(module, input, original_forward):
"""
Old way of applying Lora by executing operations during layer's forward.
Stacking many loras this way results in big performance degradation.
"""
if len(loaded_loras) == 0:
if len(loaded_networks) == 0:
return original_forward(module, input)
input = devices.cond_cast_unet(input)
lora_restore_weights_from_backup(module)
lora_reset_cached_weight(module)
network_restore_weights_from_backup(module)
network_reset_cached_weight(module)
res = original_forward(module, input)
y = original_forward(module, input)
lora_layer_name = getattr(module, 'lora_layer_name', None)
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
network_layer_name = getattr(module, 'network_layer_name', None)
for lora in loaded_networks:
module = lora.modules.get(network_layer_name, None)
if module is None:
continue
module.up.to(device=devices.device)
module.down.to(device=devices.device)
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
y = module.forward(y, input)
return res
return y
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
self.lora_current_names = ()
self.lora_weights_backup = None
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
self.network_current_names = ()
self.network_weights_backup = None
def lora_Linear_forward(self, input):
def network_Linear_forward(self, input):
if shared.opts.lora_functional:
return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
return network_forward(self, input, torch.nn.Linear_forward_before_network)
lora_apply_weights(self)
network_apply_weights(self)
return torch.nn.Linear_forward_before_lora(self, input)
return torch.nn.Linear_forward_before_network(self, input)
def lora_Linear_load_state_dict(self, *args, **kwargs):
lora_reset_cached_weight(self)
def network_Linear_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)
return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs)
def lora_Conv2d_forward(self, input):
def network_Conv2d_forward(self, input):
if shared.opts.lora_functional:
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
return network_forward(self, input, torch.nn.Conv2d_forward_before_network)
lora_apply_weights(self)
network_apply_weights(self)
return torch.nn.Conv2d_forward_before_lora(self, input)
return torch.nn.Conv2d_forward_before_network(self, input)
def lora_Conv2d_load_state_dict(self, *args, **kwargs):
lora_reset_cached_weight(self)
def network_Conv2d_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
def lora_MultiheadAttention_forward(self, *args, **kwargs):
lora_apply_weights(self)
def network_MultiheadAttention_forward(self, *args, **kwargs):
network_apply_weights(self)
return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs)
return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs)
def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
lora_reset_cached_weight(self)
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs)
return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs)
def list_available_loras():
available_loras.clear()
available_lora_aliases.clear()
forbidden_lora_aliases.clear()
available_lora_hash_lookup.clear()
forbidden_lora_aliases.update({"none": 1, "Addams": 1})
def list_available_networks():
available_networks.clear()
available_network_aliases.clear()
forbidden_network_aliases.clear()
available_network_hash_lookup.clear()
forbidden_network_aliases.update({"none": 1, "Addams": 1})
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
......@@ -480,21 +386,21 @@ def list_available_loras():
name = os.path.splitext(os.path.basename(filename))[0]
try:
entry = LoraOnDisk(name, filename)
entry = network.NetworkOnDisk(name, filename)
except OSError: # should catch FileNotFoundError and PermissionError etc.
errors.report(f"Failed to load LoRA {name} from {filename}", exc_info=True)
errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
continue
available_loras[name] = entry
available_networks[name] = entry
if entry.alias in available_lora_aliases:
forbidden_lora_aliases[entry.alias.lower()] = 1
if entry.alias in available_network_aliases:
forbidden_network_aliases[entry.alias.lower()] = 1
available_lora_aliases[name] = entry
available_lora_aliases[entry.alias] = entry
available_network_aliases[name] = entry
available_network_aliases[entry.alias] = entry
re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
def infotext_pasted(infotext, params):
......@@ -516,7 +422,7 @@ def infotext_pasted(infotext, params):
if name is None:
continue
m = re_lora_name.match(name)
m = re_network_name.match(name)
if m:
name = m.group(1)
......@@ -528,10 +434,10 @@ def infotext_pasted(infotext, params):
params["Prompt"] += "\n" + "".join(added)
available_loras = {}
available_lora_aliases = {}
available_lora_hash_lookup = {}
forbidden_lora_aliases = {}
loaded_loras = []
available_networks = {}
available_network_aliases = {}
loaded_networks = []
available_network_hash_lookup = {}
forbidden_network_aliases = {}
list_available_loras()
list_available_networks()
......@@ -4,18 +4,19 @@ import torch
import gradio as gr
from fastapi import FastAPI
import lora
import network
import networks
import extra_networks_lora
import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload():
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora
torch.nn.Linear.forward = torch.nn.Linear_forward_before_network
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network
def before_ui():
......@@ -23,50 +24,50 @@ def before_ui():
extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
if not hasattr(torch.nn, 'Linear_forward_before_lora'):
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
if not hasattr(torch.nn, 'Linear_forward_before_network'):
torch.nn.Linear_forward_before_network = torch.nn.Linear.forward
if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'):
torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict
if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):
torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'):
torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'):
torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):
torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict
torch.nn.Linear.forward = lora.lora_Linear_forward
torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict
torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict
torch.nn.Linear.forward = networks.network_Linear_forward
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
script_callbacks.on_before_ui(before_ui)
script_callbacks.on_infotext_pasted(lora.infotext_pasted)
script_callbacks.on_infotext_pasted(networks.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
"sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks),
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
}))
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
"lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
"lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
}))
def create_lora_json(obj: lora.LoraOnDisk):
def create_lora_json(obj: network.NetworkOnDisk):
return {
"name": obj.name,
"alias": obj.alias,
......@@ -75,17 +76,17 @@ def create_lora_json(obj: lora.LoraOnDisk):
}
def api_loras(_: gr.Blocks, app: FastAPI):
def api_networks(_: gr.Blocks, app: FastAPI):
@app.get("/sdapi/v1/loras")
async def get_loras():
return [create_lora_json(obj) for obj in lora.available_loras.values()]
return [create_lora_json(obj) for obj in networks.available_networks.values()]
@app.post("/sdapi/v1/refresh-loras")
async def refresh_loras():
return lora.list_available_loras()
return networks.list_available_networks()
script_callbacks.on_app_started(api_loras)
script_callbacks.on_app_started(api_networks)
re_lora = re.compile("<lora:([^:]+):")
......@@ -98,19 +99,19 @@ def infotext_pasted(infotext, d):
hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
def lora_replacement(m):
def network_replacement(m):
alias = m.group(1)
shorthash = hashes.get(alias)
if shorthash is None:
return m.group(0)
lora_on_disk = lora.available_lora_hash_lookup.get(shorthash)
if lora_on_disk is None:
network_on_disk = networks.available_network_hash_lookup.get(shorthash)
if network_on_disk is None:
return m.group(0)
return f'<lora:{lora_on_disk.get_alias()}:'
return f'<lora:{network_on_disk.get_alias()}:'
d["Prompt"] = re.sub(re_lora, lora_replacement, d["Prompt"])
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
script_callbacks.on_infotext_pasted(infotext_pasted)
import os
import lora
import networks
from modules import shared, ui_extra_networks
from modules.ui_extra_networks import quote_js
......@@ -11,10 +11,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
super().__init__('Lora')
def refresh(self):
lora.list_available_loras()
networks.list_available_networks()
def create_item(self, name, index=None):
lora_on_disk = lora.available_loras.get(name)
lora_on_disk = networks.available_networks.get(name)
path, ext = os.path.splitext(lora_on_disk.filename)
......@@ -43,7 +43,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
return item
def list_items(self):
for index, name in enumerate(lora.available_loras):
for index, name in enumerate(networks.available_networks):
item = self.create_item(name, index)
yield item
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment