Commit 6be644fa authored by dan's avatar dan

Enable batch_size>1 for mixed-sized training

parent 50fb20ce
......@@ -3,8 +3,10 @@ import numpy as np
import PIL
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
from collections import defaultdict
from random import shuffle, choices
import random
import tqdm
......@@ -45,12 +47,12 @@ class PersonalizedBase(Dataset):
assert data_root, 'dataset directory not specified'
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
assert batch_size == 1 or not varsize, 'variable img size must have batch size 1'
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
groups = defaultdict(list)
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
......@@ -103,13 +105,14 @@ class PersonalizedBase(Dataset):
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
with devices.autocast():
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
groups[image.size].append(len(self.dataset))
self.dataset.append(entry)
del torchdata
del latent_dist
del latent_sample
self.length = len(self.dataset)
self.groups = list(groups.values())
assert self.length > 0, "No images have been found in the dataset."
self.batch_size = min(batch_size, self.length)
self.gradient_step = min(gradient_step, self.length // self.batch_size)
......@@ -137,9 +140,34 @@ class PersonalizedBase(Dataset):
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
return entry
class GroupedBatchSampler(Sampler):
def __init__(self, data_source: PersonalizedBase, batch_size: int):
n = len(data_source)
self.groups = data_source.groups
self.len = n_batch = n // batch_size
expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
self.base = [int(e) // batch_size for e in expected]
self.n_rand_batches = nrb = n_batch - sum(self.base)
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
self.batch_size = batch_size
def __len__(self):
return self.len
def __iter__(self):
b = self.batch_size
for g in self.groups:
shuffle(g)
batches = []
for g in self.groups:
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
for _ in range(self.n_rand_batches):
rand_group = choices(self.groups, self.probs)[0]
batches.append(choices(rand_group, k=b))
shuffle(batches)
yield from batches
class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
if latent_sampling_method == "random":
self.collate_fn = collate_wrapper_random
else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment