Commit f7c8c4fe authored by kurumuz's avatar kurumuz

smh

parent 6a53993a
......@@ -3,7 +3,9 @@ from torch.nn.parameter import Parameter
from basedformer import models
def GPTJTransform(model):
import deepspeed
from deepspeed.module_inject import DSPolicy
class BasedformerGPTJLayerPolicy(DSPolicy):
_orig_layer_class = None
#can't have original layer class because in transformerfork all models are just one class
......@@ -52,8 +54,6 @@ def GPTJTransform(model):
model.forward = model.forward_ds
model.get_embeds = model.get_embeds_ds
import deepspeed
from deepspeed.module_inject import DSPolicy
model = deepspeed.init_inference(
model,
mp_size=1,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment