Commit 27400dd0 authored by kurumuz's avatar kurumuz

compiled DS

parent 4e38eaac
......@@ -27,7 +27,10 @@ RUN pip3 install https://github.com/crowsonkb/k-diffusion/archive/481677d114f6ea
RUN pip3 install simplejpeg
RUN pip3 install min-dalle
#RUN pip3 install https://github.com/microsoft/DeepSpeed/archive/55b7b9e008943b8b93d4903d90b255313bb9d82c.zip
#basedformer
RUN pip3 install https://www.dropbox.com/s/8ozhhbo1g7y5dsz/basedformer-f7c8c4fe12f8a0acf6588d8d09a8b9b0481895e3.zip?dl=1
#built DS
RUN pip3 install https://www.dropbox.com/s/euzpgpfrs9isf1z/deepspeed-0.7.3%2B55b7b9e0-cp38-cp38-linux_x86_64.whl?dl=0
#Open ports
EXPOSE 8080
......
......@@ -573,7 +573,7 @@ class BasedformerModel(nn.Module):
from transformers import GPT2TokenizerFast
self.config = config
self.model = lm_utils.load_from_path(config.model_path).half().cuda()
#self.model = self.model.convert_to_ds()
self.model = self.model.convert_to_ds()
self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
@torch.no_grad()
......@@ -581,5 +581,5 @@ class BasedformerModel(nn.Module):
prompt = request.prompt
prompt = self.tokenizer.encode("Input: " + prompt, return_tensors='pt').cuda().long()
prompt = torch.cat([prompt, torch.tensor([[49527]], dtype=torch.long).cuda()], dim=1)
is_safe, corrected = generate(self.model, prompt, self.tokenizer, tokens_to_generate=150, ds=False)
is_safe, corrected = generate(self.model.module, prompt, self.tokenizer, tokens_to_generate=150, ds=True)
return is_safe, corrected
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment