How To: Vector embeddings and semantic similarity WITHOUT OpenAI
Table of contents
Why avoid OpenAI
OpenAI is great, but costs money.
Learning to do things with open source models is a valuable skill. Open source models are the future, Google recently said "We Have No Moat, And Neither Does OpenAI"
SentenceTransformerService
I'm just going to share the code.
If you have any questions, plug them into ChatGPT and maybe share the answers in comments below.
First, install the sentence-transformers
and torch
libraries.
pip install sentence-transformers
pip install torch
Then, you can use them like so.
from sentence_transformers import SentenceTransformer, util # type: ignore
from torch import Tensor
class SentenceTransformerService:
"""
A service that uses the MiniLM model to convert text to embeddings and then compares the
embeddings to compute similarity
"""
def __init__(self):
# The MiniLM model converts text to 384-dimensional embeddings
self.model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
def compare(self, text1: str, text2: str) -> float:
# Convert the two sentences into two embeddings
embedding_1: Tensor = self.model.encode(text1, convert_to_tensor=True) # type: ignore
embedding_2: Tensor = self.model.encode(text2, convert_to_tensor=True) # type: ignore
# Compute the cosine similarity between the two embeddings
similarity = util.pytorch_cos_sim(embedding_1, embedding_2)[0][0].item()
# Returns a float from 0.00 to 1.00, where 1.00 means the two sentences are identical
return similarity
The results
The output is as follows.
As you can see, "hello" is equal to "hello", so the result is exactly 1.0
"Hello" and "hi" are also pretty similar, so the result is ~0.81
However, "hello" and "world" are very dissimilar, so the result is only ~0.35
sts = SentenceTransformerService()
print(sts.compare("hello", "hello")) # Output: 1.0
print(sts.compare("hello", "hi")) # Output: 0.8071529865264893
print(sts.compare("hello", "hey")) # Output: 0.702325701713562
print(sts.compare("hello", "good day")) # Output: 0.5354946255683899
print(sts.compare("hello", "hola")) # Output: 0.42029112577438354
print(sts.compare("hello", "world")) # Output: 0.34536516666412354
Not bad for a local model, eh?
Best of luck!