Skip to content

Commit

Permalink
Formatted Python code inside notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanwebb authored Nov 4, 2024
1 parent 73ee19e commit f23c01b
Showing 1 changed file with 39 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@
"source": [
"model = SentenceTransformer(\n",
" # Remove 'device='mps' if running on non-Mac device\n",
" \"nomic-ai/nomic-embed-text-v1.5\", trust_remote_code=True, device=\"mps\"\n",
" \"nomic-ai/nomic-embed-text-v1.5\",\n",
" trust_remote_code=True,\n",
" device=\"mps\",\n",
")"
]
},
Expand Down Expand Up @@ -134,10 +136,8 @@
"fields = [\n",
" FieldSchema(name=\"id\", dtype=DataType.INT64, is_primary=True, auto_id=True),\n",
" FieldSchema(name=\"title\", dtype=DataType.VARCHAR, max_length=256),\n",
"\n",
" # First sixth of unnormalized embedding vector\n",
" FieldSchema(name=\"head_embedding\", dtype=DataType.FLOAT_VECTOR, dim=search_dim),\n",
" \n",
" # Entire unnormalized embedding vector\n",
" FieldSchema(name=\"embedding\", dtype=DataType.FLOAT_VECTOR, dim=embedding_dim),\n",
"]\n",
Expand Down Expand Up @@ -171,7 +171,9 @@
"outputs": [],
"source": [
"index_params = client.prepare_index_params()\n",
"index_params.add_index(field_name=\"head_embedding\", index_type=\"FLAT\", metric_type=\"COSINE\")\n",
"index_params.add_index(\n",
" field_name=\"head_embedding\", index_type=\"FLAT\", metric_type=\"COSINE\"\n",
")\n",
"index_params.add_index(field_name=\"embedding\", index_type=\"FLAT\", metric_type=\"COSINE\")\n",
"client.create_index(collection_name, index_params)"
]
Expand Down Expand Up @@ -204,18 +206,14 @@
" # Output of embedding model is unnormalized\n",
" embeddings = model.encode(plot_summary, convert_to_tensor=True)\n",
" head_embeddings = embeddings[:, :search_dim]\n",
" \n",
"\n",
" data = [\n",
" {\n",
" \"title\": title,\n",
" \"head_embedding\": head.cpu().numpy(),\n",
" \"embedding\": embedding.cpu().numpy(),\n",
" }\n",
" for title, head, embedding in zip(\n",
" batch[\"Title\"],\n",
" head_embeddings, \n",
" embeddings\n",
" )\n",
" for title, head, embedding in zip(batch[\"Title\"], head_embeddings, embeddings)\n",
" ]\n",
" res = client.insert(collection_name=collection_name, data=data)"
]
Expand All @@ -237,14 +235,16 @@
"queries = [\n",
" \"An archaeologist searches for ancient artifacts while fighting Nazis.\",\n",
" \"A teenager fakes illness to get off school and have adventures with two friends.\",\n",
" \"A young couple with a kid look after a hotel during winter and the husband goes insane.\"\n",
" \"A young couple with a kid look after a hotel during winter and the husband goes insane.\",\n",
"]\n",
"\n",
"\n",
"# Search the database based on input text\n",
"def embed_search(data):\n",
" embeds = model.encode(data)\n",
" return [x for x in embeds]\n",
"\n",
"\n",
"# This particular model requires us to prefix 'search_query:' to queries\n",
"instruct_queries = [\"search_query: \" + q.strip() for q in queries]\n",
"search_data = embed_search(instruct_queries)\n",
Expand Down Expand Up @@ -307,12 +307,12 @@
],
"source": [
"for query, hits in zip(queries, res):\n",
" rows = [x['entity'] for x in hits][:5]\n",
" rows = [x[\"entity\"] for x in hits][:5]\n",
"\n",
" print(\"Query:\", query)\n",
" print(\"Results:\")\n",
" for row in rows:\n",
" print(row['title'].strip())\n",
" print(row[\"title\"].strip())\n",
" print()"
]
},
Expand All @@ -336,13 +336,9 @@
" Convert a Milvus search result to a Pandas dataframe. This function is specific to our data schema.\n",
"\n",
" \"\"\"\n",
" rows = [x['entity'] for x in hits]\n",
" rows = [x[\"entity\"] for x in hits]\n",
" rows_dict = [\n",
" {\n",
" \"title\": x['title'],\n",
" \"embedding\": torch.tensor(x['embedding'])\n",
" }\n",
" for x in rows\n",
" {\"title\": x[\"title\"], \"embedding\": torch.tensor(x[\"embedding\"])} for x in rows\n",
" ]\n",
" return pd.DataFrame.from_records(rows_dict)\n",
"\n",
Expand Down Expand Up @@ -517,12 +513,12 @@
],
"source": [
"for query, hits in zip(queries, res):\n",
" rows = [x['entity'] for x in hits]\n",
" rows = [x[\"entity\"] for x in hits]\n",
"\n",
" print(\"Query:\", query)\n",
" print(\"Results:\")\n",
" for row in rows:\n",
" print(row['title'].strip())\n",
" print(row[\"title\"].strip())\n",
" print()"
]
},
Expand Down Expand Up @@ -557,11 +553,13 @@
" \"A teenager fakes illness to get off school and have adventures with two friends.\"\n",
"]\n",
"\n",
"\n",
"# Search the database based on input text\n",
"def embed_search(data):\n",
" embeds = model.encode(data)\n",
" return [x for x in embeds]\n",
"\n",
"\n",
"instruct_queries = [\"search_query: \" + q.strip() for q in queries2]\n",
"search_data2 = embed_search(instruct_queries)\n",
"head_search2 = [x[:search_dim] for x in search_data2]\n",
Expand Down Expand Up @@ -592,11 +590,11 @@
],
"source": [
"for query, hits in zip(queries, res):\n",
" rows = [x['entity'] for x in hits]\n",
" rows = [x[\"entity\"] for x in hits]\n",
"\n",
" print(\"Query:\", queries2[0])\n",
" for idx, row in enumerate(rows):\n",
" if row['title'].strip() == \"Ferris Bueller's Day Off\":\n",
" if row[\"title\"].strip() == \"Ferris Bueller's Day Off\":\n",
" print(f\"Row {idx}: Ferris Bueller's Day Off\")"
]
},
Expand Down Expand Up @@ -684,7 +682,9 @@
"client.create_collection(collection_name=collection_name, schema=schema)\n",
"\n",
"index_params = client.prepare_index_params()\n",
"index_params.add_index(field_name=\"head_embedding\", index_type=\"FLAT\", metric_type=\"COSINE\")\n",
"index_params.add_index(\n",
" field_name=\"head_embedding\", index_type=\"FLAT\", metric_type=\"COSINE\"\n",
")\n",
"client.create_index(collection_name, index_params)"
]
},
Expand All @@ -709,18 +709,14 @@
" embeddings = model.encode(plot_summary, convert_to_tensor=True)\n",
" embeddings = torch.flip(embeddings, dims=[-1])\n",
" head_embeddings = embeddings[:, :search_dim]\n",
" \n",
"\n",
" data = [\n",
" {\n",
" \"title\": title,\n",
" \"head_embedding\": head.cpu().numpy(),\n",
" \"embedding\": embedding.cpu().numpy(),\n",
" }\n",
" for title, head, embedding in zip(\n",
" batch[\"Title\"],\n",
" head_embeddings, \n",
" embeddings\n",
" )\n",
" for title, head, embedding in zip(batch[\"Title\"], head_embeddings, embeddings)\n",
" ]\n",
" res = client.insert(collection_name=collection_name, data=data)"
]
Expand All @@ -733,7 +729,9 @@
"source": [
"# Normalize head embeddings\n",
"\n",
"flip_search_data = [torch.flip(torch.tensor(x), dims=[-1]).cpu().numpy() for x in search_data]\n",
"flip_search_data = [\n",
" torch.flip(torch.tensor(x), dims=[-1]).cpu().numpy() for x in search_data\n",
"]\n",
"flip_head_search = [x[:search_dim] for x in flip_search_data]\n",
"\n",
"# Perform standard vector search on subset of embeddings\n",
Expand Down Expand Up @@ -794,7 +792,12 @@
"]\n",
"\n",
"for d in dfs_results:\n",
" print(d[\"query\"], \"\\n\", d[\"results\"][:7][\"title\"].to_string(index=False, header=False), \"\\n\")"
" print(\n",
" d[\"query\"],\n",
" \"\\n\",\n",
" d[\"results\"][:7][\"title\"].to_string(index=False, header=False),\n",
" \"\\n\",\n",
" )"
]
},
{
Expand All @@ -815,10 +818,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Here is a comparison of our search results across methods:\n",
"<div style='margin: auto; width: 80%;'><img src='results-raiders-of-the-lost-ark.png' width='100%'></div>\n",
"<div style='margin: auto; width: 100%;'><img src='results-ferris-buellers-day-off.png' width='100%'></div>\n",
"<div style='margin: auto; width: 80%;'><img src='results-the-shining.png' width='100%'></div>\n",
"Here is a comparison of our search results across methods:\n",
"<div style='margin: auto; width: 80%;'><img src='results-raiders-of-the-lost-ark.png' width='100%'></div>\n",
"<div style='margin: auto; width: 100%;'><img src='results-ferris-buellers-day-off.png' width='100%'></div>\n",
"<div style='margin: auto; width: 80%;'><img src='results-the-shining.png' width='100%'></div>\n",
"We have shown how to use Matryoshka embeddings with Milvus for performing a more efficient semantic search algorithm called \"funnel search.\" We also explored the importance of the reranking and pruning steps of the algorithm, as well as a failure mode when the initial candidate list is too small. Finally, we discussed how the order of the dimensions is important when forming sub-embeddings - it must be in the same way for which the model was trained. Or rather, it is only because the model was trained in a certain way that prefixes of the embeddings are meaningful. Now you know how to implement Matryoshka embeddings and funnel search to reduce the storage costs of semantic search without sacrificing too much retrieval performance!"
]
}
Expand Down

0 comments on commit f23c01b

Please sign in to comment.