Commit c272e1ae authored by kurumuz's avatar kurumuz

fix

parent a1a34f99
...@@ -703,13 +703,13 @@ class EmbedderModel(nn.Module): ...@@ -703,13 +703,13 @@ class EmbedderModel(nn.Module):
#check if text is a substring in tag_count.keys() #check if text is a substring in tag_count.keys()
found = [] found = []
for tag, count in self.tag_count_sorted: for tag, count in self.tag_count_sorted:
if len(tag) > len(text) and tag.startswith(text): if len(tag) >= len(text) and tag.startswith(text):
found.append([tag, count, 0]) found.append([tag, count, 0])
results = [] results = []
embedding = self([text]) embedding = self([text])
#print(embedding.dtype) #print(embedding.dtype)
k = 10 k = 20
D, I = self.knn.search(embedding, k) D, I = self.knn.search(embedding, k)
D, I = D.squeeze(), I.squeeze() D, I = D.squeeze(), I.squeeze()
for i, id in enumerate(I): for i, id in enumerate(I):
...@@ -718,16 +718,16 @@ class EmbedderModel(nn.Module): ...@@ -718,16 +718,16 @@ class EmbedderModel(nn.Module):
prob = D[i] prob = D[i]
results.append([tag, count, prob]) results.append([tag, count, prob])
print(results)
#sort results by count and prob after
results = sorted(results, key=lambda x: x[1], reverse=True)
found = found[:5] found = found[:5]
for result in found: for i, result in enumerate(found):
if result[0] in results: if result[0] in [x[0] for x in results]:
results.remove(result) found.remove(result)
if len(found) > 0: if len(found) > 0:
results = results[:-len(found)] results = results[:-len(found)]
results = found + results results = found + results
#max 10 results
results = results[:10]
results = sorted(results, key=lambda x: x[1], reverse=True)
return results return results
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment