Commit 4b5d5964 authored by kurumuz's avatar kurumuz

changes to search

parent 16a0c79d
...@@ -726,7 +726,7 @@ class EmbedderModel(nn.Module): ...@@ -726,7 +726,7 @@ class EmbedderModel(nn.Module):
results = [] results = []
embedding = self([text]) embedding = self([text])
#print(embedding.dtype) #print(embedding.dtype)
k = 20 k = 15
D, I = self.knn.search(embedding, k) D, I = self.knn.search(embedding, k)
D, I = D.squeeze(), I.squeeze() D, I = D.squeeze(), I.squeeze()
for i, id in enumerate(I): for i, id in enumerate(I):
...@@ -736,12 +736,19 @@ class EmbedderModel(nn.Module): ...@@ -736,12 +736,19 @@ class EmbedderModel(nn.Module):
results.append([tag, count, prob]) results.append([tag, count, prob])
found = found[:5] found = found[:5]
for i, result in enumerate(found): results_tags = [x[0] for x in found]
if result[0] in [x[0] for x in results]: for result in results.copy():
found.remove(result) if result[0] in results_tags:
results.remove(result)
results = sorted(results, key=lambda x: x[2], reverse=True)
#filter results for >0.5 confidence unless it has the search text in it and confidence is >0.4
results = [x for x in results if x[2] > 0.5 or (x[2] > 0.4 and text in x[0])]
found = sorted(found, key=lambda x: x[1], reverse=True)
print(found)
print(results)
if len(found) > 0: if len(found) > 0:
results = results[:-len(found)] #results = results[:-len(found)]
results = found + results results = found + results
#max 10 results #max 10 results
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment