Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
8e7097d0
Commit
8e7097d0
authored
Oct 19, 2022
by
random_thoughtss
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added support for RunwayML inpainting model
parent
604620a7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
293 additions
and
15 deletions
+293
-15
modules/processing.py
modules/processing.py
+32
-2
modules/sd_hijack_inpainting.py
modules/sd_hijack_inpainting.py
+208
-0
modules/sd_models.py
modules/sd_models.py
+15
-1
modules/sd_samplers.py
modules/sd_samplers.py
+38
-12
No files found.
modules/processing.py
View file @
8e7097d0
...
...
@@ -546,7 +546,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if
not
self
.
enable_hr
:
x
=
create_random_tensors
([
opt_C
,
self
.
height
//
opt_f
,
self
.
width
//
opt_f
],
seeds
=
seeds
,
subseeds
=
subseeds
,
subseed_strength
=
self
.
subseed_strength
,
seed_resize_from_h
=
self
.
seed_resize_from_h
,
seed_resize_from_w
=
self
.
seed_resize_from_w
,
p
=
self
)
samples
=
self
.
sampler
.
sample
(
self
,
x
,
conditioning
,
unconditional_conditioning
)
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning
=
torch
.
zeros
(
x
.
shape
[
0
],
3
,
self
.
height
,
self
.
width
,
device
=
x
.
device
)
image_conditioning
=
self
.
sd_model
.
get_first_stage_encoding
(
self
.
sd_model
.
encode_first_stage
(
image_conditioning
))
# Add the fake full 1s mask to the first dimension.
image_conditioning
=
torch
.
nn
.
functional
.
pad
(
image_conditioning
,
(
0
,
0
,
0
,
0
,
1
,
0
),
value
=
1.0
)
image_conditioning
=
image_conditioning
.
to
(
x
.
dtype
)
samples
=
self
.
sampler
.
sample
(
self
,
x
,
conditioning
,
unconditional_conditioning
,
image_conditioning
=
image_conditioning
)
return
samples
x
=
create_random_tensors
([
opt_C
,
self
.
firstphase_height
//
opt_f
,
self
.
firstphase_width
//
opt_f
],
seeds
=
seeds
,
subseeds
=
subseeds
,
subseed_strength
=
self
.
subseed_strength
,
seed_resize_from_h
=
self
.
seed_resize_from_h
,
seed_resize_from_w
=
self
.
seed_resize_from_w
,
p
=
self
)
...
...
@@ -714,10 +723,31 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif
self
.
inpainting_fill
==
3
:
self
.
init_latent
=
self
.
init_latent
*
self
.
mask
if
self
.
image_mask
is
not
None
:
conditioning_mask
=
np
.
array
(
self
.
image_mask
.
convert
(
"L"
))
conditioning_mask
=
conditioning_mask
.
astype
(
np
.
float32
)
/
255.0
conditioning_mask
=
torch
.
from_numpy
(
conditioning_mask
[
None
,
None
])
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask
=
torch
.
round
(
conditioning_mask
)
else
:
conditioning_mask
=
torch
.
ones
(
1
,
1
,
*
image
.
shape
[
-
2
:])
# Create another latent image, this time with a masked version of the original input.
conditioning_mask
=
conditioning_mask
.
to
(
image
.
device
)
conditioning_image
=
image
*
(
1.0
-
conditioning_mask
)
conditioning_image
=
self
.
sd_model
.
get_first_stage_encoding
(
self
.
sd_model
.
encode_first_stage
(
conditioning_image
))
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask
=
torch
.
nn
.
functional
.
interpolate
(
conditioning_mask
,
size
=
self
.
init_latent
.
shape
[
-
2
:])
conditioning_mask
=
conditioning_mask
.
expand
(
conditioning_image
.
shape
[
0
],
-
1
,
-
1
,
-
1
)
self
.
image_conditioning
=
torch
.
cat
([
conditioning_mask
,
conditioning_image
],
dim
=
1
)
self
.
image_conditioning
=
self
.
image_conditioning
.
to
(
shared
.
device
)
.
type
(
self
.
sd_model
.
dtype
)
def
sample
(
self
,
conditioning
,
unconditional_conditioning
,
seeds
,
subseeds
,
subseed_strength
):
x
=
create_random_tensors
([
opt_C
,
self
.
height
//
opt_f
,
self
.
width
//
opt_f
],
seeds
=
seeds
,
subseeds
=
subseeds
,
subseed_strength
=
self
.
subseed_strength
,
seed_resize_from_h
=
self
.
seed_resize_from_h
,
seed_resize_from_w
=
self
.
seed_resize_from_w
,
p
=
self
)
samples
=
self
.
sampler
.
sample_img2img
(
self
,
self
.
init_latent
,
x
,
conditioning
,
unconditional_conditioning
)
samples
=
self
.
sampler
.
sample_img2img
(
self
,
self
.
init_latent
,
x
,
conditioning
,
unconditional_conditioning
,
image_conditioning
=
self
.
image_conditioning
)
if
self
.
mask
is
not
None
:
samples
=
samples
*
self
.
nmask
+
self
.
init_latent
*
self
.
mask
...
...
modules/sd_hijack_inpainting.py
0 → 100644
View file @
8e7097d0
import
torch
import
numpy
as
np
from
tqdm
import
tqdm
from
einops
import
rearrange
,
repeat
from
omegaconf
import
ListConfig
from
types
import
MethodType
import
ldm.models.diffusion.ddpm
import
ldm.models.diffusion.ddim
from
ldm.models.diffusion.ddpm
import
LatentDiffusion
from
ldm.models.diffusion.ddim
import
DDIMSampler
,
noise_like
# =================================================================================================
# Monkey patch DDIMSampler methods from RunwayML repo directly.
# Adapted from:
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
# =================================================================================================
@
torch
.
no_grad
()
def
sample
(
self
,
S
,
batch_size
,
shape
,
conditioning
=
None
,
callback
=
None
,
normals_sequence
=
None
,
img_callback
=
None
,
quantize_x0
=
False
,
eta
=
0.
,
mask
=
None
,
x0
=
None
,
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
verbose
=
True
,
x_T
=
None
,
log_every_t
=
100
,
unconditional_guidance_scale
=
1.
,
unconditional_conditioning
=
None
,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**
kwargs
):
if
conditioning
is
not
None
:
if
isinstance
(
conditioning
,
dict
):
ctmp
=
conditioning
[
list
(
conditioning
.
keys
())[
0
]]
while
isinstance
(
ctmp
,
list
):
ctmp
=
elf
.
inpainting_fill
==
2
:
self
.
init_latent
=
self
.
init_latent
*
self
.
mask
+
create_random_tensors
(
self
.
init_latent
.
shape
[
1
:],
all_seeds
[
0
:
self
.
init_latent
.
shape
[
0
]])
*
self
.
nmask
elif
self
.
inpainting_fill
==
3
:
self
.
init_latent
=
self
.
init_latent
*
self
.
mask
if
self
.
image_mask
is
not
None
:
conditioning_mask
=
np
.
array
(
self
.
image_mask
.
convert
(
"L"
))
conditioning_mask
=
conditioning_mask
.
astype
(
np
.
float32
)
/
255.0
conditioning_mask
=
torch
.
from_numpy
(
conditioning_mask
[
None
,
None
])
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask
=
torch
.
round
(
conditioning_mask
)
else
:
conditioning_mask
=
torch
.
ones
(
1
,
1
,
*
image
.
shape
[
-
2
:])
# Create another latent image, this time with a masked version of the original input.
conditioning_mask
=
conditioning_mask
.
to
(
image
.
device
)
conditioning_image
=
image
*
(
1.0
-
conditioning_mask
)
conditioning_image
=
self
.
sd_model
.
get_first_stage_encoding
(
self
.
sd_model
.
encode_first_stage
(
conditioning_image
))
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask
=
torch
.
nn
.
functional
.
interpolate
(
conditioning_mask
,
size
=
self
.
init_latent
.
shape
[
-
2
:])
conditioning_mask
=
conditioning_mask
.
expand
(
conditioning_image
.
shape
[
0
],
-
1
,
-
1
,
-
1
)
self
.
image_conditioning
=
torch
.
cat
([
conditioning_mask
,
conditioning_image
],
dim
=
1
)
self
.
image_conditioning
=
self
.
image_conditioning
.
to
(
shared
.
device
)
.
type
(
self
.
sd_model
.
dtype
)
def
sample
(
self
,
conditioning
,
unconditional_conditioning
,
seeds
,
subseeds
,
subseed_strength
):
x
=
create_random_tensors
([
opctmp
[
0
]
cbs
=
ctmp
.
shape
[
0
]
if
cbs
!=
batch_size
:
print
(
f
"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else
:
if
conditioning
.
shape
[
0
]
!=
batch_size
:
print
(
f
"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
self
.
make_schedule
(
ddim_num_steps
=
S
,
ddim_eta
=
eta
,
verbose
=
verbose
)
# sampling
C
,
H
,
W
=
shape
size
=
(
batch_size
,
C
,
H
,
W
)
print
(
f
'Data shape for DDIM sampling is {size}, eta {eta}'
)
samples
,
intermediates
=
self
.
ddim_sampling
(
conditioning
,
size
,
callback
=
callback
,
img_callback
=
img_callback
,
quantize_denoised
=
quantize_x0
,
mask
=
mask
,
x0
=
x0
,
ddim_use_original_steps
=
False
,
noise_dropout
=
noise_dropout
,
temperature
=
temperature
,
score_corrector
=
score_corrector
,
corrector_kwargs
=
corrector_kwargs
,
x_T
=
x_T
,
log_every_t
=
log_every_t
,
unconditional_guidance_scale
=
unconditional_guidance_scale
,
unconditional_conditioning
=
unconditional_conditioning
,
)
return
samples
,
intermediates
@
torch
.
no_grad
()
def
p_sample_ddim
(
self
,
x
,
c
,
t
,
index
,
repeat_noise
=
False
,
use_original_steps
=
False
,
quantize_denoised
=
False
,
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
unconditional_guidance_scale
=
1.
,
unconditional_conditioning
=
None
):
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
if
unconditional_conditioning
is
None
or
unconditional_guidance_scale
==
1.
:
e_t
=
self
.
model
.
apply_model
(
x
,
t
,
c
)
else
:
x_in
=
torch
.
cat
([
x
]
*
2
)
t_in
=
torch
.
cat
([
t
]
*
2
)
if
isinstance
(
c
,
dict
):
assert
isinstance
(
unconditional_conditioning
,
dict
)
c_in
=
dict
()
for
k
in
c
:
if
isinstance
(
c
[
k
],
list
):
c_in
[
k
]
=
[
torch
.
cat
([
unconditional_conditioning
[
k
][
i
],
c
[
k
][
i
]])
for
i
in
range
(
len
(
c
[
k
]))
]
else
:
c_in
[
k
]
=
torch
.
cat
([
unconditional_conditioning
[
k
],
c
[
k
]])
else
:
c_in
=
torch
.
cat
([
unconditional_conditioning
,
c
])
e_t_uncond
,
e_t
=
self
.
model
.
apply_model
(
x_in
,
t_in
,
c_in
)
.
chunk
(
2
)
e_t
=
e_t_uncond
+
unconditional_guidance_scale
*
(
e_t
-
e_t_uncond
)
if
score_corrector
is
not
None
:
assert
self
.
model
.
parameterization
==
"eps"
e_t
=
score_corrector
.
modify_score
(
self
.
model
,
e_t
,
x
,
t
,
c
,
**
corrector_kwargs
)
alphas
=
self
.
model
.
alphas_cumprod
if
use_original_steps
else
self
.
ddim_alphas
alphas_prev
=
self
.
model
.
alphas_cumprod_prev
if
use_original_steps
else
self
.
ddim_alphas_prev
sqrt_one_minus_alphas
=
self
.
model
.
sqrt_one_minus_alphas_cumprod
if
use_original_steps
else
self
.
ddim_sqrt_one_minus_alphas
sigmas
=
self
.
model
.
ddim_sigmas_for_original_num_steps
if
use_original_steps
else
self
.
ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t
=
torch
.
full
((
b
,
1
,
1
,
1
),
alphas
[
index
],
device
=
device
)
a_prev
=
torch
.
full
((
b
,
1
,
1
,
1
),
alphas_prev
[
index
],
device
=
device
)
sigma_t
=
torch
.
full
((
b
,
1
,
1
,
1
),
sigmas
[
index
],
device
=
device
)
sqrt_one_minus_at
=
torch
.
full
((
b
,
1
,
1
,
1
),
sqrt_one_minus_alphas
[
index
],
device
=
device
)
# current prediction for x_0
pred_x0
=
(
x
-
sqrt_one_minus_at
*
e_t
)
/
a_t
.
sqrt
()
if
quantize_denoised
:
pred_x0
,
_
,
*
_
=
self
.
model
.
first_stage_model
.
quantize
(
pred_x0
)
# direction pointing to x_t
dir_xt
=
(
1.
-
a_prev
-
sigma_t
**
2
)
.
sqrt
()
*
e_t
noise
=
sigma_t
*
noise_like
(
x
.
shape
,
device
,
repeat_noise
)
*
temperature
if
noise_dropout
>
0.
:
noise
=
torch
.
nn
.
functional
.
dropout
(
noise
,
p
=
noise_dropout
)
x_prev
=
a_prev
.
sqrt
()
*
pred_x0
+
dir_xt
+
noise
return
x_prev
,
pred_x0
# =================================================================================================
# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
# Adapted from:
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py
# =================================================================================================
@
torch
.
no_grad
()
def
get_unconditional_conditioning
(
self
,
batch_size
,
null_label
=
None
):
if
null_label
is
not
None
:
xc
=
null_label
if
isinstance
(
xc
,
ListConfig
):
xc
=
list
(
xc
)
if
isinstance
(
xc
,
dict
)
or
isinstance
(
xc
,
list
):
c
=
self
.
get_learned_conditioning
(
xc
)
else
:
if
hasattr
(
xc
,
"to"
):
xc
=
xc
.
to
(
self
.
device
)
c
=
self
.
get_learned_conditioning
(
xc
)
else
:
# todo: get null label from cond_stage_model
raise
NotImplementedError
()
c
=
repeat
(
c
,
"1 ... -> b ..."
,
b
=
batch_size
)
.
to
(
self
.
device
)
return
c
class
LatentInpaintDiffusion
(
LatentDiffusion
):
def
__init__
(
self
,
concat_keys
=
(
"mask"
,
"masked_image"
),
masked_image_key
=
"masked_image"
,
*
args
,
**
kwargs
,
):
super
()
.
__init__
(
*
args
,
**
kwargs
)
self
.
masked_image_key
=
masked_image_key
assert
self
.
masked_image_key
in
concat_keys
self
.
concat_keys
=
concat_keys
def
should_hijack_inpainting
(
checkpoint_info
):
return
str
(
checkpoint_info
.
filename
)
.
endswith
(
"inpainting.ckpt"
)
and
not
checkpoint_info
.
config
.
endswith
(
"inpainting.yaml"
)
def
do_inpainting_hijack
():
ldm
.
models
.
diffusion
.
ddpm
.
get_unconditional_conditioning
=
get_unconditional_conditioning
ldm
.
models
.
diffusion
.
ddpm
.
LatentInpaintDiffusion
=
LatentInpaintDiffusion
ldm
.
models
.
diffusion
.
ddim
.
DDIMSampler
.
p_sample_ddim
=
p_sample_ddim
ldm
.
models
.
diffusion
.
ddim
.
DDIMSampler
.
sample
=
sample
\ No newline at end of file
modules/sd_models.py
View file @
8e7097d0
...
...
@@ -9,6 +9,7 @@ from ldm.util import instantiate_from_config
from
modules
import
shared
,
modelloader
,
devices
from
modules.paths
import
models_path
from
modules.sd_hijack_inpainting
import
do_inpainting_hijack
,
should_hijack_inpainting
model_dir
=
"Stable-diffusion"
model_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
models_path
,
model_dir
))
...
...
@@ -211,6 +212,19 @@ def load_model():
print
(
f
"Loading config from: {checkpoint_info.config}"
)
sd_config
=
OmegaConf
.
load
(
checkpoint_info
.
config
)
if
should_hijack_inpainting
(
checkpoint_info
):
do_inpainting_hijack
()
# Hardcoded config for now...
sd_config
.
model
.
target
=
"ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
sd_config
.
model
.
params
.
use_ema
=
False
sd_config
.
model
.
params
.
conditioning_key
=
"hybrid"
sd_config
.
model
.
params
.
unet_config
.
params
.
in_channels
=
9
# Create a "fake" config with a different name so that we know to unload it when switching models.
checkpoint_info
=
checkpoint_info
.
_replace
(
config
=
checkpoint_info
.
config
.
replace
(
".yaml"
,
"-inpainting.yaml"
))
sd_model
=
instantiate_from_config
(
sd_config
.
model
)
load_model_weights
(
sd_model
,
checkpoint_info
)
...
...
@@ -234,7 +248,7 @@ def reload_model_weights(sd_model, info=None):
if
sd_model
.
sd_model_checkpoint
==
checkpoint_info
.
filename
:
return
if
sd_model
.
sd_checkpoint_info
.
config
!=
checkpoint_info
.
config
:
if
sd_model
.
sd_checkpoint_info
.
config
!=
checkpoint_info
.
config
or
should_hijack_inpainting
(
checkpoint_info
)
!=
should_hijack_inpainting
(
sd_model
.
sd_checkpoint_info
)
:
checkpoints_loaded
.
clear
()
shared
.
sd_model
=
load_model
()
return
shared
.
sd_model
...
...
modules/sd_samplers.py
View file @
8e7097d0
...
...
@@ -136,9 +136,15 @@ class VanillaStableDiffusionSampler:
if
self
.
stop_at
is
not
None
and
self
.
step
>
self
.
stop_at
:
raise
InterruptedException
# Have to unwrap the inpainting conditioning here to perform pre-preocessing
image_conditioning
=
None
if
isinstance
(
cond
,
dict
):
image_conditioning
=
cond
[
"c_concat"
][
0
]
cond
=
cond
[
"c_crossattn"
][
0
]
unconditional_conditioning
=
unconditional_conditioning
[
"c_crossattn"
][
0
]
conds_list
,
tensor
=
prompt_parser
.
reconstruct_multicond_batch
(
cond
,
self
.
step
)
unconditional_conditioning
=
prompt_parser
.
reconstruct_cond_batch
(
unconditional_conditioning
,
self
.
step
)
unconditional_conditioning
=
prompt_parser
.
reconstruct_cond_batch
(
unconditional_conditioning
,
self
.
step
)
assert
all
([
len
(
conds
)
==
1
for
conds
in
conds_list
]),
'composition via AND is not supported for DDIM/PLMS samplers'
cond
=
tensor
...
...
@@ -157,6 +163,10 @@ class VanillaStableDiffusionSampler:
img_orig
=
self
.
sampler
.
model
.
q_sample
(
self
.
init_latent
,
ts
)
x_dec
=
img_orig
*
self
.
mask
+
self
.
nmask
*
x_dec
if
image_conditioning
is
not
None
:
cond
=
{
"c_concat"
:
[
image_conditioning
],
"c_crossattn"
:
[
cond
]}
unconditional_conditioning
=
{
"c_concat"
:
[
image_conditioning
],
"c_crossattn"
:
[
unconditional_conditioning
]}
res
=
self
.
orig_p_sample_ddim
(
x_dec
,
cond
,
ts
,
unconditional_conditioning
=
unconditional_conditioning
,
*
args
,
**
kwargs
)
if
self
.
mask
is
not
None
:
...
...
@@ -182,7 +192,7 @@ class VanillaStableDiffusionSampler:
self
.
mask
=
p
.
mask
if
hasattr
(
p
,
'mask'
)
else
None
self
.
nmask
=
p
.
nmask
if
hasattr
(
p
,
'nmask'
)
else
None
def
sample_img2img
(
self
,
p
,
x
,
noise
,
conditioning
,
unconditional_conditioning
,
steps
=
None
):
def
sample_img2img
(
self
,
p
,
x
,
noise
,
conditioning
,
unconditional_conditioning
,
steps
=
None
,
image_conditioning
=
None
):
steps
,
t_enc
=
setup_img2img_steps
(
p
,
steps
)
self
.
initialize
(
p
)
...
...
@@ -202,7 +212,7 @@ class VanillaStableDiffusionSampler:
return
samples
def
sample
(
self
,
p
,
x
,
conditioning
,
unconditional_conditioning
,
steps
=
None
):
def
sample
(
self
,
p
,
x
,
conditioning
,
unconditional_conditioning
,
steps
=
None
,
image_conditioning
=
None
):
self
.
initialize
(
p
)
self
.
init_latent
=
None
...
...
@@ -210,6 +220,11 @@ class VanillaStableDiffusionSampler:
steps
=
steps
or
p
.
steps
# Wrap the conditioning models with additional image conditioning for inpainting model
if
image_conditioning
is
not
None
:
conditioning
=
{
"c_concat"
:
[
image_conditioning
],
"c_crossattn"
:
[
conditioning
]}
unconditional_conditioning
=
{
"c_concat"
:
[
image_conditioning
],
"c_crossattn"
:
[
unconditional_conditioning
]}
# existing code fails with certain step counts, like 9
try
:
samples_ddim
=
self
.
launch_sampling
(
steps
,
lambda
:
self
.
sampler
.
sample
(
S
=
steps
,
conditioning
=
conditioning
,
batch_size
=
int
(
x
.
shape
[
0
]),
shape
=
x
[
0
]
.
shape
,
verbose
=
False
,
unconditional_guidance_scale
=
p
.
cfg_scale
,
unconditional_conditioning
=
unconditional_conditioning
,
x_T
=
x
,
eta
=
self
.
eta
)[
0
])
...
...
@@ -228,7 +243,7 @@ class CFGDenoiser(torch.nn.Module):
self
.
init_latent
=
None
self
.
step
=
0
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
):
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
,
image_cond
):
if
state
.
interrupted
or
state
.
skipped
:
raise
InterruptedException
...
...
@@ -239,28 +254,29 @@ class CFGDenoiser(torch.nn.Module):
repeats
=
[
len
(
conds_list
[
i
])
for
i
in
range
(
batch_size
)]
x_in
=
torch
.
cat
([
torch
.
stack
([
x
[
i
]
for
_
in
range
(
n
)])
for
i
,
n
in
enumerate
(
repeats
)]
+
[
x
])
image_cond_in
=
torch
.
cat
([
torch
.
stack
([
image_cond
[
i
]
for
_
in
range
(
n
)])
for
i
,
n
in
enumerate
(
repeats
)]
+
[
image_cond
])
sigma_in
=
torch
.
cat
([
torch
.
stack
([
sigma
[
i
]
for
_
in
range
(
n
)])
for
i
,
n
in
enumerate
(
repeats
)]
+
[
sigma
])
if
tensor
.
shape
[
1
]
==
uncond
.
shape
[
1
]:
cond_in
=
torch
.
cat
([
tensor
,
uncond
])
if
shared
.
batch_cond_uncond
:
x_out
=
self
.
inner_model
(
x_in
,
sigma_in
,
cond
=
cond_in
)
x_out
=
self
.
inner_model
(
x_in
,
sigma_in
,
cond
=
{
"c_crossattn"
:
[
cond_in
],
"c_concat"
:
[
image_cond_in
]}
)
else
:
x_out
=
torch
.
zeros_like
(
x_in
)
for
batch_offset
in
range
(
0
,
x_out
.
shape
[
0
],
batch_size
):
a
=
batch_offset
b
=
a
+
batch_size
x_out
[
a
:
b
]
=
self
.
inner_model
(
x_in
[
a
:
b
],
sigma_in
[
a
:
b
],
cond
=
cond_in
[
a
:
b
]
)
x_out
[
a
:
b
]
=
self
.
inner_model
(
x_in
[
a
:
b
],
sigma_in
[
a
:
b
],
cond
=
{
"c_crossattn"
:
[
cond_in
[
a
:
b
]],
"c_concat"
:
[
image_cond_in
[
a
:
b
]]}
)
else
:
x_out
=
torch
.
zeros_like
(
x_in
)
batch_size
=
batch_size
*
2
if
shared
.
batch_cond_uncond
else
batch_size
for
batch_offset
in
range
(
0
,
tensor
.
shape
[
0
],
batch_size
):
a
=
batch_offset
b
=
min
(
a
+
batch_size
,
tensor
.
shape
[
0
])
x_out
[
a
:
b
]
=
self
.
inner_model
(
x_in
[
a
:
b
],
sigma_in
[
a
:
b
],
cond
=
tensor
[
a
:
b
]
)
x_out
[
a
:
b
]
=
self
.
inner_model
(
x_in
[
a
:
b
],
sigma_in
[
a
:
b
],
cond
=
{
"c_crossattn"
:
[
tensor
[
a
:
b
]],
"c_concat"
:
[
image_cond_in
[
a
:
b
]]}
)
x_out
[
-
uncond
.
shape
[
0
]:]
=
self
.
inner_model
(
x_in
[
-
uncond
.
shape
[
0
]:],
sigma_in
[
-
uncond
.
shape
[
0
]:],
cond
=
uncond
)
x_out
[
-
uncond
.
shape
[
0
]:]
=
self
.
inner_model
(
x_in
[
-
uncond
.
shape
[
0
]:],
sigma_in
[
-
uncond
.
shape
[
0
]:],
cond
=
{
"c_crossattn"
:
[
uncond
],
"c_concat"
:
[
image_cond_in
[
-
uncond
.
shape
[
0
]:]]}
)
denoised_uncond
=
x_out
[
-
uncond
.
shape
[
0
]:]
denoised
=
torch
.
clone
(
denoised_uncond
)
...
...
@@ -361,7 +377,7 @@ class KDiffusionSampler:
return
extra_params_kwargs
def
sample_img2img
(
self
,
p
,
x
,
noise
,
conditioning
,
unconditional_conditioning
,
steps
=
None
):
def
sample_img2img
(
self
,
p
,
x
,
noise
,
conditioning
,
unconditional_conditioning
,
steps
=
None
,
image_conditioning
=
None
):
steps
,
t_enc
=
setup_img2img_steps
(
p
,
steps
)
if
p
.
sampler_noise_scheduler_override
:
...
...
@@ -389,11 +405,16 @@ class KDiffusionSampler:
self
.
model_wrap_cfg
.
init_latent
=
x
samples
=
self
.
launch_sampling
(
steps
,
lambda
:
self
.
func
(
self
.
model_wrap_cfg
,
xi
,
extra_args
=
{
'cond'
:
conditioning
,
'uncond'
:
unconditional_conditioning
,
'cond_scale'
:
p
.
cfg_scale
},
disable
=
False
,
callback
=
self
.
callback_state
,
**
extra_params_kwargs
))
samples
=
self
.
launch_sampling
(
steps
,
lambda
:
self
.
func
(
self
.
model_wrap_cfg
,
xi
,
extra_args
=
{
'cond'
:
conditioning
,
'image_cond'
:
image_conditioning
,
'uncond'
:
unconditional_conditioning
,
'cond_scale'
:
p
.
cfg_scale
},
disable
=
False
,
callback
=
self
.
callback_state
,
**
extra_params_kwargs
))
return
samples
def
sample
(
self
,
p
,
x
,
conditioning
,
unconditional_conditioning
,
steps
=
None
):
def
sample
(
self
,
p
,
x
,
conditioning
,
unconditional_conditioning
,
steps
=
None
,
image_conditioning
=
None
):
steps
=
steps
or
p
.
steps
if
p
.
sampler_noise_scheduler_override
:
...
...
@@ -414,7 +435,12 @@ class KDiffusionSampler:
else
:
extra_params_kwargs
[
'sigmas'
]
=
sigmas
samples
=
self
.
launch_sampling
(
steps
,
lambda
:
self
.
func
(
self
.
model_wrap_cfg
,
x
,
extra_args
=
{
'cond'
:
conditioning
,
'uncond'
:
unconditional_conditioning
,
'cond_scale'
:
p
.
cfg_scale
},
disable
=
False
,
callback
=
self
.
callback_state
,
**
extra_params_kwargs
))
samples
=
self
.
launch_sampling
(
steps
,
lambda
:
self
.
func
(
self
.
model_wrap_cfg
,
x
,
extra_args
=
{
'cond'
:
conditioning
,
'image_cond'
:
image_conditioning
,
'uncond'
:
unconditional_conditioning
,
'cond_scale'
:
p
.
cfg_scale
},
disable
=
False
,
callback
=
self
.
callback_state
,
**
extra_params_kwargs
))
return
samples
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment