Commit 4e346860 authored by kurumuz's avatar kurumuz

DS and verifytoken

parent 53864bfe
......@@ -26,6 +26,7 @@ RUN pip3 install -e stable-diffusion-private-hypernets/.
RUN pip3 install https://github.com/crowsonkb/k-diffusion/archive/481677d114f6ea445aa009cf5bd7a9cdee909e47.zip
RUN pip3 install simplejpeg
RUN pip3 install min-dalle
RUN pip3 install https://github.com/microsoft/DeepSpeed/archive/55b7b9e008943b8b93d4903d90b255313bb9d82c.zip
#Open ports
EXPOSE 8080
......
......@@ -10,13 +10,17 @@ from dotmap import DotMap
from icecream import ic
from sentry_sdk import capture_exception
from sentry_sdk.integrations.threading import ThreadingIntegration
from hydra_node.models import StableDiffusionModel, DalleMiniModel
from hydra_node.models import StableDiffusionModel, DalleMiniModel, BasedformerModel
import traceback
import zlib
from pathlib import Path
from ldm.modules.attention import CrossAttention, HyperLogic
model_map = {"stable-diffusion": StableDiffusionModel, "dalle-mini": DalleMiniModel}
model_map = {
"stable-diffusion": StableDiffusionModel,
"dalle-mini": DalleMiniModel,
"basedformer": BasedformerModel,
}
def no_init(loading_code):
def dummy(self):
......@@ -143,7 +147,8 @@ def init_config_model():
# Instantiate our actual model.
load_time = time.time()
model_hash = None
try:
if config.model_name != "dalle-mini":
model = no_init(lambda: model_map[config.model_name](config))
......@@ -170,7 +175,7 @@ def init_config_model():
modules = load_modules(config.module_path)
#attach it to the model
model.premodules = modules
config.model = model
# Mark that our model is loaded.
......
......@@ -496,4 +496,89 @@ class DalleMiniModel(nn.Module):
return images
def apply_temp(logits, temperature):
logits = logits / temperature
return logits
@torch.no_grad()
def generate(forward, prompt_tokens, tokenizer, tokens_to_generate=50, ds=False, ops_list=[{"temp": 0.9}], hypernetwork=None, non_deterministic=False, fully_deterministic=False):
in_tokens = prompt_tokens
context = prompt_tokens
generated = torch.zeros(len(ops_list), 0, dtype=torch.long).to(in_tokens.device)
kv = None
if non_deterministic:
torch.seed()
#soft_required = ["top_k", "top_p"]
op_map = {
"temp": apply_temp,
}
for _ in range(tokens_to_generate):
if ds:
logits, kv = forward(in_tokens, past_key_values=kv, use_cache=True)
else:
logits, kv = forward(in_tokens, cache=True, kv=kv, hypernetwork=hypernetwork)
logits = logits[:, -1, :] #get the last token in the seq
logits = torch.log_softmax(logits, dim=-1)
batch = []
for i, ops in enumerate(ops_list):
item = logits[i, ...].unsqueeze(0)
ctx = context[i, ...].unsqueeze(0)
for op, value in ops.items():
if op == "rep_pen":
item = op_map[op](ctx, item, **value)
else:
item = op_map[op](item, value)
batch.append(item)
logits = torch.cat(batch, dim=0)
logits = torch.softmax(logits, dim=-1)
#fully_deterministic makes it deterministic across the batch
if fully_deterministic:
logits = logits.split(1, dim=0)
logit_list = []
for logit in logits:
torch.manual_seed(69)
logit_list.append(torch.multinomial(logit, 1))
logits = torch.cat(logit_list, dim=0)
else:
logits = torch.multinomial(logits, 1)
if logits[0, 0] == 48585:
if generated[0, -1] == 1400:
pass
elif generated[0, -1] == 3363:
return "safe", "none"
else:
return "notsafe", tokenizer.decode(generated.squeeze()).split("Output: ")[-1]
generated = torch.cat([generated, logits], dim=-1)
context = torch.cat([context, logits], dim=-1)
in_tokens = logits
return "unknown", tokenizer.decode(generated.squeeze())
class BasedformerModel(nn.Module):
def __init__(self, config):
nn.Module.__init__(self)
from basedformer import lm_utils
from transformers import GPT2TokenizerFast
self.config = config
self.model = lm_utils.load_from_path(config.model_path).half().cuda()
self.model = self.model.convert_to_ds()
self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
@torch.no_grad()
def sample(self, request):
prompt = request.prompt
prompt = self.tokenizer.encode("Input: " + prompt, return_tensors='pt').cuda().long()
prompt = torch.cat([prompt, torch.tensor([[49527]], dtype=torch.long).cuda()], dim=1)
is_safe, corrected = generate(self.model.module, prompt, self.tokenizer, tokens_to_generate=150, ds=True)
return is_safe, corrected
\ No newline at end of file
......@@ -41,6 +41,7 @@ dalle_mini_forced_defaults = {
defaults = {
'stable-diffusion': (v1pp_defaults, v1pp_forced_defaults),
'dalle-mini': (dalle_mini_defaults, dalle_mini_forced_defaults),
'basedformer': ({}, {}),
}
samplers = [
......@@ -185,6 +186,9 @@ def sanitize_stable_diffusion(request, config):
def sanitize_dalle_mini(request):
return True, request
def sanitize_basedformer(request):
return True, request
def sanitize_input(config, request):
"""
Sanitize the input data and set defaults
......@@ -202,4 +206,7 @@ def sanitize_input(config, request):
return sanitize_stable_diffusion(request, config)
elif config.model_name == 'dalle-mini':
return sanitize_dalle_mini(request)
\ No newline at end of file
return sanitize_dalle_mini(request)
elif config.model_name == 'basedformer':
return sanitize_basedformer(request)
\ No newline at end of file
from fastapi import FastAPI, Request
from fastapi import FastAPI, Request, Depends
from pydantic import BaseModel
from fastapi.responses import HTMLResponse, PlainTextResponse, Response
from fastapi.exceptions import HTTPException
from fastapi.middleware.cors import CORSMiddleware
from sentry_sdk import capture_exception
from sentry_sdk import capture_message
......@@ -24,6 +25,10 @@ from PIL import Image
from PIL.PngImagePlugin import PngInfo
import json
TOKEN = os.getenv("TOKEN", None)
print(TOKEN)
print("Starting Hydra Node HTTP")
#Initialize model and config
model, config, model_hash = init_config_model()
logger = config.logger
......@@ -32,6 +37,41 @@ mainpid = config.mainpid
hostname = socket.gethostname()
sent_first_message = False
def auth_required(handler):
async def wrapper(raw_request: Request, *args, **kwargs):
if TOKEN:
print("got here")
authorization = raw_request.headers.get("authorization")
if authorization is None or authorization != "Bearer "+TOKEN:
return ErrorOutput(error="invalid token")
return await handler(*args, **kwargs)
# Fix signature of wrapper
import inspect
wrapper.__signature__ = inspect.Signature(
parameters = [
# Use all parameters from handler
*inspect.signature(handler).parameters.values(),
# Skip *args and **kwargs from wrapper parameters:
*filter(
lambda p: p.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD),
inspect.signature(wrapper).parameters.values()
)
],
return_annotation = inspect.signature(handler).return_annotation,
)
return wrapper
def verify_token(req: Request):
valid = "Authorization" in req.headers and req.headers["Authorization"] == "Bearer "+TOKEN
if not valid:
raise HTTPException(
status_code=401,
detail="Unauthorized"
)
return True
#Initialize fastapi
app = FastAPI()
......@@ -85,12 +125,20 @@ class GenerationRequest(BaseModel):
module: str = None
masks: List[Masker] = None
class TextRequest(BaseModel):
prompt: str
class TextOutput(BaseModel):
is_safe: str
corrected_text: str
class GenerationOutput(BaseModel):
output: List[str]
class ErrorOutput(BaseModel):
error: str
@auth_required
@app.post('/generate-stream')
def generate(request: GenerationRequest):
t = time.perf_counter()
......@@ -158,27 +206,7 @@ def generate(request: GenerationRequest):
os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
'''
@app.post('/image-to-image')
def image_to_image(request: GenerationRequest):
#prompt is a base64 encoded image
try:
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
image = base64.b64decode(request.prompt)
image = simplejpeg.decode_jpeg(image)
image = model.image_to_image(image, request)
image = simplejpeg.encode_jpeg(image, quality=95)
#get base64 of image
image = base64.b64encode(image).decode("ascii")
return GenerationOutput(output=[image])
'''
@auth_required
@app.post('/generate', response_model=Union[GenerationOutput, ErrorOutput])
def generate(request: GenerationRequest):
t = time.perf_counter()
......@@ -221,5 +249,33 @@ def generate(request: GenerationRequest):
os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
@app.post('/generate-text', response_model=Union[TextOutput, ErrorOutput])
def generate_text(request: TextRequest, authorized: bool = Depends(verify_token)):
t = time.perf_counter()
try:
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
is_safe, corrected_text = model.sample(request)
process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds")
return TextOutput(is_safe=is_safe, corrected_text=corrected_text)
except Exception as e:
traceback.print_exc()
capture_exception(e)
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
logger.error("GPU error, committing seppuku.")
os.kill(mainpid, signal.SIGTERM)
return ErrorOutput(error=str(e))
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=80, log_level="info")
\ No newline at end of file
export MODEL="basedformer"
export DEV="True"
export MODEL_PATH="/home/xuser/nvme1/workspace/arda/basedformer/models/gptj-imagegen-mitigation/final"
export TRANSFORMERS_CACHE="/home/xuser/nvme1/transformer_cache"
export SENTRY_URL="https://49ca8adcf4444f82a10eae1b3fd4182f@o846434.ingest.sentry.io/6612448"
export TOKEN="test_token"
gunicorn main:app --workers 1 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:4315
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment