Commit 6a53993a authored by kurumuz's avatar kurumuz

terrible hack :?

parent c20db765
from deepspeed.module_inject import DSPolicy
import torch
from torch.nn.parameter import Parameter
from basedformer import models
class BasedformerGPTJLayerPolicy(DSPolicy):
_orig_layer_class = None
#can't have original layer class because in transformerfork all models are just one class
#needs some config from the model.config, including:
#rotary_dim, layer_norm_epsilon
def __init__(self, client_module, inference=True):
super().__init__(inference, scale_attention=True)
self.client_module = client_module
def GPTJTransform(model):
def get_hidden_heads(self):
return self.client_module.attn.q_proj.weight.shape[1], \
self.client_module.attn.n_head
class BasedformerGPTJLayerPolicy(DSPolicy):
_orig_layer_class = None
#can't have original layer class because in transformerfork all models are just one class
#needs some config from the model.config, including:
#rotary_dim, layer_norm_epsilon
def __init__(self, client_module, inference=True):
super().__init__(inference, scale_attention=True)
self.client_module = client_module
def attention(self):
qw = self.client_module.attn.q_proj.weight
kw = self.client_module.attn.k_proj.weight
vw = self.client_module.attn.v_proj.weight
def get_hidden_heads(self):
return self.client_module.attn.q_proj.weight.shape[1], \
self.client_module.attn.n_head
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)
def attention(self):
qw = self.client_module.attn.q_proj.weight
kw = self.client_module.attn.k_proj.weight
vw = self.client_module.attn.v_proj.weight
return self.linear_layer, \
qkvw, \
None, \
self.client_module.attn.out_proj.weight, \
None, \
self.scale_attention, \
self.is_megatron_v2
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)
def mlp(self):
return self.linear_layer, \
self.client_module.ff.ff1.weight, \
self.client_module.ff.ff1.bias, \
self.client_module.ff.ff2.weight, \
self.client_module.ff.ff2.bias
return self.linear_layer, \
qkvw, \
None, \
self.client_module.attn.out_proj.weight, \
None, \
self.scale_attention, \
self.is_megatron_v2
def layerNorm(self):
return None, \
None, \
self.client_module.ln_preattn.weight, \
self.client_module.ln_preattn.bias
def mlp(self):
return self.linear_layer, \
self.client_module.ff.ff1.weight, \
self.client_module.ff.ff1.bias, \
self.client_module.ff.ff2.weight, \
self.client_module.ff.ff2.bias
def GPTJTransform(model):
def layerNorm(self):
return None, \
None, \
self.client_module.ln_preattn.weight, \
self.client_module.ln_preattn.bias
model.config.rotary_dim = model.layers[0].attn.rotary_dim
model.config.layer_norm_epsilon = 1e-5
......@@ -52,6 +53,7 @@ def GPTJTransform(model):
model.get_embeds = model.get_embeds_ds
import deepspeed
from deepspeed.module_inject import DSPolicy
model = deepspeed.init_inference(
model,
mp_size=1,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment