Commit 7ae5e142 authored by kurumuz's avatar kurumuz

multi-knn

parent fb3c04d0
......@@ -709,14 +709,18 @@ class EmbedderModel(nn.Module):
import requests
knn_folder = config.knn_folder
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2').cuda()
self.index = pickle.load(requests.get(f"{knn_folder}/index.pkl", stream='True').raw)
self.tag_count = pickle.load(requests.get(f"{knn_folder}/all_tags.pkl", stream='True').raw)
r = requests.get(f"{knn_folder}/knn.index", stream='True')
with open("knn.index", "wb") as f:
f.write(r.content)
self.knn = faiss.read_index("knn.index")
self.tag_count_sorted = sorted(self.tag_count.items(), key=lambda x: x[1], reverse=True)
self.indexes = {}
for folder in knn_folder.split(","):
name, url = folder.split(":")
index = pickle.load(requests.get(f"{url}/index.pkl", stream='True').raw)
tag_count = pickle.load(requests.get(f"{url}/all_tags.pkl", stream='True').raw)
tag_count_sorted = sorted(tag_count.items(), key=lambda x: x[1], reverse=True)
r = requests.get(f"{url}/knn.index", stream='True')
with open("knn.index", "wb") as f:
f.write(r.content)
knn = faiss.read_index("knn.index")
self.indexes[name] = [index, tag_count, tag_count_sorted, knn]
def __call__(self, sentences):
with torch.no_grad():
......@@ -725,9 +729,11 @@ class EmbedderModel(nn.Module):
def get_top_k(self, request):
text = request.prompt
model = request.model
index, tag_count, tag_count_sorted, knn = self.indexes[model]
#check if text is a substring in tag_count.keys()
found = []
for tag, count in self.tag_count_sorted:
for tag, count in tag_count_sorted:
if len(tag) >= len(text) and tag.startswith(text):
found.append([tag, count, 0])
......@@ -735,11 +741,11 @@ class EmbedderModel(nn.Module):
embedding = self([text])
#print(embedding.dtype)
k = 15
D, I = self.knn.search(embedding, k)
D, I = knn.search(embedding, k)
D, I = D.squeeze(), I.squeeze()
for i, id in enumerate(I):
tag = self.index[id]
count = self.tag_count[tag]
tag = index[id]
count = tag_count[tag]
prob = D[i]
results.append([tag, count, prob])
......@@ -753,8 +759,6 @@ class EmbedderModel(nn.Module):
#filter results for >0.5 confidence unless it has the search text in it and confidence is >0.4
results = [x for x in results if x[2] > 0.5 or (x[2] > 0.4 and text in x[0])]
found = sorted(found, key=lambda x: x[1], reverse=True)
print(found)
print(results)
if len(found) > 0:
#results = results[:-len(found)]
results = found + results
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment