Commit c20db765 authored by kurumuz's avatar kurumuz

make DSinference work with basedformer

parent 6ccb1e1f
......@@ -6,6 +6,7 @@ from . import alibi
from . import vit
from . import resnet
from . import fast
from . import ds_strats
MODEL_MAP = {
"gptj": gptj.GPTJModel,
......
......@@ -3,6 +3,13 @@ import torch.nn.functional as F
from dataclasses import dataclass
from dotmap import DotMap
import math
from basedformer import models
class ConfigClass:
def __init__(self, config):
#set all the key and values in config to attributes of this class
for key, value in config.items():
setattr(self, key, value)
class BaseModel(nn.Module):
def __init__(self, user_config, **kwargs):
......@@ -61,7 +68,8 @@ class BaseModel(nn.Module):
for k, v in self.user_config.items():
full_config[k] = v
full_config = DotMap(full_config)
#full_config = DotMap(full_config)
full_config = ConfigClass(full_config)
return full_config
def forward_with_hidden_states(self, x, target=None, hypernetwork=None, act_ck=False, kv=None, cache=False):
......@@ -119,4 +127,36 @@ class BaseModel(nn.Module):
if cache:
return x, kv_new
else:
return x, None
\ No newline at end of file
return x, None
def get_embeds_ds(self, x, past_key_values=None, use_cache=True):
if past_key_values is None:
past_key_values = [None] * self.n_layer
kv_new = []
x = self.vocab_embed(x)
for layer_id, layer in enumerate(self.layers):
x = layer(x, layer_past=past_key_values[layer_id], use_cache=use_cache)
kv_new.append(x[1])
x = x[0]
x = self.ln_final(x)
if use_cache:
return x, kv_new
else:
return x, None
def forward_ds(self, x, past_key_values=None, use_cache=True):
x, kv = self.get_embeds_ds(x, past_key_values=past_key_values, use_cache=use_cache)
x = self.lm_head(x)
return x, kv
def convert_to_ds(self):
convert_func = models.ds_strats.model_map[self.config.Layer]
model = convert_func(self)
return model
from deepspeed.module_inject import DSPolicy
import torch
from torch.nn.parameter import Parameter
from basedformer import models
class BasedformerGPTJLayerPolicy(DSPolicy):
_orig_layer_class = None
#can't have original layer class because in transformerfork all models are just one class
#needs some config from the model.config, including:
#rotary_dim, layer_norm_epsilon
def __init__(self, client_module, inference=True):
super().__init__(inference, scale_attention=True)
self.client_module = client_module
def get_hidden_heads(self):
return self.client_module.attn.q_proj.weight.shape[1], \
self.client_module.attn.n_head
def attention(self):
qw = self.client_module.attn.q_proj.weight
kw = self.client_module.attn.k_proj.weight
vw = self.client_module.attn.v_proj.weight
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)
return self.linear_layer, \
qkvw, \
None, \
self.client_module.attn.out_proj.weight, \
None, \
self.scale_attention, \
self.is_megatron_v2
def mlp(self):
return self.linear_layer, \
self.client_module.ff.ff1.weight, \
self.client_module.ff.ff1.bias, \
self.client_module.ff.ff2.weight, \
self.client_module.ff.ff2.bias
def layerNorm(self):
return None, \
None, \
self.client_module.ln_preattn.weight, \
self.client_module.ln_preattn.bias
def GPTJTransform(model):
model.config.rotary_dim = model.layers[0].attn.rotary_dim
model.config.layer_norm_epsilon = 1e-5
model.forward = model.forward_ds
model.get_embeds = model.get_embeds_ds
import deepspeed
model = deepspeed.init_inference(
model,
mp_size=1,
dtype=torch.float16,
replace_method="auto",
injection_policy={models.gptj.GPTJLayer: BasedformerGPTJLayerPolicy},
replace_with_kernel_inject=True,
enable_cuda_graph=True,
)
return model
model_map = {
models.gptj.GPTJLayer: GPTJTransform,
}
......@@ -245,7 +245,8 @@ class GPTJModel(base_lm.BaseModel):
'activation': gelu_new,
'SelfAttention': SelfAttention,
'FeedForward': FeedForward,
'masked_softmax_fusion': True,
'q_only': False,
'masked_softmax_fusion': False,
}
base_lm.BaseModel.__init__(self, user_config, **kwargs)
if self.config.masked_softmax_fusion:
......
......@@ -174,7 +174,7 @@ def generate_greedy(forward, prompt_tokens, tokens_to_generate=50, hypernetwork=
return generated
@torch.no_grad()
def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0.9}], hypernetwork=None, non_deterministic=False, fully_deterministic=False):
def generate(forward, prompt_tokens, tokens_to_generate=50, ds=False, ops_list=[{"temp": 0.9}], hypernetwork=None, non_deterministic=False, fully_deterministic=False):
in_tokens = prompt_tokens
context = prompt_tokens
generated = torch.zeros(len(ops_list), 0, dtype=torch.long).to(in_tokens.device)
......@@ -192,7 +192,10 @@ def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0
}
for _ in range(tokens_to_generate):
logits, kv = forward(in_tokens, cache=True, kv=kv, hypernetwork=hypernetwork)
if ds:
logits, kv = forward(in_tokens, past_key_values=kv, use_cache=True)
else:
logits, kv = forward(in_tokens, cache=True, kv=kv, hypernetwork=hypernetwork)
logits = logits[:, -1, :] #get the last token in the seq
logits = torch.log_softmax(logits, dim=-1)
#if kv[0][0].shape[0] == 1 and (kv[0][0].shape[0] != len(ops_list)):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment