Commit c99ffa47 authored by novelailab's avatar novelailab

fix checkpointing

parent 7e65fc56
...@@ -16,6 +16,8 @@ import wandb ...@@ -16,6 +16,8 @@ import wandb
from lm_arch.gpt2 import GPT2Model from lm_arch.gpt2 import GPT2Model
import numpy as np import numpy as np
from transformers import AutoTokenizer from transformers import AutoTokenizer
from torch.utils.checkpoint import checkpoint as ck
from math import log2, ceil
def _init_weights(module): def _init_weights(module):
"""Initialize the weights.""" """Initialize the weights."""
...@@ -29,6 +31,43 @@ def _init_weights(module): ...@@ -29,6 +31,43 @@ def _init_weights(module):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def shift_tokens(x, amt, eps = 1e-5):
n, device = x.shape[1], x.device
cumsum = x.cumsum(dim = 1)
*x, x_pass = x.chunk(amt + 1, dim = -1)
*x_cumsum, _ = cumsum.chunk(amt + 1, dim = -1)
amts = 2 ** torch.arange(amt)
amts = amts.tolist()
shifts = []
denom = torch.arange(n, device = device)
for x_chunk, x_cumsum_chunk, amt in zip(x, x_cumsum, amts):
shifted_chunk = shift(x_cumsum_chunk, amt, dim = -2) - shift(x_cumsum_chunk, 2 * amt, dim = -2)
shifted_denom = shift(denom, amt, dim = -1) - shift(denom, 2 * amt, dim = -1)
shifted_denom = rearrange(shifted_denom, 'n -> () n ()')
normed_shifted_x = shifted_chunk / (shifted_denom + eps)
shifts.append(normed_shifted_x)
return torch.cat((*shifts, x_pass), dim = -1)
def discounted_cumsum(t, gamma):
try:
from torch_discounted_cumsum import discounted_cumsum_left
except ImportError:
print('unable to import torch_discounted_cumsum - please run `pip install torch-discounted-cumsum`')
b, n, d = t.shape
t = rearrange(t, 'b n d -> (b d) n')
t = discounted_cumsum_left(t, gamma)
t = rearrange(t, '(b d) n -> b n d', b = b)
return t
def shift(x, amt, dim = -1):
return F.pad(x, (*((0, 0) * (-dim - 1)), amt, -amt), value = 0.)
class HyperNetwork(nn.Module): class HyperNetwork(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -36,6 +75,7 @@ class HyperNetwork(nn.Module): ...@@ -36,6 +75,7 @@ class HyperNetwork(nn.Module):
self.linear = nn.Linear(embed_dim, embed_dim//4, bias=True) self.linear = nn.Linear(embed_dim, embed_dim//4, bias=True)
self.linear2 = nn.Linear(embed_dim//4, embed_dim, bias=True) self.linear2 = nn.Linear(embed_dim//4, embed_dim, bias=True)
self.activation = gelu_new self.activation = gelu_new
self.num_shifts = ceil(log2(2048)) - 1
#self.linear.weight.data.normal_(mean=0.0, std=0.02) #self.linear.weight.data.normal_(mean=0.0, std=0.02)
for module in self.modules(): for module in self.modules():
_init_weights(module) _init_weights(module)
...@@ -48,8 +88,10 @@ class HyperNetwork(nn.Module): ...@@ -48,8 +88,10 @@ class HyperNetwork(nn.Module):
#self.load_state_dict(state) #self.load_state_dict(state)
def forward(self, x): def forward(self, x):
x = self.linear(x.float()) x = x.float()
x = self.activation(x) x = shift_tokens(x, self.num_shifts)
x = self.linear(x)
x = ck(self.activation, x)
x = self.linear2(x) x = self.linear2(x)
x = x.mul(torch.sigmoid(x)) x = x.mul(torch.sigmoid(x))
return x.bfloat16() return x.bfloat16()
...@@ -131,7 +173,7 @@ for input_ids, labels in t: ...@@ -131,7 +173,7 @@ for input_ids, labels in t:
loss = 0 loss = 0
for x in range(train_config["gas"]): for x in range(train_config["gas"]):
with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16): with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16):
logits = model(input_ids[x*bs:(x+1)*bs, :].cuda(), hypernetwork=hypernetwork, act_ck=False) logits = model(input_ids[x*bs:(x+1)*bs, :].cuda(), hypernetwork=hypernetwork, act_ck=True)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0])) #print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
logits = logits.view(-1, logits.shape[-1]) logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :].contiguous() gas_labels = labels[x*bs:(x+1)*bs, :].contiguous()
......
...@@ -212,7 +212,7 @@ class FeedForward(nn.Module): ...@@ -212,7 +212,7 @@ class FeedForward(nn.Module):
def forward(self, x, act_ck=False): def forward(self, x, act_ck=False):
x = self.ff1(x) x = self.ff1(x)
if act_ck: if act_ck:
ck(self.activation, x) x = ck(self.activation, x)
else: else:
x = self.activation(x) x = self.activation(x)
x = self.ff2(x) x = self.ff2(x)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment