Commit 2d0b32de authored by novelailab's avatar novelailab

optimizer save works

parent 971ed5dc
......@@ -22,10 +22,13 @@ def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr, cosine_warmup=Fals
class BasedOptimizer:
def __init__(self, parameters, config, optimizer, init=True):
if init:
self.init_config(config)
self.init_optimizer(parameters, optimizer)
self.config = config
self.optimizer_name = optimizer
self.parameters = parameters
self.init_config()
self.init_optimizer()
def init_config(self, config):
def init_config(self):
defaults = {
"lr": 6e-4,
"end_lr": 6e-4,
......@@ -46,18 +49,18 @@ class BasedOptimizer:
for k, v in defaults.items():
setattr(self, k, v)
for k, v in config.items():
for k, v in self.config.items():
setattr(self, k, v)
def init_optimizer(self, parameters, optimizer_name):
if optimizer_name == "adamw":
def init_optimizer(self):
if self.optimizer_name == "adamw":
self.optimizer = optim.AdamW(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
elif optimizer_name == "adamw8bit":
elif self.optimizer_name == "adamw8bit":
import bitsandbytes as bnb
self.optimizer = bnb.optim.Adam8bit(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
elif optimizer_name == "adafactor":
elif self.optimizer_name == "adafactor":
try:
from transformers.optimization import Adafactor
......@@ -68,14 +71,14 @@ class BasedOptimizer:
def step(self, dry_run=False, scaler=None):
self.curr_lr = lr_schedule(self.curr_step, self.warmup_steps, self.anneal_steps, self.lr, self.end_lr)
if not dry_run:
if scaler:
scaler.step(self.optimizer)
else:
self.optimizer.step()
self.curr_lr = lr_schedule(self.curr_step, self.warmup_steps, self.anneal_steps, self.lr, self.end_lr)
self.curr_step = self.curr_step + 1
if not self.max_lr:
......@@ -102,7 +105,10 @@ class BasedOptimizer:
path = path / "opt"
path.mkdir(parents=True, exist_ok=True)
torch.save(self.optimizer.state_dict(), path / "opt_states.pt")
#clean the optimizer and parameters from the dict.
del self.optimizer
del self.parameters
metadata = self.__dict__
with open(path / "opt_metadata.pkl", 'wb') as f:
pickle.dump(metadata, f)
......
......@@ -13,14 +13,16 @@ import torch
# Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset):
def __init__(self, block_size, map_file, max_samples=None):
def __init__(self, block_size, map_file, max_samples=None, skip=0):
self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
self.samples = self.npz.shape[0]
if max_samples is not None:
self.samples = min(self.samples, int(max_samples))
self.skip = 0
self.skip = skip
def __len__(self):
return self.samples
def __getitem__(self, _id):
nth = _id + self.skip
data = torch.tensor(self.npz[nth].astype(np.int64))
......
......@@ -12,7 +12,7 @@ train_config = {
}
model = torch.nn.Linear(10, 100)
save_folder = "models/test_optimizer2"
save_folder = "models/test_optimizer6"
if not os.path.isdir(save_folder + "/opt"):
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
else:
......@@ -21,8 +21,9 @@ else:
wandb.init(project="opt-test", name="test")
for x in tqdm(range(opt.curr_step, 100)):
print(f"Step {opt.curr_step}: LR {opt.curr_lr}")
wandb.log({"lr": opt.curr_lr})
opt.step(dry_run=True)
#if x == 60:
#opt.save(Path(save_folder))
# current step gets iterated before the logging, so negate 1.
print(f"Step {opt.curr_step - 1}: LR {opt.curr_lr}")
wandb.log({"lr": opt.curr_lr})
if x == 60:
opt.save(Path(save_folder))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment