Commit 8b26deda authored by Wes Brown's avatar Wes Brown

Revert mostly to `x=` assignment form.

parent 8073ccfc
......@@ -29,6 +29,7 @@ prompts = ["<|endoftext|>",
"[ Tags:",
"***"]
def _init_weights(module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
......@@ -72,13 +73,19 @@ class HyperNetworkGRU(nn.Module):
param.data.normal_(mean=0.0,
std=(0.02 / math.sqrt(2 * config["n_layer"])))
self.linear_gru = nn.Sequential(
self.linear1,
self.gru)
self.layernorm_linear = nn.Sequential(
self.ln_1,
self.linear2)
def forward(self, x):
return ck(self.activation,
self.linear2(
self.ln_1(
self.gru(
self.linear1(
x.float()))[0]))).bfloat16()
x = x.float()
x = self.linear_gru.forward(x)[0]
x = ck(self.activation,
self.layernorm_linear.forward(x))
return x.bfloat16()
class HyperNetwork(nn.Module):
......@@ -96,11 +103,12 @@ class HyperNetwork(nn.Module):
std=(0.02 / math.sqrt(2 * config["n_layer"])))
def forward(self, x):
x = self.linear2(
ck(self.activation,
self.linear(x.float())))
return x.mul(torch.sigmoid(x)).bfloat16()
x = x.float()
x = self.linear(x)
x = ck(self.activation, x)
x = self.linear2(x)
x = x.mul(torch.sigmoid(x))
return x.bfloat16()
class HyperNetworkSingle(nn.Module):
def __init__(self, config):
......@@ -115,14 +123,12 @@ class HyperNetworkSingle(nn.Module):
for param in self.linear.parameters():
param.data.normal_(mean=0.0,
std=(0.02 / math.sqrt(2 * config["n_layer"])))
# state = self.state_dict()
# for k in state:
# state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
# self.load_state_dict(state)
def forward(self, x):
x = self.linear(x.float())
return x.mul(torch.sigmoid(x)).bfloat16()
x = x.float()
x = self.linear(x)
x = x.mul(torch.sigmoid(x))
return x.bfloat16()
tokenizer = AutoTokenizer.from_pretrained('gpt2')
......@@ -183,14 +189,17 @@ def report_console(data):
print(colored("======================================================",
"red"))
def make_hypernet_saver(train_config, hypernetwork):
def hypernet_saver(id: str):
save_folder = Path(train_config["save_path"]) / id
save_folder.mkdir(parents=True, exist_ok=True)
torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt")
opt.save(save_folder / "opt")
return hypernet_saver
parser = argparse.ArgumentParser(description='Hypernetwork Finetuner')
parser.add_argument('--run_name', type=str, help='the run name to use',
required=True)
......@@ -383,4 +392,4 @@ for input_ids, labels in t:
curr_step += 1
hypernetwork_saver("final")
\ No newline at end of file
hypernetwork_saver("final")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment