Commit 2d0b32de authored by novelailab's avatar novelailab

optimizer save works

parent 971ed5dc
...@@ -22,10 +22,13 @@ def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr, cosine_warmup=Fals ...@@ -22,10 +22,13 @@ def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr, cosine_warmup=Fals
class BasedOptimizer: class BasedOptimizer:
def __init__(self, parameters, config, optimizer, init=True): def __init__(self, parameters, config, optimizer, init=True):
if init: if init:
self.init_config(config) self.config = config
self.init_optimizer(parameters, optimizer) self.optimizer_name = optimizer
self.parameters = parameters
self.init_config()
self.init_optimizer()
def init_config(self, config): def init_config(self):
defaults = { defaults = {
"lr": 6e-4, "lr": 6e-4,
"end_lr": 6e-4, "end_lr": 6e-4,
...@@ -46,18 +49,18 @@ class BasedOptimizer: ...@@ -46,18 +49,18 @@ class BasedOptimizer:
for k, v in defaults.items(): for k, v in defaults.items():
setattr(self, k, v) setattr(self, k, v)
for k, v in config.items(): for k, v in self.config.items():
setattr(self, k, v) setattr(self, k, v)
def init_optimizer(self, parameters, optimizer_name): def init_optimizer(self):
if optimizer_name == "adamw": if self.optimizer_name == "adamw":
self.optimizer = optim.AdamW(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps) self.optimizer = optim.AdamW(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
elif optimizer_name == "adamw8bit": elif self.optimizer_name == "adamw8bit":
import bitsandbytes as bnb import bitsandbytes as bnb
self.optimizer = bnb.optim.Adam8bit(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps) self.optimizer = bnb.optim.Adam8bit(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
elif optimizer_name == "adafactor": elif self.optimizer_name == "adafactor":
try: try:
from transformers.optimization import Adafactor from transformers.optimization import Adafactor
...@@ -68,14 +71,14 @@ class BasedOptimizer: ...@@ -68,14 +71,14 @@ class BasedOptimizer:
def step(self, dry_run=False, scaler=None): def step(self, dry_run=False, scaler=None):
self.curr_lr = lr_schedule(self.curr_step, self.warmup_steps, self.anneal_steps, self.lr, self.end_lr)
if not dry_run: if not dry_run:
if scaler: if scaler:
scaler.step(self.optimizer) scaler.step(self.optimizer)
else: else:
self.optimizer.step() self.optimizer.step()
self.curr_lr = lr_schedule(self.curr_step, self.warmup_steps, self.anneal_steps, self.lr, self.end_lr)
self.curr_step = self.curr_step + 1 self.curr_step = self.curr_step + 1
if not self.max_lr: if not self.max_lr:
...@@ -102,7 +105,10 @@ class BasedOptimizer: ...@@ -102,7 +105,10 @@ class BasedOptimizer:
path = path / "opt" path = path / "opt"
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
torch.save(self.optimizer.state_dict(), path / "opt_states.pt") torch.save(self.optimizer.state_dict(), path / "opt_states.pt")
#clean the optimizer and parameters from the dict.
del self.optimizer del self.optimizer
del self.parameters
metadata = self.__dict__ metadata = self.__dict__
with open(path / "opt_metadata.pkl", 'wb') as f: with open(path / "opt_metadata.pkl", 'wb') as f:
pickle.dump(metadata, f) pickle.dump(metadata, f)
......
...@@ -13,14 +13,16 @@ import torch ...@@ -13,14 +13,16 @@ import torch
# Does this work with other block_sizes? doesn't seem to. # Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset): class FbDataset(data.Dataset):
def __init__(self, block_size, map_file, max_samples=None): def __init__(self, block_size, map_file, max_samples=None, skip=0):
self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size)) self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
self.samples = self.npz.shape[0] self.samples = self.npz.shape[0]
if max_samples is not None: if max_samples is not None:
self.samples = min(self.samples, int(max_samples)) self.samples = min(self.samples, int(max_samples))
self.skip = 0 self.skip = skip
def __len__(self): def __len__(self):
return self.samples return self.samples
def __getitem__(self, _id): def __getitem__(self, _id):
nth = _id + self.skip nth = _id + self.skip
data = torch.tensor(self.npz[nth].astype(np.int64)) data = torch.tensor(self.npz[nth].astype(np.int64))
......
...@@ -12,7 +12,7 @@ train_config = { ...@@ -12,7 +12,7 @@ train_config = {
} }
model = torch.nn.Linear(10, 100) model = torch.nn.Linear(10, 100)
save_folder = "models/test_optimizer2" save_folder = "models/test_optimizer6"
if not os.path.isdir(save_folder + "/opt"): if not os.path.isdir(save_folder + "/opt"):
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw") opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
else: else:
...@@ -21,8 +21,9 @@ else: ...@@ -21,8 +21,9 @@ else:
wandb.init(project="opt-test", name="test") wandb.init(project="opt-test", name="test")
for x in tqdm(range(opt.curr_step, 100)): for x in tqdm(range(opt.curr_step, 100)):
print(f"Step {opt.curr_step}: LR {opt.curr_lr}")
wandb.log({"lr": opt.curr_lr})
opt.step(dry_run=True) opt.step(dry_run=True)
#if x == 60: # current step gets iterated before the logging, so negate 1.
#opt.save(Path(save_folder)) print(f"Step {opt.curr_step - 1}: LR {opt.curr_lr}")
wandb.log({"lr": opt.curr_lr})
if x == 60:
opt.save(Path(save_folder))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment