Commit 3b246976 authored by novelailab's avatar novelailab

bfloat16 model again

parent 5ff36559
......@@ -146,7 +146,7 @@ class HyperNetworkSingle(nn.Module):
#x = shift_tokens(x, self.num_shifts)
x = self.linear(x)
x = x.mul(torch.sigmoid(x))
return x.half()
return x.bfloat16()
model_config = {
......@@ -194,7 +194,7 @@ gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
#model = GPTModel.gpt2_init(model_config).cuda().float()
model = load_gpt_j().cuda().half()
model = load_gpt_j().cuda().bfloat16()
for param in model.parameters():
param.requires_grad = False
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment