Commit 0c5522ea authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge branch 'master' into training-help-text

parents 858462f7 2273e752
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: bug-report
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior**
A clear and concise description of what you expected to happen.
**Screenshots**
If applicable, add screenshots to help explain your problem.
**Desktop (please complete the following information):**
- OS: [e.g. Windows, Linux]
- Browser [e.g. chrome, safari]
- Commit revision [looks like this: e68484500f76a33ba477d5a99340ab30451e557b; can be seen when launching webui.bat, or obtained manually by running `git rev-parse HEAD`]
**Additional context**
Add any other context about the problem here.
name: Bug Report
description: You think somethings is broken in the UI
title: "[Bug]: "
labels: ["bug-report"]
body:
- type: checkboxes
attributes:
label: Is there an existing issue for this?
description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
options:
- label: I have searched the existing issues and checked the recent builds/commits
required: true
- type: markdown
attributes:
value: |
*Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
- type: textarea
id: what-did
attributes:
label: What happened?
description: Tell us what happened in a very clear and simple way
validations:
required: true
- type: textarea
id: steps
attributes:
label: Steps to reproduce the problem
description: Please provide us with precise step by step information on how to reproduce the bug
value: |
1. Go to ....
2. Press ....
3. ...
validations:
required: true
- type: textarea
id: what-should
attributes:
label: What should have happened?
description: tell what you think the normal behavior should be
validations:
required: true
- type: input
id: commit
attributes:
label: Commit where the problem happens
description: Which commit are you running ? (copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
- type: dropdown
id: platforms
attributes:
label: What platforms do you use to access UI ?
multiple: true
options:
- Windows
- Linux
- MacOS
- iOS
- Android
- Other/Cloud
- type: dropdown
id: browsers
attributes:
label: What browsers do you use to access the UI ?
multiple: true
options:
- Mozilla Firefox
- Google Chrome
- Brave
- Apple Safari
- Microsoft Edge
- type: textarea
id: cmdargs
attributes:
label: Command Line Arguments
description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below
render: Shell
- type: textarea
id: misc
attributes:
label: Additional information, context and logs
description: Please provide us with any relevant additional info, context or log output.
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: 'suggestion'
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.
name: Feature request
description: Suggest an idea for this project
title: "[Feature Request]: "
labels: ["suggestion"]
body:
- type: checkboxes
attributes:
label: Is there an existing issue for this?
description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit.
options:
- label: I have searched the existing issues and checked the recent builds/commits
required: true
- type: markdown
attributes:
value: |
*Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible*
- type: textarea
id: feature
attributes:
label: What would your feature do ?
description: Tell us about your feature in a very clear and simple way, and what problem it would solve
validations:
required: true
- type: textarea
id: workflow
attributes:
label: Proposed workflow
description: Please provide us with step by step information on how you'd like the feature to be accessed and used
value: |
1. Go to ....
2. Press ....
3. ...
validations:
required: true
- type: textarea
id: misc
attributes:
label: Additional information
description: Add any other context or screenshots about the feature request here.
...@@ -27,3 +27,4 @@ __pycache__ ...@@ -27,3 +27,4 @@ __pycache__
notification.mp3 notification.mp3
/SwinIR /SwinIR
/textual_inversion /textual_inversion
.vscode
\ No newline at end of file
...@@ -70,6 +70,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -70,6 +70,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens) - No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args) - DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args)
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args) - [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
## Installation and Running ## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
......
...@@ -91,6 +91,8 @@ titles = { ...@@ -91,6 +91,8 @@ titles = {
"Weighted sum": "Result = A * (1 - M) + B * M", "Weighted sum": "Result = A * (1 - M) + B * M",
"Add difference": "Result = A + (B - C) * M", "Add difference": "Result = A + (B - C) * M",
"Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
} }
......
...@@ -156,9 +156,15 @@ def prepare_enviroment(): ...@@ -156,9 +156,15 @@ def prepare_enviroment():
if not is_installed("clip"): if not is_installed("clip"):
run_pip(f"install {clip_package}", "clip") run_pip(f"install {clip_package}", "clip")
if (not is_installed("xformers") or reinstall_xformers) and xformers and platform.python_version().startswith("3.10"): if (not is_installed("xformers") or reinstall_xformers) and xformers:
if platform.system() == "Windows": if platform.system() == "Windows":
run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers") if platform.python_version().startswith("3.10"):
run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers")
else:
print("Installation of xformers is not supported in this version of Python.")
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
if not is_installed("xformers"):
exit(0)
elif platform.system() == "Linux": elif platform.system() == "Linux":
run_pip("install xformers", "xformers") run_pip("install xformers", "xformers")
......
...@@ -39,9 +39,12 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -39,9 +39,12 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
if input_dir == '': if input_dir == '':
return outputs, "Please select an input directory.", '' return outputs, "Please select an input directory.", ''
image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)] image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)]
for img in image_list: for img in image_list:
image = Image.open(img) try:
image = Image.open(img)
except Exception:
continue
imageArr.append(image) imageArr.append(image)
imageNameArr.append(img) imageNameArr.append(img)
else: else:
...@@ -118,10 +121,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -118,10 +121,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
while len(cached_images) > 2: while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))] del cached_images[next(iter(cached_images.keys()))]
images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, if opts.use_original_name_batch and image_name != None:
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, basename = os.path.splitext(os.path.basename(image_name))[0]
forced_filename=image_name if opts.use_original_name_batch else None) else:
basename = ''
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
if opts.enable_pnginfo: if opts.enable_pnginfo:
image.info = existing_pnginfo image.info = existing_pnginfo
......
...@@ -45,11 +45,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -45,11 +45,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
else: else:
prompt += ("" if prompt == "" else "\n") + line prompt += ("" if prompt == "" else "\n") + line
if len(prompt) > 0: res["Prompt"] = prompt
res["Prompt"] = prompt res["Negative prompt"] = negative_prompt
if len(negative_prompt) > 0:
res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline): for k, v in re_param.findall(lastline):
m = re_imagesize.match(v) m = re_imagesize.match(v)
......
...@@ -22,40 +22,57 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler ...@@ -22,40 +22,57 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
multiplier = 1.0 multiplier = 1.0
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False): def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
super().__init__() super().__init__()
if layer_structure is not None:
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
else: assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
layer_structure = parse_layer_structure(dim, state_dict)
linears = [] linears = []
for i in range(len(layer_structure) - 1): for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
if activation_func == "relu":
linears.append(torch.nn.ReLU())
elif activation_func == "leakyrelu":
linears.append(torch.nn.LeakyReLU())
elif activation_func == 'linear' or activation_func is None:
pass
else:
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
if add_layer_norm: if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
self.linear = torch.nn.Sequential(*linears) self.linear = torch.nn.Sequential(*linears)
if state_dict is not None: if state_dict is not None:
try: self.fix_old_state_dict(state_dict)
self.load_state_dict(state_dict) self.load_state_dict(state_dict)
except RuntimeError:
self.try_load_previous(state_dict)
else: else:
for layer in self.linear: for layer in self.linear:
layer.weight.data.normal_(mean = 0.0, std = 0.01) if type(layer) == torch.nn.Linear:
layer.bias.data.zero_() layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
self.to(devices.device) self.to(devices.device)
def try_load_previous(self, state_dict): def fix_old_state_dict(self, state_dict):
states = self.state_dict() changes = {
states['linear.0.bias'].copy_(state_dict['linear1.bias']) 'linear1.bias': 'linear.0.bias',
states['linear.0.weight'].copy_(state_dict['linear1.weight']) 'linear1.weight': 'linear.0.weight',
states['linear.1.bias'].copy_(state_dict['linear2.bias']) 'linear2.bias': 'linear.1.bias',
states['linear.1.weight'].copy_(state_dict['linear2.weight']) 'linear2.weight': 'linear.1.weight',
}
for fr, to in changes.items():
x = state_dict.get(fr, None)
if x is None:
continue
del state_dict[fr]
state_dict[to] = x
def forward(self, x): def forward(self, x):
return x + self.linear(x) * self.multiplier return x + self.linear(x) * self.multiplier
...@@ -63,7 +80,8 @@ class HypernetworkModule(torch.nn.Module): ...@@ -63,7 +80,8 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self): def trainables(self):
layer_structure = [] layer_structure = []
for layer in self.linear: for layer in self.linear:
layer_structure += [layer.weight, layer.bias] if type(layer) == torch.nn.Linear:
layer_structure += [layer.weight, layer.bias]
return layer_structure return layer_structure
...@@ -71,23 +89,11 @@ def apply_strength(value=None): ...@@ -71,23 +89,11 @@ def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
def parse_layer_structure(dim, state_dict):
i = 0
layer_structure = [1]
while (key := "linear.{}.weight".format(i)) in state_dict:
weight = state_dict[key]
layer_structure.append(len(weight) // dim)
i += 1
return layer_structure
class Hypernetwork: class Hypernetwork:
filename = None filename = None
name = None name = None
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False): def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None):
self.filename = None self.filename = None
self.name = name self.name = name
self.layers = {} self.layers = {}
...@@ -96,11 +102,12 @@ class Hypernetwork: ...@@ -96,11 +102,12 @@ class Hypernetwork:
self.sd_checkpoint_name = None self.sd_checkpoint_name = None
self.layer_structure = layer_structure self.layer_structure = layer_structure
self.add_layer_norm = add_layer_norm self.add_layer_norm = add_layer_norm
self.activation_func = activation_func
for size in enable_sizes or []: for size in enable_sizes or []:
self.layers[size] = ( self.layers[size] = (
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm), HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm), HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
) )
def weights(self): def weights(self):
...@@ -123,6 +130,7 @@ class Hypernetwork: ...@@ -123,6 +130,7 @@ class Hypernetwork:
state_dict['name'] = self.name state_dict['name'] = self.name
state_dict['layer_structure'] = self.layer_structure state_dict['layer_structure'] = self.layer_structure
state_dict['is_layer_norm'] = self.add_layer_norm state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['activation_func'] = self.activation_func
state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
...@@ -135,17 +143,19 @@ class Hypernetwork: ...@@ -135,17 +143,19 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu') state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.activation_func = state_dict.get('activation_func', None)
for size, sd in state_dict.items(): for size, sd in state_dict.items():
if type(size) == int: if type(size) == int:
self.layers[size] = ( self.layers[size] = (
HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]), HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func),
HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]), HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func),
) )
self.name = state_dict.get('name', self.name) self.name = state_dict.get('name', self.name)
self.step = state_dict.get('step', 0) self.step = state_dict.get('step', 0)
self.layer_structure = state_dict.get('layer_structure', None)
self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.sd_checkpoint = state_dict.get('sd_checkpoint', None) self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
...@@ -244,7 +254,11 @@ def stack_conds(conds): ...@@ -244,7 +254,11 @@ def stack_conds(conds):
return torch.stack(conds) return torch.stack(conds)
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
assert hypernetwork_name, 'hypernetwork not selected' assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None) path = shared.hypernetworks.get(hypernetwork_name, None)
...@@ -287,6 +301,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -287,6 +301,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
last_saved_file = "<none>" last_saved_file = "<none>"
last_saved_image = "<none>" last_saved_image = "<none>"
forced_filename = "<none>"
ititial_step = hypernetwork.step or 0 ititial_step = hypernetwork.step or 0
if ititial_step > steps: if ititial_step > steps:
...@@ -334,7 +349,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -334,7 +349,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
}) })
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
last_saved_image = os.path.join(images_dir, forced_filename)
optimizer.zero_grad() optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.cond_stage_model.to(devices.device)
...@@ -370,7 +386,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -370,7 +386,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if image is not None: if image is not None:
shared.state.current_image = image shared.state.current_image = image
image.save(last_saved_image) last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
last_saved_image += f", prompt: {preview_text}" last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step shared.state.job_no = hypernetwork.step
......
...@@ -10,19 +10,20 @@ from modules import sd_hijack, shared, devices ...@@ -10,19 +10,20 @@ from modules import sd_hijack, shared, devices
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False, activation_func=None):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
if not overwrite_old: if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists" assert not os.path.exists(fn), f"file {fn} already exists"
if type(layer_structure) == str: if type(layer_structure) == str:
layer_structure = tuple(map(int, re.sub(r'\D', '', layer_structure))) layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork( hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
name=name, name=name,
enable_sizes=[int(x) for x in enable_sizes], enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure, layer_structure=layer_structure,
add_layer_norm=add_layer_norm, add_layer_norm=add_layer_norm,
activation_func=activation_func,
) )
hypernet.save(fn) hypernet.save(fn)
......
...@@ -28,9 +28,11 @@ class InterrogateModels: ...@@ -28,9 +28,11 @@ class InterrogateModels:
clip_preprocess = None clip_preprocess = None
categories = None categories = None
dtype = None dtype = None
running_on_cpu = None
def __init__(self, content_dir): def __init__(self, content_dir):
self.categories = [] self.categories = []
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
if os.path.exists(content_dir): if os.path.exists(content_dir):
for filename in os.listdir(content_dir): for filename in os.listdir(content_dir):
...@@ -53,7 +55,11 @@ class InterrogateModels: ...@@ -53,7 +55,11 @@ class InterrogateModels:
def load_clip_model(self): def load_clip_model(self):
import clip import clip
model, preprocess = clip.load(clip_model_name) if self.running_on_cpu:
model, preprocess = clip.load(clip_model_name, device="cpu")
else:
model, preprocess = clip.load(clip_model_name)
model.eval() model.eval()
model = model.to(devices.device_interrogate) model = model.to(devices.device_interrogate)
...@@ -62,14 +68,14 @@ class InterrogateModels: ...@@ -62,14 +68,14 @@ class InterrogateModels:
def load(self): def load(self):
if self.blip_model is None: if self.blip_model is None:
self.blip_model = self.load_blip_model() self.blip_model = self.load_blip_model()
if not shared.cmd_opts.no_half: if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.blip_model = self.blip_model.half() self.blip_model = self.blip_model.half()
self.blip_model = self.blip_model.to(devices.device_interrogate) self.blip_model = self.blip_model.to(devices.device_interrogate)
if self.clip_model is None: if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model() self.clip_model, self.clip_preprocess = self.load_clip_model()
if not shared.cmd_opts.no_half: if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.clip_model = self.clip_model.half() self.clip_model = self.clip_model.half()
self.clip_model = self.clip_model.to(devices.device_interrogate) self.clip_model = self.clip_model.to(devices.device_interrogate)
......
...@@ -304,7 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration ...@@ -304,7 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Size": f"{p.width}x{p.height}", "Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')), "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.filename.split('\\')[-1].split('.')[0]),
"Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch), "Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
...@@ -540,17 +540,37 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -540,17 +540,37 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def create_dummy_mask(self, x, width=None, height=None):
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
height = height or self.height
width = width or self.width
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
else:
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
return image_conditioning
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr: if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
return samples return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
...@@ -587,7 +607,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -587,7 +607,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None x = None
devices.torch_gc() devices.torch_gc()
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
return samples return samples
...@@ -613,6 +633,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -613,6 +633,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.inpainting_mask_invert = inpainting_mask_invert self.inpainting_mask_invert = inpainting_mask_invert
self.mask = None self.mask = None
self.nmask = None self.nmask = None
self.image_conditioning = None
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
...@@ -714,10 +735,39 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -714,10 +735,39 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3: elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask self.init_latent = self.init_latent * self.mask
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
if self.image_mask is not None:
conditioning_mask = np.array(self.image_mask.convert("L"))
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
else:
conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
# Create another latent image, this time with a masked version of the original input.
conditioning_mask = conditioning_mask.to(image.device)
conditioning_image = image * (1.0 - conditioning_mask)
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
else:
self.image_conditioning = torch.zeros(
self.init_latent.shape[0], 5, 1, 1,
dtype=self.init_latent.dtype,
device=self.init_latent.device
)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning) samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None: if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask samples = samples * self.nmask + self.init_latent * self.mask
......
This diff is collapsed.
...@@ -9,6 +9,7 @@ from ldm.util import instantiate_from_config ...@@ -9,6 +9,7 @@ from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices from modules import shared, modelloader, devices
from modules.paths import models_path from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir)) model_path = os.path.abspath(os.path.join(models_path, model_dir))
...@@ -203,14 +204,26 @@ def load_model_weights(model, checkpoint_info): ...@@ -203,14 +204,26 @@ def load_model_weights(model, checkpoint_info):
model.sd_checkpoint_info = checkpoint_info model.sd_checkpoint_info = checkpoint_info
def load_model(): def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
if checkpoint_info.config != shared.cmd_opts.config: if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}") print(f"Loading config from: {checkpoint_info.config}")
sd_config = OmegaConf.load(checkpoint_info.config) sd_config = OmegaConf.load(checkpoint_info.config)
if should_hijack_inpainting(checkpoint_info):
# Hardcoded config for now...
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
sd_config.model.params.use_ema = False
sd_config.model.params.conditioning_key = "hybrid"
sd_config.model.params.unet_config.params.in_channels = 9
# Create a "fake" config with a different name so that we know to unload it when switching models.
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
do_inpainting_hijack()
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)
...@@ -234,9 +247,9 @@ def reload_model_weights(sd_model, info=None): ...@@ -234,9 +247,9 @@ def reload_model_weights(sd_model, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename: if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config: if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
checkpoints_loaded.clear() checkpoints_loaded.clear()
shared.sd_model = load_model() shared.sd_model = load_model(checkpoint_info)
return shared.sd_model return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
......
...@@ -117,6 +117,8 @@ class VanillaStableDiffusionSampler: ...@@ -117,6 +117,8 @@ class VanillaStableDiffusionSampler:
self.config = None self.config = None
self.last_latent = None self.last_latent = None
self.conditioning_key = sd_model.model.conditioning_key
def number_of_needed_noises(self, p): def number_of_needed_noises(self, p):
return 0 return 0
...@@ -136,6 +138,12 @@ class VanillaStableDiffusionSampler: ...@@ -136,6 +138,12 @@ class VanillaStableDiffusionSampler:
if self.stop_at is not None and self.step > self.stop_at: if self.stop_at is not None and self.step > self.stop_at:
raise InterruptedException raise InterruptedException
# Have to unwrap the inpainting conditioning here to perform pre-processing
image_conditioning = None
if isinstance(cond, dict):
image_conditioning = cond["c_concat"][0]
cond = cond["c_crossattn"][0]
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
...@@ -157,6 +165,12 @@ class VanillaStableDiffusionSampler: ...@@ -157,6 +165,12 @@ class VanillaStableDiffusionSampler:
img_orig = self.sampler.model.q_sample(self.init_latent, ts) img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec x_dec = img_orig * self.mask + self.nmask * x_dec
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
# Note that they need to be lists because it just concatenates them later.
if image_conditioning is not None:
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
if self.mask is not None: if self.mask is not None:
...@@ -182,7 +196,7 @@ class VanillaStableDiffusionSampler: ...@@ -182,7 +196,7 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None): def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps) steps, t_enc = setup_img2img_steps(p, steps)
self.initialize(p) self.initialize(p)
...@@ -196,20 +210,33 @@ class VanillaStableDiffusionSampler: ...@@ -196,20 +210,33 @@ class VanillaStableDiffusionSampler:
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.init_latent = x self.init_latent = x
self.last_latent = x
self.step = 0 self.step = 0
# Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
self.initialize(p) self.initialize(p)
self.init_latent = None self.init_latent = None
self.last_latent = x
self.step = 0 self.step = 0
steps = steps or p.steps steps = steps or p.steps
# Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
# existing code fails with certain step counts, like 9 # existing code fails with certain step counts, like 9
try: try:
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
...@@ -228,7 +255,7 @@ class CFGDenoiser(torch.nn.Module): ...@@ -228,7 +255,7 @@ class CFGDenoiser(torch.nn.Module):
self.init_latent = None self.init_latent = None
self.step = 0 self.step = 0
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
if state.interrupted or state.skipped: if state.interrupted or state.skipped:
raise InterruptedException raise InterruptedException
...@@ -239,28 +266,29 @@ class CFGDenoiser(torch.nn.Module): ...@@ -239,28 +266,29 @@ class CFGDenoiser(torch.nn.Module):
repeats = [len(conds_list[i]) for i in range(batch_size)] repeats = [len(conds_list[i]) for i in range(batch_size)]
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
if tensor.shape[1] == uncond.shape[1]: if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond]) cond_in = torch.cat([tensor, uncond])
if shared.batch_cond_uncond: if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond=cond_in) x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
else: else:
x_out = torch.zeros_like(x_in) x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size): for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset a = batch_offset
b = a + batch_size b = a + batch_size
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b]) x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
else: else:
x_out = torch.zeros_like(x_in) x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
for batch_offset in range(0, tensor.shape[0], batch_size): for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset a = batch_offset
b = min(a + batch_size, tensor.shape[0]) b = min(a + batch_size, tensor.shape[0])
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b]) x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond) x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
denoised_uncond = x_out[-uncond.shape[0]:] denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond) denoised = torch.clone(denoised_uncond)
...@@ -306,6 +334,8 @@ class KDiffusionSampler: ...@@ -306,6 +334,8 @@ class KDiffusionSampler:
self.config = None self.config = None
self.last_latent = None self.last_latent = None
self.conditioning_key = sd_model.model.conditioning_key
def callback_state(self, d): def callback_state(self, d):
step = d['i'] step = d['i']
latent = d["denoised"] latent = d["denoised"]
...@@ -361,7 +391,7 @@ class KDiffusionSampler: ...@@ -361,7 +391,7 @@ class KDiffusionSampler:
return extra_params_kwargs return extra_params_kwargs
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None): def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps) steps, t_enc = setup_img2img_steps(p, steps)
if p.sampler_noise_scheduler_override: if p.sampler_noise_scheduler_override:
...@@ -388,12 +418,18 @@ class KDiffusionSampler: ...@@ -388,12 +418,18 @@ class KDiffusionSampler:
extra_params_kwargs['sigmas'] = sigma_sched extra_params_kwargs['sigmas'] = sigma_sched
self.model_wrap_cfg.init_latent = x self.model_wrap_cfg.init_latent = x
self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)) samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
steps = steps or p.steps steps = steps or p.steps
if p.sampler_noise_scheduler_override: if p.sampler_noise_scheduler_override:
...@@ -414,7 +450,13 @@ class KDiffusionSampler: ...@@ -414,7 +450,13 @@ class KDiffusionSampler:
else: else:
extra_params_kwargs['sigmas'] = sigmas extra_params_kwargs['sigmas'] = sigmas
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)) self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples return samples
...@@ -83,7 +83,7 @@ class PersonalizedBase(Dataset): ...@@ -83,7 +83,7 @@ class PersonalizedBase(Dataset):
self.dataset.append(entry) self.dataset.append(entry)
assert len(self.dataset) > 1, "No images have been found in the dataset." assert len(self.dataset) > 0, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size self.length = len(self.dataset) * repeats // batch_size
self.initial_indexes = np.arange(len(self.dataset)) self.initial_indexes = np.arange(len(self.dataset))
...@@ -91,7 +91,7 @@ class PersonalizedBase(Dataset): ...@@ -91,7 +91,7 @@ class PersonalizedBase(Dataset):
self.shuffle() self.shuffle()
def shuffle(self): def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
def create_text(self, filename_text): def create_text(self, filename_text):
text = random.choice(self.lines) text = random.choice(self.lines)
......
...@@ -268,8 +268,13 @@ def calc_time_left(progress, threshold, label, force_display): ...@@ -268,8 +268,13 @@ def calc_time_left(progress, threshold, label, force_display):
time_since_start = time.time() - shared.state.time_start time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress) eta = (time_since_start/progress)
eta_relative = eta-time_since_start eta_relative = eta-time_since_start
if (eta_relative > threshold and progress > 0.02) or force_display: if (eta_relative > threshold and progress > 0.02) or force_display:
return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) if eta_relative > 3600:
return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
elif eta_relative > 60:
return label + time.strftime('%M:%S', time.gmtime(eta_relative))
else:
return label + time.strftime('%Ss', time.gmtime(eta_relative))
else: else:
return "" return ""
...@@ -285,7 +290,7 @@ def check_progress_call(id_part): ...@@ -285,7 +290,7 @@ def check_progress_call(id_part):
if shared.state.sampling_steps > 0: if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
time_left = calc_time_left( progress, 60, " ETA:", shared.state.time_left_force_display ) time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display )
if time_left != "": if time_left != "":
shared.state.time_left_force_display = True shared.state.time_left_force_display = True
...@@ -293,7 +298,7 @@ def check_progress_call(id_part): ...@@ -293,7 +298,7 @@ def check_progress_call(id_part):
progressbar = "" progressbar = ""
if opts.show_progressbar: if opts.show_progressbar:
progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:hidden;width:{progress * 100}%">{str(int(progress*100))+"%"+time_left if progress > 0.01 else ""}</div></div>""" progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{"&nbsp;" * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}</div></div>"""
image = gr_show(False) image = gr_show(False)
preview_visibility = gr_show(False) preview_visibility = gr_show(False)
...@@ -1221,6 +1226,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1221,6 +1226,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
...@@ -1311,6 +1317,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1311,6 +1317,7 @@ def create_ui(wrap_gradio_gpu_call):
overwrite_old_hypernetwork, overwrite_old_hypernetwork,
new_hypernetwork_layer_structure, new_hypernetwork_layer_structure,
new_hypernetwork_add_layer_norm, new_hypernetwork_add_layer_norm,
new_hypernetwork_activation_func,
], ],
outputs=[ outputs=[
train_hypernetwork_name, train_hypernetwork_name,
......
...@@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs): ...@@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs):
if info is None: if info is None:
raise RuntimeError(f"Unknown checkpoint: {x}") raise RuntimeError(f"Unknown checkpoint: {x}")
modules.sd_models.reload_model_weights(shared.sd_model, info) modules.sd_models.reload_model_weights(shared.sd_model, info)
p.sd_model = shared.sd_model
def confirm_checkpoints(p, xs): def confirm_checkpoints(p, xs):
......
...@@ -118,7 +118,8 @@ def api_only(): ...@@ -118,7 +118,8 @@ def api_only():
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
def webui(launch_api=False): def webui():
launch_api = cmd_opts.api
initialize() initialize()
while 1: while 1:
...@@ -158,4 +159,4 @@ if __name__ == "__main__": ...@@ -158,4 +159,4 @@ if __name__ == "__main__":
if cmd_opts.nowebui: if cmd_opts.nowebui:
api_only() api_only()
else: else:
webui(cmd_opts.api) webui()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment