Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
H
Hydra Node Http
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Hydra Node Http
Commits
4e346860
Commit
4e346860
authored
Aug 26, 2022
by
kurumuz
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
DS and verifytoken
parent
53864bfe
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
188 additions
and
27 deletions
+188
-27
Dockerfile
Dockerfile
+1
-0
hydra_node/config.py
hydra_node/config.py
+9
-4
hydra_node/models.py
hydra_node/models.py
+85
-0
hydra_node/sanitize.py
hydra_node/sanitize.py
+8
-1
main.py
main.py
+78
-22
run_basedformer.sh
run_basedformer.sh
+7
-0
No files found.
Dockerfile
View file @
4e346860
...
...
@@ -26,6 +26,7 @@ RUN pip3 install -e stable-diffusion-private-hypernets/.
RUN
pip3
install
https://github.com/crowsonkb/k-diffusion/archive/481677d114f6ea445aa009cf5bd7a9cdee909e47.zip
RUN
pip3
install
simplejpeg
RUN
pip3
install
min-dalle
RUN
pip3
install
https://github.com/microsoft/DeepSpeed/archive/55b7b9e008943b8b93d4903d90b255313bb9d82c.zip
#Open ports
EXPOSE
8080
...
...
hydra_node/config.py
View file @
4e346860
...
...
@@ -10,13 +10,17 @@ from dotmap import DotMap
from
icecream
import
ic
from
sentry_sdk
import
capture_exception
from
sentry_sdk.integrations.threading
import
ThreadingIntegration
from
hydra_node.models
import
StableDiffusionModel
,
DalleMiniModel
from
hydra_node.models
import
StableDiffusionModel
,
DalleMiniModel
,
BasedformerModel
import
traceback
import
zlib
from
pathlib
import
Path
from
ldm.modules.attention
import
CrossAttention
,
HyperLogic
model_map
=
{
"stable-diffusion"
:
StableDiffusionModel
,
"dalle-mini"
:
DalleMiniModel
}
model_map
=
{
"stable-diffusion"
:
StableDiffusionModel
,
"dalle-mini"
:
DalleMiniModel
,
"basedformer"
:
BasedformerModel
,
}
def
no_init
(
loading_code
):
def
dummy
(
self
):
...
...
@@ -143,7 +147,8 @@ def init_config_model():
# Instantiate our actual model.
load_time
=
time
.
time
()
model_hash
=
None
try
:
if
config
.
model_name
!=
"dalle-mini"
:
model
=
no_init
(
lambda
:
model_map
[
config
.
model_name
](
config
))
...
...
@@ -170,7 +175,7 @@ def init_config_model():
modules
=
load_modules
(
config
.
module_path
)
#attach it to the model
model
.
premodules
=
modules
config
.
model
=
model
# Mark that our model is loaded.
...
...
hydra_node/models.py
View file @
4e346860
...
...
@@ -496,4 +496,89 @@ class DalleMiniModel(nn.Module):
return
images
def
apply_temp
(
logits
,
temperature
):
logits
=
logits
/
temperature
return
logits
@
torch
.
no_grad
()
def
generate
(
forward
,
prompt_tokens
,
tokenizer
,
tokens_to_generate
=
50
,
ds
=
False
,
ops_list
=
[{
"temp"
:
0.9
}],
hypernetwork
=
None
,
non_deterministic
=
False
,
fully_deterministic
=
False
):
in_tokens
=
prompt_tokens
context
=
prompt_tokens
generated
=
torch
.
zeros
(
len
(
ops_list
),
0
,
dtype
=
torch
.
long
)
.
to
(
in_tokens
.
device
)
kv
=
None
if
non_deterministic
:
torch
.
seed
()
#soft_required = ["top_k", "top_p"]
op_map
=
{
"temp"
:
apply_temp
,
}
for
_
in
range
(
tokens_to_generate
):
if
ds
:
logits
,
kv
=
forward
(
in_tokens
,
past_key_values
=
kv
,
use_cache
=
True
)
else
:
logits
,
kv
=
forward
(
in_tokens
,
cache
=
True
,
kv
=
kv
,
hypernetwork
=
hypernetwork
)
logits
=
logits
[:,
-
1
,
:]
#get the last token in the seq
logits
=
torch
.
log_softmax
(
logits
,
dim
=-
1
)
batch
=
[]
for
i
,
ops
in
enumerate
(
ops_list
):
item
=
logits
[
i
,
...
]
.
unsqueeze
(
0
)
ctx
=
context
[
i
,
...
]
.
unsqueeze
(
0
)
for
op
,
value
in
ops
.
items
():
if
op
==
"rep_pen"
:
item
=
op_map
[
op
](
ctx
,
item
,
**
value
)
else
:
item
=
op_map
[
op
](
item
,
value
)
batch
.
append
(
item
)
logits
=
torch
.
cat
(
batch
,
dim
=
0
)
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
#fully_deterministic makes it deterministic across the batch
if
fully_deterministic
:
logits
=
logits
.
split
(
1
,
dim
=
0
)
logit_list
=
[]
for
logit
in
logits
:
torch
.
manual_seed
(
69
)
logit_list
.
append
(
torch
.
multinomial
(
logit
,
1
))
logits
=
torch
.
cat
(
logit_list
,
dim
=
0
)
else
:
logits
=
torch
.
multinomial
(
logits
,
1
)
if
logits
[
0
,
0
]
==
48585
:
if
generated
[
0
,
-
1
]
==
1400
:
pass
elif
generated
[
0
,
-
1
]
==
3363
:
return
"safe"
,
"none"
else
:
return
"notsafe"
,
tokenizer
.
decode
(
generated
.
squeeze
())
.
split
(
"Output: "
)[
-
1
]
generated
=
torch
.
cat
([
generated
,
logits
],
dim
=-
1
)
context
=
torch
.
cat
([
context
,
logits
],
dim
=-
1
)
in_tokens
=
logits
return
"unknown"
,
tokenizer
.
decode
(
generated
.
squeeze
())
class
BasedformerModel
(
nn
.
Module
):
def
__init__
(
self
,
config
):
nn
.
Module
.
__init__
(
self
)
from
basedformer
import
lm_utils
from
transformers
import
GPT2TokenizerFast
self
.
config
=
config
self
.
model
=
lm_utils
.
load_from_path
(
config
.
model_path
)
.
half
()
.
cuda
()
self
.
model
=
self
.
model
.
convert_to_ds
()
self
.
tokenizer
=
GPT2TokenizerFast
.
from_pretrained
(
"gpt2"
)
@
torch
.
no_grad
()
def
sample
(
self
,
request
):
prompt
=
request
.
prompt
prompt
=
self
.
tokenizer
.
encode
(
"Input: "
+
prompt
,
return_tensors
=
'pt'
)
.
cuda
()
.
long
()
prompt
=
torch
.
cat
([
prompt
,
torch
.
tensor
([[
49527
]],
dtype
=
torch
.
long
)
.
cuda
()],
dim
=
1
)
is_safe
,
corrected
=
generate
(
self
.
model
.
module
,
prompt
,
self
.
tokenizer
,
tokens_to_generate
=
150
,
ds
=
True
)
return
is_safe
,
corrected
\ No newline at end of file
hydra_node/sanitize.py
View file @
4e346860
...
...
@@ -41,6 +41,7 @@ dalle_mini_forced_defaults = {
defaults
=
{
'stable-diffusion'
:
(
v1pp_defaults
,
v1pp_forced_defaults
),
'dalle-mini'
:
(
dalle_mini_defaults
,
dalle_mini_forced_defaults
),
'basedformer'
:
({},
{}),
}
samplers
=
[
...
...
@@ -185,6 +186,9 @@ def sanitize_stable_diffusion(request, config):
def
sanitize_dalle_mini
(
request
):
return
True
,
request
def
sanitize_basedformer
(
request
):
return
True
,
request
def
sanitize_input
(
config
,
request
):
"""
Sanitize the input data and set defaults
...
...
@@ -202,4 +206,7 @@ def sanitize_input(config, request):
return
sanitize_stable_diffusion
(
request
,
config
)
elif
config
.
model_name
==
'dalle-mini'
:
return
sanitize_dalle_mini
(
request
)
\ No newline at end of file
return
sanitize_dalle_mini
(
request
)
elif
config
.
model_name
==
'basedformer'
:
return
sanitize_basedformer
(
request
)
\ No newline at end of file
main.py
View file @
4e346860
from
fastapi
import
FastAPI
,
Request
from
fastapi
import
FastAPI
,
Request
,
Depends
from
pydantic
import
BaseModel
from
fastapi.responses
import
HTMLResponse
,
PlainTextResponse
,
Response
from
fastapi.exceptions
import
HTTPException
from
fastapi.middleware.cors
import
CORSMiddleware
from
sentry_sdk
import
capture_exception
from
sentry_sdk
import
capture_message
...
...
@@ -24,6 +25,10 @@ from PIL import Image
from
PIL.PngImagePlugin
import
PngInfo
import
json
TOKEN
=
os
.
getenv
(
"TOKEN"
,
None
)
print
(
TOKEN
)
print
(
"Starting Hydra Node HTTP"
)
#Initialize model and config
model
,
config
,
model_hash
=
init_config_model
()
logger
=
config
.
logger
...
...
@@ -32,6 +37,41 @@ mainpid = config.mainpid
hostname
=
socket
.
gethostname
()
sent_first_message
=
False
def
auth_required
(
handler
):
async
def
wrapper
(
raw_request
:
Request
,
*
args
,
**
kwargs
):
if
TOKEN
:
print
(
"got here"
)
authorization
=
raw_request
.
headers
.
get
(
"authorization"
)
if
authorization
is
None
or
authorization
!=
"Bearer "
+
TOKEN
:
return
ErrorOutput
(
error
=
"invalid token"
)
return
await
handler
(
*
args
,
**
kwargs
)
# Fix signature of wrapper
import
inspect
wrapper
.
__signature__
=
inspect
.
Signature
(
parameters
=
[
# Use all parameters from handler
*
inspect
.
signature
(
handler
)
.
parameters
.
values
(),
# Skip *args and **kwargs from wrapper parameters:
*
filter
(
lambda
p
:
p
.
kind
not
in
(
inspect
.
Parameter
.
VAR_POSITIONAL
,
inspect
.
Parameter
.
VAR_KEYWORD
),
inspect
.
signature
(
wrapper
)
.
parameters
.
values
()
)
],
return_annotation
=
inspect
.
signature
(
handler
)
.
return_annotation
,
)
return
wrapper
def
verify_token
(
req
:
Request
):
valid
=
"Authorization"
in
req
.
headers
and
req
.
headers
[
"Authorization"
]
==
"Bearer "
+
TOKEN
if
not
valid
:
raise
HTTPException
(
status_code
=
401
,
detail
=
"Unauthorized"
)
return
True
#Initialize fastapi
app
=
FastAPI
()
...
...
@@ -85,12 +125,20 @@ class GenerationRequest(BaseModel):
module
:
str
=
None
masks
:
List
[
Masker
]
=
None
class
TextRequest
(
BaseModel
):
prompt
:
str
class
TextOutput
(
BaseModel
):
is_safe
:
str
corrected_text
:
str
class
GenerationOutput
(
BaseModel
):
output
:
List
[
str
]
class
ErrorOutput
(
BaseModel
):
error
:
str
@
auth_required
@
app
.
post
(
'/generate-stream'
)
def
generate
(
request
:
GenerationRequest
):
t
=
time
.
perf_counter
()
...
...
@@ -158,27 +206,7 @@ def generate(request: GenerationRequest):
os
.
kill
(
mainpid
,
signal
.
SIGTERM
)
return
{
"error"
:
str
(
e
)}
'''
@app.post('/image-to-image')
def image_to_image(request: GenerationRequest):
#prompt is a base64 encoded image
try:
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
image = base64.b64decode(request.prompt)
image = simplejpeg.decode_jpeg(image)
image = model.image_to_image(image, request)
image = simplejpeg.encode_jpeg(image, quality=95)
#get base64 of image
image = base64.b64encode(image).decode("ascii")
return GenerationOutput(output=[image])
'''
@
auth_required
@
app
.
post
(
'/generate'
,
response_model
=
Union
[
GenerationOutput
,
ErrorOutput
])
def
generate
(
request
:
GenerationRequest
):
t
=
time
.
perf_counter
()
...
...
@@ -221,5 +249,33 @@ def generate(request: GenerationRequest):
os
.
kill
(
mainpid
,
signal
.
SIGTERM
)
return
{
"error"
:
str
(
e
)}
@
app
.
post
(
'/generate-text'
,
response_model
=
Union
[
TextOutput
,
ErrorOutput
])
def
generate_text
(
request
:
TextRequest
,
authorized
:
bool
=
Depends
(
verify_token
)):
t
=
time
.
perf_counter
()
try
:
output
=
sanitize_input
(
config
,
request
)
if
output
[
0
]:
request
=
output
[
1
]
else
:
return
ErrorOutput
(
error
=
output
[
1
])
is_safe
,
corrected_text
=
model
.
sample
(
request
)
process_time
=
time
.
perf_counter
()
-
t
logger
.
info
(
f
"Request took {process_time:0.3f} seconds"
)
return
TextOutput
(
is_safe
=
is_safe
,
corrected_text
=
corrected_text
)
except
Exception
as
e
:
traceback
.
print_exc
()
capture_exception
(
e
)
logger
.
error
(
str
(
e
))
e_s
=
str
(
e
)
gc
.
collect
()
if
"CUDA out of memory"
in
e_s
or
\
"an illegal memory access"
in
e_s
or
"CUDA"
in
e_s
:
logger
.
error
(
"GPU error, committing seppuku."
)
os
.
kill
(
mainpid
,
signal
.
SIGTERM
)
return
ErrorOutput
(
error
=
str
(
e
))
if
__name__
==
"__main__"
:
uvicorn
.
run
(
"main:app"
,
host
=
"0.0.0.0"
,
port
=
80
,
log_level
=
"info"
)
\ No newline at end of file
run_basedformer.sh
0 → 100644
View file @
4e346860
export
MODEL
=
"basedformer"
export
DEV
=
"True"
export
MODEL_PATH
=
"/home/xuser/nvme1/workspace/arda/basedformer/models/gptj-imagegen-mitigation/final"
export
TRANSFORMERS_CACHE
=
"/home/xuser/nvme1/transformer_cache"
export
SENTRY_URL
=
"https://49ca8adcf4444f82a10eae1b3fd4182f@o846434.ingest.sentry.io/6612448"
export
TOKEN
=
"test_token"
gunicorn main:app
--workers
1
--worker-class
uvicorn.workers.UvicornWorker
--bind
0.0.0.0:4315
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment