Commit 6be644fa authored by dan's avatar dan

Enable batch_size>1 for mixed-sized training

parent 50fb20ce
...@@ -3,8 +3,10 @@ import numpy as np ...@@ -3,8 +3,10 @@ import numpy as np
import PIL import PIL
import torch import torch
from PIL import Image from PIL import Image
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms from torchvision import transforms
from collections import defaultdict
from random import shuffle, choices
import random import random
import tqdm import tqdm
...@@ -45,12 +47,12 @@ class PersonalizedBase(Dataset): ...@@ -45,12 +47,12 @@ class PersonalizedBase(Dataset):
assert data_root, 'dataset directory not specified' assert data_root, 'dataset directory not specified'
assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty" assert os.listdir(data_root), "Dataset directory is empty"
assert batch_size == 1 or not varsize, 'variable img size must have batch size 1'
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
self.shuffle_tags = shuffle_tags self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out self.tag_drop_out = tag_drop_out
groups = defaultdict(list)
print("Preparing dataset...") print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths): for path in tqdm.tqdm(self.image_paths):
...@@ -103,13 +105,14 @@ class PersonalizedBase(Dataset): ...@@ -103,13 +105,14 @@ class PersonalizedBase(Dataset):
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
with devices.autocast(): with devices.autocast():
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
groups[image.size].append(len(self.dataset))
self.dataset.append(entry) self.dataset.append(entry)
del torchdata del torchdata
del latent_dist del latent_dist
del latent_sample del latent_sample
self.length = len(self.dataset) self.length = len(self.dataset)
self.groups = list(groups.values())
assert self.length > 0, "No images have been found in the dataset." assert self.length > 0, "No images have been found in the dataset."
self.batch_size = min(batch_size, self.length) self.batch_size = min(batch_size, self.length)
self.gradient_step = min(gradient_step, self.length // self.batch_size) self.gradient_step = min(gradient_step, self.length // self.batch_size)
...@@ -137,9 +140,34 @@ class PersonalizedBase(Dataset): ...@@ -137,9 +140,34 @@ class PersonalizedBase(Dataset):
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
return entry return entry
class GroupedBatchSampler(Sampler):
def __init__(self, data_source: PersonalizedBase, batch_size: int):
n = len(data_source)
self.groups = data_source.groups
self.len = n_batch = n // batch_size
expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
self.base = [int(e) // batch_size for e in expected]
self.n_rand_batches = nrb = n_batch - sum(self.base)
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
self.batch_size = batch_size
def __len__(self):
return self.len
def __iter__(self):
b = self.batch_size
for g in self.groups:
shuffle(g)
batches = []
for g in self.groups:
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
for _ in range(self.n_rand_batches):
rand_group = choices(self.groups, self.probs)[0]
batches.append(choices(rand_group, k=b))
shuffle(batches)
yield from batches
class PersonalizedDataLoader(DataLoader): class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory) super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
if latent_sampling_method == "random": if latent_sampling_method == "random":
self.collate_fn = collate_wrapper_random self.collate_fn = collate_wrapper_random
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment