Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
6be644fa
Commit
6be644fa
authored
Jan 11, 2023
by
dan
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Enable batch_size>1 for mixed-sized training
parent
50fb20ce
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
4 deletions
+32
-4
modules/textual_inversion/dataset.py
modules/textual_inversion/dataset.py
+32
-4
No files found.
modules/textual_inversion/dataset.py
View file @
6be644fa
...
...
@@ -3,8 +3,10 @@ import numpy as np
import
PIL
import
torch
from
PIL
import
Image
from
torch.utils.data
import
Dataset
,
DataLoader
from
torch.utils.data
import
Dataset
,
DataLoader
,
Sampler
from
torchvision
import
transforms
from
collections
import
defaultdict
from
random
import
shuffle
,
choices
import
random
import
tqdm
...
...
@@ -45,12 +47,12 @@ class PersonalizedBase(Dataset):
assert
data_root
,
'dataset directory not specified'
assert
os
.
path
.
isdir
(
data_root
),
"Dataset directory doesn't exist"
assert
os
.
listdir
(
data_root
),
"Dataset directory is empty"
assert
batch_size
==
1
or
not
varsize
,
'variable img size must have batch size 1'
self
.
image_paths
=
[
os
.
path
.
join
(
data_root
,
file_path
)
for
file_path
in
os
.
listdir
(
data_root
)]
self
.
shuffle_tags
=
shuffle_tags
self
.
tag_drop_out
=
tag_drop_out
groups
=
defaultdict
(
list
)
print
(
"Preparing dataset..."
)
for
path
in
tqdm
.
tqdm
(
self
.
image_paths
):
...
...
@@ -103,13 +105,14 @@ class PersonalizedBase(Dataset):
if
include_cond
and
not
(
self
.
tag_drop_out
!=
0
or
self
.
shuffle_tags
):
with
devices
.
autocast
():
entry
.
cond
=
cond_model
([
entry
.
cond_text
])
.
to
(
devices
.
cpu
)
.
squeeze
(
0
)
groups
[
image
.
size
]
.
append
(
len
(
self
.
dataset
))
self
.
dataset
.
append
(
entry
)
del
torchdata
del
latent_dist
del
latent_sample
self
.
length
=
len
(
self
.
dataset
)
self
.
groups
=
list
(
groups
.
values
())
assert
self
.
length
>
0
,
"No images have been found in the dataset."
self
.
batch_size
=
min
(
batch_size
,
self
.
length
)
self
.
gradient_step
=
min
(
gradient_step
,
self
.
length
//
self
.
batch_size
)
...
...
@@ -137,9 +140,34 @@ class PersonalizedBase(Dataset):
entry
.
latent_sample
=
shared
.
sd_model
.
get_first_stage_encoding
(
entry
.
latent_dist
)
.
to
(
devices
.
cpu
)
return
entry
class
GroupedBatchSampler
(
Sampler
):
def
__init__
(
self
,
data_source
:
PersonalizedBase
,
batch_size
:
int
):
n
=
len
(
data_source
)
self
.
groups
=
data_source
.
groups
self
.
len
=
n_batch
=
n
//
batch_size
expected
=
[
len
(
g
)
/
n
*
n_batch
*
batch_size
for
g
in
data_source
.
groups
]
self
.
base
=
[
int
(
e
)
//
batch_size
for
e
in
expected
]
self
.
n_rand_batches
=
nrb
=
n_batch
-
sum
(
self
.
base
)
self
.
probs
=
[
e
%
batch_size
/
nrb
/
batch_size
if
nrb
>
0
else
0
for
e
in
expected
]
self
.
batch_size
=
batch_size
def
__len__
(
self
):
return
self
.
len
def
__iter__
(
self
):
b
=
self
.
batch_size
for
g
in
self
.
groups
:
shuffle
(
g
)
batches
=
[]
for
g
in
self
.
groups
:
batches
.
extend
(
g
[
i
*
b
:(
i
+
1
)
*
b
]
for
i
in
range
(
len
(
g
)
//
b
))
for
_
in
range
(
self
.
n_rand_batches
):
rand_group
=
choices
(
self
.
groups
,
self
.
probs
)[
0
]
batches
.
append
(
choices
(
rand_group
,
k
=
b
))
shuffle
(
batches
)
yield
from
batches
class
PersonalizedDataLoader
(
DataLoader
):
def
__init__
(
self
,
dataset
,
latent_sampling_method
=
"once"
,
batch_size
=
1
,
pin_memory
=
False
):
super
(
PersonalizedDataLoader
,
self
)
.
__init__
(
dataset
,
shuffle
=
True
,
drop_last
=
True
,
batch_size
=
batch_size
,
pin_memory
=
pin_memory
)
super
(
PersonalizedDataLoader
,
self
)
.
__init__
(
dataset
,
batch_sampler
=
GroupedBatchSampler
(
dataset
,
batch_size
)
,
pin_memory
=
pin_memory
)
if
latent_sampling_method
==
"random"
:
self
.
collate_fn
=
collate_wrapper_random
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment