Commit 09e4a19b authored by novelailab's avatar novelailab

refactor kinda done, yeet main

parent 840dd7f4
from main import * from lm_arch import lm_class
from lm_arch.utils import *
import time import time
import torch
from time import perf_counter, perf_counter_ns from time import perf_counter, perf_counter_ns
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
...@@ -65,7 +67,8 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True ...@@ -65,7 +67,8 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
with torch.no_grad(): with torch.no_grad():
based_model = load_gpt_j().cuda().half().eval() based_model = lm_class.load_gpt_j().cuda().half().eval()
based_model = based_model.lm
print("Loaded based model") print("Loaded based model")
hf_model = no_init(lambda: AutoModelForCausalLM.from_pretrained('/home/xuser/models/j6b_ckpt_14001')).cuda().half().eval() hf_model = no_init(lambda: AutoModelForCausalLM.from_pretrained('/home/xuser/models/j6b_ckpt_14001')).cuda().half().eval()
print("Loaded hf model") print("Loaded hf model")
......
...@@ -79,7 +79,7 @@ def _attn(query, key, value, causal_mask, masked_bias, ...@@ -79,7 +79,7 @@ def _attn(query, key, value, causal_mask, masked_bias,
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later. # Code copied from HF, might want to sanity check later.
def __init__(self, hidden_dim, n_head, device, dtype): def __init__(self, hidden_dim, n_head, device, dtype):
super().__init__(self) nn.Module.__init__(self)
max_positions = 2049 max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view( bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool() 1, 1, max_positions, max_positions).bool()
...@@ -143,7 +143,7 @@ class SelfAttention(nn.Module): ...@@ -143,7 +143,7 @@ class SelfAttention(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, activation, device, dtype): def __init__(self, dim, hidden_dim, activation, device, dtype):
super().__init__(self) nn.Module.__init__(self)
self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype) self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype)
self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype) self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype)
self.activation = activation self.activation = activation
...@@ -159,7 +159,7 @@ class FeedForward(nn.Module): ...@@ -159,7 +159,7 @@ class FeedForward(nn.Module):
class GPTJLayer(nn.Module): class GPTJLayer(nn.Module):
def __init__(self, attn, ff, hidden_dim, n_head, eps, activation, device, dtype): def __init__(self, attn, ff, hidden_dim, n_head, eps, activation, device, dtype):
super().__init__(self) nn.Module.__init__(self)
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype) self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype) self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype)
...@@ -204,8 +204,8 @@ class GPTJLayer(nn.Module): ...@@ -204,8 +204,8 @@ class GPTJLayer(nn.Module):
return x return x
class GPTJModel(nn.Module): class GPTJModel(nn.Module):
def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation, Layer, device, dtype): def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation=gelu_new, Layer=GPTJLayer, device="cuda", dtype=torch.float16):
super().__init__(self) nn.Module.__init__(self)
self.n_layer = n_layer self.n_layer = n_layer
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype) self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
...@@ -225,4 +225,4 @@ class GPTJModel(nn.Module): ...@@ -225,4 +225,4 @@ class GPTJModel(nn.Module):
for layer_id, layer in enumerate(self.layers): for layer_id, layer in enumerate(self.layers):
x = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck) x = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.ln_final(x) x = self.ln_final(x)
return x return x
\ No newline at end of file
...@@ -2,11 +2,13 @@ from lm_arch import utils ...@@ -2,11 +2,13 @@ from lm_arch import utils
import math import math
import torch import torch
from torch import nn from torch import nn
from lm_arch import gptj
import os import os
#Having common BaseLM functionality in this class instead of the torch LM itself makes sense. #Having common BaseLM functionality in this class instead of the torch LM itself makes sense.
class BaseLM(nn.Module): class BaseLM(nn.Module):
def __init__(self, config=None, lm=None): def __init__(self, config=None, lm=None):
nn.Module.__init__(self)
self.config = config self.config = config
self.lm = lm self.lm = lm
...@@ -44,13 +46,13 @@ class BaseLM(nn.Module): ...@@ -44,13 +46,13 @@ class BaseLM(nn.Module):
return model return model
@classmethod @classmethod
def load(cls, config, path=None, state_dict=None, strict=False): def load(cls, model_class, config, path=None, state_dict=None, strict=False):
# I am kinda sad that we will not have a load function in lm object itself. # I am kinda sad that we will not have a load function in lm object itself.
# might be better to add load functions to that as well but not sure. # might be better to add load functions to that as well but not sure.
if path: if path:
state_dict = utils.SplitCheckpoint(path, device="cuda") state_dict = utils.SplitCheckpoint(path, device="cuda")
lm = config.model_class(**config) lm = model_class(**config)
model = cls(config, lm) model = cls(config, lm)
model.lm.load_state_dict(state_dict, strict=strict) model.lm.load_state_dict(state_dict, strict=strict)
return model return model
...@@ -59,11 +61,25 @@ class BaseLM(nn.Module): ...@@ -59,11 +61,25 @@ class BaseLM(nn.Module):
if self.lm is None: if self.lm is None:
print("No LM object to save. Please first init a model.") print("No LM object to save. Please first init a model.")
return return
try: os.mkdir(path) try: os.mkdir(path)
except: pass except: pass
checkpoint = {} checkpoint = {}
for i, x in enumerate(self.lm.state_dict().items()): for i, x in enumerate(self.lm.state_dict().items()):
checkpoint[x[0]] = f"{path}/b{i}.pt" checkpoint[x[0]] = f"{path}/b{i}.pt"
torch.save(x[1], f"{path}/b{i}.pt") torch.save(x[1], f"{path}/b{i}.pt")
torch.save(checkpoint, f"{path}/m.pt") torch.save(checkpoint, f"{path}/m.pt")
\ No newline at end of file
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gptj.gelu_new,
"Layer": gptj.GPTJLayer
}
model = BaseLM.load(gptj.GPTJModel, config, path, state_dict)
return model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as ck
from einops import rearrange, repeat
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping
import os
from pathlib import Path
import math
def no_init(loading_code):
def dummy(self):
return
modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
original = {}
for mod in modules:
original[mod] = mod.reset_parameters
mod.reset_parameters = dummy
result = loading_code()
for mod in modules:
mod.reset_parameters = original[mod]
return result
SPLIT_WEIGHTS_NAME = "m.pt"
class SplitCheckpoint(MutableMapping):
def __init__(self, name_or_path, device="cpu", subfolder=None):
self.device = device
localpath = Path(name_or_path)
if subfolder is not None:
localpath = localpath / subfolder
if os.path.isfile(localpath):
self.chkpt_dir = localpath.parent
self.remote = False
elif os.path.isfile(localpath / SPLIT_WEIGHTS_NAME):
self.chkpt_dir = localpath
self.checkpoint = torch.load(str(localpath / SPLIT_WEIGHTS_NAME))
self.remote = False
self.checkpoint = self._load(SPLIT_WEIGHTS_NAME, None)
def _load(self, name, shape, **kwparams):
path = str(self.chkpt_dir / name)
return torch.load(path, **kwparams)
def __len__(self):
return len(self.checkpoint)
def __getitem__(self, key):
name = self.checkpoint[key]
if type(name) is tuple:
return self._load(name[0].split('/')[-1], name[1], map_location=self.device)
else:
return self._load(name.split('/')[-1], None, map_location=self.device)
def __setitem__(self, key, value):
return
def __delitem__(self, key, value):
return
def keys(self):
return self.checkpoint.keys()
def __iter__(self):
for key in self.checkpoint:
yield (key, self.__getitem__(key))
def __copy__(self):
return SplitCheckpoint(self.chkpt_dir, device=self.device)
def copy(self):
return SplitCheckpoint(self.chkpt_dir, device=self.device)
#TODO: Might change with non einsum functions?
def get_logits(x, embedding):
return embedding(x)
def gelu_new(x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
def gelu_jax(x):
sqrt_2_over_pi = math.sqrt(2.0 / math.pi)
cdf = 0.5 * (1.0 + torch.tanh(sqrt_2_over_pi * (x + 0.044715 * (x ** 3))))
return cdf
def fixed_pos_embedding(dim=None, seq_len=None, x=None):
if x is None:
x = torch.empty(0)
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2) / dim)).to(x.dtype).to(x.device)
sinusoid_inp = torch.einsum('i , j -> i j', torch.arange(seq_len).to(x.device), inv_freq).float()
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, '... d j -> ... (d j)')
def apply_rotary_pos_emb(x, sincos, offset=0):
sin, cos = map(lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2), sincos)
return (x * cos) + (rotate_every_two(x) * sin)
def _split_heads(tensor, num_heads, attn_head_size, rotary):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(*new_shape)
if rotary:
return tensor
if len(tensor.shape) == 5:
return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features)
elif len(tensor.shape) == 4:
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
def _merge_heads(tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_output = torch.matmul(attn_weights, value).to(value.dtype)
return attn_output
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, hidden_dim, n_head, device="cuda", dtype=torch.float16):
super(SelfAttention, self).__init__()
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.head_dim = hidden_dim // n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = hidden_dim
self.n_head = n_head
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
sin, cos = fixed_pos_embedding(dim=self.rotary_dim, seq_len=max_positions)
self.register_buffer("sin", sin)
self.register_buffer("cos", cos)
def forward(self, x):
query = self.q_proj(x)
key = self.k_proj(x)
value = self.v_proj(x)
query = _split_heads(query, self.n_head, self.head_dim, True)
key = _split_heads(key, self.n_head, self.head_dim, True)
value = _split_heads(value, self.n_head, self.head_dim, False)
offset = 0
if self.rotary_dim < self.head_dim:
k_rot = key[:, :, :, :self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim:]
q_rot = query[:, :, :, :self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim:]
k_rot = apply_rotary_pos_emb(k_rot, (self.sin, self.cos), offset=offset).to(k_rot.dtype)
q_rot = apply_rotary_pos_emb(q_rot, (self.sin, self.cos), offset=offset).to(q_rot.dtype)
key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
key = apply_rotary_pos_emb(key, (self.sin, self.cos), offset=offset).to(key.dtype)
query = apply_rotary_pos_emb(query, (self.sin, self.cos), offset=offset).to(query.dtype)
key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
)
x = _merge_heads(x, self.n_head, self.head_dim)
x = self.out_proj(x)
return x
class FeedForward(nn.Module):
def __init__(self, dim=768, hidden_dim=768*4, activation=nn.GELU(), device="cuda", dtype=torch.float16):
super(FeedForward, self).__init__()
self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype)
self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype)
self.activation = activation
def forward(self, x, act_ck=False):
x = self.ff1(x)
if act_ck:
x = ck(self.activation, x)
else:
x = self.activation(x)
x = self.ff2(x)
return x
class GPTLayer(nn.Module):
def __init__(self, attn=SelfAttention, ff=FeedForward, hidden_dim=768, n_head=4, eps=1e-6, activation=nn.GELU(), device="cuda", dtype=torch.float16):
super(GPTLayer, self).__init__()
self.hidden_dim = hidden_dim
self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype)
self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype)
self.tick = True
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False, interleaving_layers=False, every_n=5):
residual = x
if act_ck:
x = ck(self.ln_preattn, x)
attn_out = ck(self.attn, x)
else:
x = self.ln_preattn(x)
attn_out = self.attn(x)
if hypernetwork:
if diff_hypernets:
if interleaving_layers and layer_id % every_n == 0:
if self.tick:
hyper_out = hypernetwork[0](x)
self.tick = False
else:
hyper_out = hypernetwork[1](x)
self.tick = True
elif layer_id % every_n == 0:
hyper_out = hypernetwork[(layer_id // every_n) - 1](x)
else:
if layer_id % every_n == 0:
hyper_out = hypernetwork(x)
ff_out = self.ff(x, act_ck)
#order of addition matters, i had no idea... fixed a bug here.
x = attn_out + ff_out + residual
#x = residual + attn_out + ff_out -> doesn't match.
if hypernetwork and layer_id % every_n == 0:
x = x + hyper_out
return x
# Can access and change every module from here, as both Layer class and ff and attn classes are passed from GPTModel.
class GPTModel(nn.Module):
def __init__(self, hidden_dim=512, n_layer=12, n_head=4, vocab_dim=50400, eps=1e-4, activation=nn.GELU(), Layer=GPTLayer, device="cuda", dtype=torch.float16):
super(GPTModel, self).__init__()
self.n_layer = n_layer
self.hidden_dim = hidden_dim
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=eps, device=device, dtype=dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(hidden_dim, vocab_dim, bias=True)
for _ in range(n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))
#TODO: Decouple more, maybe even init everything here, not sure. Not modular enough yet.
#TODO: Do we want to pass a config object everywhere? I don't exactly like that but passing a lot of variables is a bit ugly too.
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
for name, p in module.named_parameters():
if ("ff2" in name or "out_proj" in name) and "weight" in name:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * self.n_layer)))
def get_embeds(self, x, hypernetwork=None, act_ck=False):
x = self.vocab_embed(x)
for layer_id, layer in enumerate(self.layers):
x = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.ln_final(x)
return x
def forward(self, x, hypernetwork=None, act_ck=False):
x = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.lm_head(x)
return x.float()
@classmethod
def load(cls, config, path=None, state_dict=None):
if path:
state_dict = SplitCheckpoint(path, device="cuda")
model = no_init(lambda: cls(**config))
model.load_state_dict(state_dict, strict=False)
return model
@classmethod
def init(cls, config):
model = cls(**config)
return model
@classmethod
def neox_init(cls, config):
model = cls(**config)
modules = [*model.layers[:-1], model.vocab_embed, model.ln_final, model.lm_head]
init = small_init_method(config["hidden_dim"])
for param in model.parameters():
init(param)
return model
@classmethod
def simple_init(cls, config):
model = cls(**config)
state = model.state_dict()
for k in state:
state[k] = state[k] / math.sqrt(2 * config["n_layer"])
model.load_state_dict(state)
return model
@classmethod
def gpt2_init(cls, config):
model = cls(**config)
for module in model.modules():
model._init_weights(module)
return model
def save(self, path):
try: os.mkdir(path)
except: pass
checkpoint = {}
for i, x in enumerate(self.state_dict().items()):
checkpoint[x[0]] = f"{path}/b{i}.pt"
torch.save(x[1], f"{path}/b{i}.pt")
torch.save(checkpoint, f"{path}/m.pt")
# TODO: Do we want to have the LM head as a seperate Class? Or just a function? I think we might be better off with a function here and maybe
# also for the self attention, we can just write a function that gets fed in the q, k, v.
def wang_init_method(n_layers, dim):
std = 2 / n_layers / math.sqrt(dim)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
# Stolen from NeoX. For the 20B run wang_init used on the output layer and small_init on rest of the layers.
def small_init_method(dim):
"""Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution."""
std = math.sqrt(2 / (5 * dim))
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gelu_new,
"Layer": GPTLayer
}
model = GPTModel.load(config, path, state_dict)
return model
def init_6b():
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gelu_new,
"Layer": GPTLayer
}
model = GPTModel.init(config)
return model
def init_125m():
config = {
"n_layer": 12,
"n_head": 12,
"hidden_dim": 768,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gelu_new,
"Layer": GPTLayer
}
model = GPTModel.init(config)
return model
def init_1_3b():
config = {
"n_layer": 24,
"n_head": 16,
"hidden_dim": 2048,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gelu_new,
"Layer": GPTLayer
}
model = GPTModel(**config)
return model
\ No newline at end of file
# run bash: -b
# run command: default
# kill: -k name
# start pod: -s name
# gpu: -g --gpu
# amount: -n
# cpu cores: -c
# amount of ram: -r
# image: -i
from novelutils.novelfra import * from novelutils.novelfra import *
from pyfra import * from pyfra import *
import argparse import argparse
import sys import sys
parser = argparse.ArgumentParser(description='Novelfra utility tool for launching pods and deployments on kubernetes with pyfra.') parser = argparse.ArgumentParser(description='Novelfra utility tool for launching pods and deployments on kubernetes with pyfra.')
parser.add_argument('name', nargs="?", type=str, help='Deployment name') parser.add_argument('name', nargs="?", type=str, help='Deployment name')
# Make the default the last one we used. #TODO: Make the default the last one we used.
parser.add_argument('--service', action="store_true", help="""Create a service with the deployment. If a service is not parser.add_argument('--service', action="store_true", help="""Create a service with the deployment. If a service is not
created you won't be able to access the pod outside from the kube network.""") created you won't be able to access the pod outside from the kube network.""")
parser.add_argument('-b', '--bash', action="store_true", help='Run bash instead of python3.') parser.add_argument('-b', '--bash', action="store_true", help='Run bash instead of python3.')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment