Commit 12f4f476 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #1795 from MarkovInequality/learnschedule

Added learning_rate scheduling for TI
parents d7474a51 419e539f
......@@ -189,8 +189,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
losses = torch.zeros((32,))
last_saved_file = "<none>"
......@@ -200,12 +198,27 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if ititial_step > steps:
return embedding, filename
tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
epoch_len = (tr_img_len * num_repeats) + tr_img_len
scheduleIter = iter(LearnSchedule(learn_rate, steps, ititial_step))
(learn_rate, end_step) = next(scheduleIter)
print(f'Training at rate of {learn_rate} until step {end_step}')
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text, _) in pbar:
embedding.step = i + ititial_step
if embedding.step > steps:
break
if embedding.step > end_step:
try:
(learn_rate, end_step) = next(scheduleIter)
except:
break
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
for pg in optimizer.param_groups:
pg['lr'] = learn_rate
if shared.state.interrupted:
break
......@@ -276,3 +289,36 @@ Last saved image: {html.escape(last_saved_image)}<br/>
return embedding, filename
class LearnSchedule:
def __init__(self, learn_rate, max_steps, cur_step=0):
pairs = learn_rate.split(',')
self.rates = []
self.it = 0
self.maxit = 0
for i, pair in enumerate(pairs):
tmp = pair.split(':')
if len(tmp) == 2:
step = int(tmp[1])
if step > cur_step:
self.rates.append((float(tmp[0]), min(step, max_steps)))
self.maxit += 1
if step > max_steps:
return
elif step == -1:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
else:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
def __iter__(self):
return self
def __next__(self):
if self.it < self.maxit:
self.it += 1
return self.rates[self.it - 1]
else:
raise StopIteration
......@@ -1070,7 +1070,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value = "5.0e-03")
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment