Commit dab9ddf7 authored by novelailab's avatar novelailab

Merge branch 'main' of https://github.com/NovelAI/hydra-node-http into main

parents 1503edda ddfc0428
...@@ -684,9 +684,9 @@ class EmbedderModel(nn.Module): ...@@ -684,9 +684,9 @@ class EmbedderModel(nn.Module):
import pickle import pickle
import requests import requests
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2').cuda() self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2').cuda()
self.index = pickle.load(requests.get("https://f004.backblazeb2.com/file/naipublicbucketxyz/index.pkl", stream='True').raw) self.index = pickle.load(requests.get("https://f004.backblazeb2.com/file/naipublicbucketxyz/safe/index.pkl", stream='True').raw)
self.tag_count = pickle.load(requests.get("https://f004.backblazeb2.com/file/naipublicbucketxyz/all_tags.pkl", stream='True').raw) self.tag_count = pickle.load(requests.get("https://f004.backblazeb2.com/file/naipublicbucketxyz/safe/all_tags.pkl", stream='True').raw)
r = requests.get("https://f004.backblazeb2.com/file/naipublicbucketxyz/knn.index", stream='True') r = requests.get("https://f004.backblazeb2.com/file/naipublicbucketxyz/safe/knn.index", stream='True')
with open("knn.index", "wb") as f: with open("knn.index", "wb") as f:
f.write(r.content) f.write(r.content)
...@@ -703,8 +703,8 @@ class EmbedderModel(nn.Module): ...@@ -703,8 +703,8 @@ class EmbedderModel(nn.Module):
#check if text is a substring in tag_count.keys() #check if text is a substring in tag_count.keys()
found = [] found = []
for tag, count in self.tag_count_sorted: for tag, count in self.tag_count_sorted:
if tag.startswith(text): if len(tag) > len(text) and tag.startswith(text):
found.append([tag, count]) found.append([tag, count, 0])
results = [] results = []
embedding = self([text]) embedding = self([text])
...@@ -716,15 +716,18 @@ class EmbedderModel(nn.Module): ...@@ -716,15 +716,18 @@ class EmbedderModel(nn.Module):
tag = self.index[id] tag = self.index[id]
count = self.tag_count[tag] count = self.tag_count[tag]
prob = D[i] prob = D[i]
results.append([tag, count]) results.append([tag, count, prob])
print(results)
#sort results by count and prob after #sort results by count and prob after
results = sorted(results, key=lambda x: x[1], reverse=True) results = sorted(results, key=lambda x: x[1], reverse=True)
found = found[:5] found = found[:5]
for result in found: for result in found:
if result in results: if result[0] in results:
results.remove(result) results.remove(result)
results = results[:-len(found)] if len(found) > 0:
results = found + results results = results[:-len(found)]
results = found + results
return results return results
...@@ -76,6 +76,7 @@ class Masker(TypedDict): ...@@ -76,6 +76,7 @@ class Masker(TypedDict):
class Tags(TypedDict): class Tags(TypedDict):
tag: str tag: str
count: int count: int
confidence: float
class GenerationRequest(BaseModel): class GenerationRequest(BaseModel):
prompt: str prompt: str
...@@ -271,7 +272,7 @@ async def predict_tags(request: TextRequest, authorized: bool = Depends(verify_t ...@@ -271,7 +272,7 @@ async def predict_tags(request: TextRequest, authorized: bool = Depends(verify_t
process_time = time.perf_counter() - t process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds") logger.info(f"Request took {process_time:0.3f} seconds")
return TagOutput(tags=[Tags(tag=tag, count=count) for tag, count in tags]) return TagOutput(tags=[Tags(tag=tag, count=count, confidence=confidence) for tag, count, confidence in tags])
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
......
export MODEL="embedder"
export DEV="True"
gunicorn main:app --workers 1 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:4315
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment