Commit dab9ddf7 authored by novelailab's avatar novelailab

Merge branch 'main' of https://github.com/NovelAI/hydra-node-http into main

parents 1503edda ddfc0428
......@@ -684,9 +684,9 @@ class EmbedderModel(nn.Module):
import pickle
import requests
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2').cuda()
self.index = pickle.load(requests.get("https://f004.backblazeb2.com/file/naipublicbucketxyz/index.pkl", stream='True').raw)
self.tag_count = pickle.load(requests.get("https://f004.backblazeb2.com/file/naipublicbucketxyz/all_tags.pkl", stream='True').raw)
r = requests.get("https://f004.backblazeb2.com/file/naipublicbucketxyz/knn.index", stream='True')
self.index = pickle.load(requests.get("https://f004.backblazeb2.com/file/naipublicbucketxyz/safe/index.pkl", stream='True').raw)
self.tag_count = pickle.load(requests.get("https://f004.backblazeb2.com/file/naipublicbucketxyz/safe/all_tags.pkl", stream='True').raw)
r = requests.get("https://f004.backblazeb2.com/file/naipublicbucketxyz/safe/knn.index", stream='True')
with open("knn.index", "wb") as f:
f.write(r.content)
......@@ -703,8 +703,8 @@ class EmbedderModel(nn.Module):
#check if text is a substring in tag_count.keys()
found = []
for tag, count in self.tag_count_sorted:
if tag.startswith(text):
found.append([tag, count])
if len(tag) > len(text) and tag.startswith(text):
found.append([tag, count, 0])
results = []
embedding = self([text])
......@@ -716,15 +716,18 @@ class EmbedderModel(nn.Module):
tag = self.index[id]
count = self.tag_count[tag]
prob = D[i]
results.append([tag, count])
results.append([tag, count, prob])
print(results)
#sort results by count and prob after
results = sorted(results, key=lambda x: x[1], reverse=True)
found = found[:5]
for result in found:
if result in results:
if result[0] in results:
results.remove(result)
results = results[:-len(found)]
results = found + results
if len(found) > 0:
results = results[:-len(found)]
results = found + results
return results
......@@ -76,6 +76,7 @@ class Masker(TypedDict):
class Tags(TypedDict):
tag: str
count: int
confidence: float
class GenerationRequest(BaseModel):
prompt: str
......@@ -271,7 +272,7 @@ async def predict_tags(request: TextRequest, authorized: bool = Depends(verify_t
process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds")
return TagOutput(tags=[Tags(tag=tag, count=count) for tag, count in tags])
return TagOutput(tags=[Tags(tag=tag, count=count, confidence=confidence) for tag, count, confidence in tags])
except Exception as e:
traceback.print_exc()
......
export MODEL="embedder"
export DEV="True"
gunicorn main:app --workers 1 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:4315
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment