Commit c272e1ae authored by kurumuz's avatar kurumuz

fix

parent a1a34f99
......@@ -703,13 +703,13 @@ class EmbedderModel(nn.Module):
#check if text is a substring in tag_count.keys()
found = []
for tag, count in self.tag_count_sorted:
if len(tag) > len(text) and tag.startswith(text):
if len(tag) >= len(text) and tag.startswith(text):
found.append([tag, count, 0])
results = []
embedding = self([text])
#print(embedding.dtype)
k = 10
k = 20
D, I = self.knn.search(embedding, k)
D, I = D.squeeze(), I.squeeze()
for i, id in enumerate(I):
......@@ -718,16 +718,16 @@ class EmbedderModel(nn.Module):
prob = D[i]
results.append([tag, count, prob])
print(results)
#sort results by count and prob after
results = sorted(results, key=lambda x: x[1], reverse=True)
found = found[:5]
for result in found:
if result[0] in results:
results.remove(result)
for i, result in enumerate(found):
if result[0] in [x[0] for x in results]:
found.remove(result)
if len(found) > 0:
results = results[:-len(found)]
results = found + results
#max 10 results
results = results[:10]
results = sorted(results, key=lambda x: x[1], reverse=True)
return results
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment