Commit 4b5d5964 authored by kurumuz's avatar kurumuz

changes to search

parent 16a0c79d
......@@ -726,7 +726,7 @@ class EmbedderModel(nn.Module):
results = []
embedding = self([text])
#print(embedding.dtype)
k = 20
k = 15
D, I = self.knn.search(embedding, k)
D, I = D.squeeze(), I.squeeze()
for i, id in enumerate(I):
......@@ -736,12 +736,19 @@ class EmbedderModel(nn.Module):
results.append([tag, count, prob])
found = found[:5]
for i, result in enumerate(found):
if result[0] in [x[0] for x in results]:
found.remove(result)
results_tags = [x[0] for x in found]
for result in results.copy():
if result[0] in results_tags:
results.remove(result)
results = sorted(results, key=lambda x: x[2], reverse=True)
#filter results for >0.5 confidence unless it has the search text in it and confidence is >0.4
results = [x for x in results if x[2] > 0.5 or (x[2] > 0.4 and text in x[0])]
found = sorted(found, key=lambda x: x[1], reverse=True)
print(found)
print(results)
if len(found) > 0:
results = results[:-len(found)]
#results = results[:-len(found)]
results = found + results
#max 10 results
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment