mirror of
https://github.com/placeholder-soft/chroma.git
synced 2026-01-12 17:02:54 +08:00
works
This commit is contained in:
@@ -6,7 +6,10 @@ class Chroma:
|
||||
|
||||
_api_url = "http://localhost:8000/api/v1"
|
||||
|
||||
def __init__(self, url=None):
|
||||
# we enable the user to set the space_key in the constructor
|
||||
_space_key = None
|
||||
|
||||
def __init__(self, url=None, app=None, model_version=None, layer=None):
|
||||
"""Initialize Chroma client"""
|
||||
|
||||
if isinstance(url, str) and url.startswith("http"):
|
||||
@@ -14,30 +17,54 @@ class Chroma:
|
||||
|
||||
self.url = url
|
||||
|
||||
def count(self):
|
||||
if app and model_version and layer:
|
||||
self._space_key = app + "_" + model_version + "_" + layer
|
||||
|
||||
def set_context(self, app, model_version, layer):
|
||||
'''
|
||||
Sets the context of the client
|
||||
'''
|
||||
self._space_key = app + "_" + model_version + "_" + layer
|
||||
|
||||
def set_space_key(self, space_key):
|
||||
'''
|
||||
Sets the space key for the client
|
||||
'''
|
||||
self._space_key = space_key
|
||||
|
||||
def get_context(self):
|
||||
'''
|
||||
Returns the space key
|
||||
'''
|
||||
return self._space_key
|
||||
|
||||
def count(self, space_key=None):
|
||||
'''
|
||||
Returns the number of embeddings in the database
|
||||
'''
|
||||
x = requests.get(self._api_url + "/count")
|
||||
payload = json.dumps({"space_key": space_key or self._space_key})
|
||||
x = requests.get(self._api_url + "/count", data=payload)
|
||||
return x.json()
|
||||
|
||||
def fetch(self, where_filter={}, sort=None, limit=None):
|
||||
'''
|
||||
Fetches embeddings from the database
|
||||
'''
|
||||
where_filter["space_key"] = self._space_key
|
||||
|
||||
x = requests.get(self._api_url + "/fetch", data=json.dumps({
|
||||
"where_filter":json.dumps(where_filter),
|
||||
"where_filter":where_filter,
|
||||
"sort":sort,
|
||||
"limit":limit
|
||||
}))
|
||||
return x.json()
|
||||
|
||||
def process(self):
|
||||
def process(self, space_key=None):
|
||||
'''
|
||||
Processes embeddings in the database
|
||||
- currently this only runs hnswlib, doesnt return anything
|
||||
'''
|
||||
requests.get(self._api_url + "/process")
|
||||
requests.get(self._api_url + "/process", data = json.dumps({"space_key": space_key or self._space_key}))
|
||||
return True
|
||||
|
||||
def reset(self):
|
||||
@@ -46,19 +73,6 @@ class Chroma:
|
||||
'''
|
||||
return requests.get(self._api_url + "/reset")
|
||||
|
||||
def persist(self):
|
||||
'''
|
||||
Persists the database to disk in the .chroma folder inside chroma-server
|
||||
'''
|
||||
return requests.get(self._api_url + "/persist")
|
||||
|
||||
def rand(self):
|
||||
'''
|
||||
Stubbed out sampling endpoint, returns a random bisection of the database
|
||||
'''
|
||||
x = requests.get(self._api_url + "/rand")
|
||||
return x.json()
|
||||
|
||||
def heartbeat(self):
|
||||
'''
|
||||
Returns the current server time in milliseconds to check if the server is alive
|
||||
@@ -70,13 +84,22 @@ class Chroma:
|
||||
embedding_data: list,
|
||||
input_uri: list,
|
||||
dataset: list = None,
|
||||
category_name: list = None):
|
||||
category_name: list = None,
|
||||
space_keys: list = None
|
||||
):
|
||||
'''
|
||||
Logs a batch of embeddings to the database
|
||||
- pass in column oriented data lists
|
||||
'''
|
||||
|
||||
if not space_keys:
|
||||
if isinstance(dataset, list):
|
||||
space_keys = [self._space_key] * len(dataset)
|
||||
else:
|
||||
space_keys = self._space_key
|
||||
|
||||
x = requests.post(self._api_url + "/add", data = json.dumps({
|
||||
"space_key": space_keys,
|
||||
"embedding_data": embedding_data,
|
||||
"input_uri": input_uri,
|
||||
"dataset": dataset,
|
||||
@@ -88,7 +111,7 @@ class Chroma:
|
||||
else:
|
||||
return False
|
||||
|
||||
def log_training(self, embedding_data: list, input_uri: list, category_name: list):
|
||||
def log_training(self, embedding_data: list, input_uri: list, category_name: list, space_keys: list = None):
|
||||
'''
|
||||
Small wrapper around log() to log a batch of training embedding
|
||||
- sets dataset to "training"
|
||||
@@ -97,10 +120,11 @@ class Chroma:
|
||||
embedding_data=embedding_data,
|
||||
input_uri=input_uri,
|
||||
dataset="training",
|
||||
category_name=category_name
|
||||
category_name=category_name,
|
||||
space_keys=space_keys
|
||||
)
|
||||
|
||||
def log_production(self, embedding_data: list, input_uri: list, category_name: list):
|
||||
def log_production(self, embedding_data: list, input_uri: list, category_name: list, space_keys: list = None):
|
||||
'''
|
||||
Small wrapper around log() to log a batch of production embedding
|
||||
- sets dataset to "production"
|
||||
@@ -109,10 +133,11 @@ class Chroma:
|
||||
embedding_data=embedding_data,
|
||||
input_uri=input_uri,
|
||||
dataset="production",
|
||||
category_name=category_name
|
||||
category_name=category_name,
|
||||
space_keys=space_keys
|
||||
)
|
||||
|
||||
def log_triage(self, embedding_data: list, input_uri: list, category_name: list):
|
||||
def log_triage(self, embedding_data: list, input_uri: list, category_name: list, space_keys: list = None):
|
||||
'''
|
||||
Small wrapper around log() to log a batch of triage embedding
|
||||
- sets dataset to "triage"
|
||||
@@ -121,14 +146,19 @@ class Chroma:
|
||||
embedding_data=embedding_data,
|
||||
input_uri=input_uri,
|
||||
dataset="triage",
|
||||
category_name=category_name
|
||||
category_name=category_name,
|
||||
space_keys=space_keys
|
||||
)
|
||||
|
||||
def get_nearest_neighbors(self, embedding, n_results=10, category_name=None, dataset="training"):
|
||||
def get_nearest_neighbors(self, embedding, n_results=10, category_name=None, dataset="training", space_key = None):
|
||||
'''
|
||||
Gets the nearest neighbors of a single embedding
|
||||
'''
|
||||
if not space_key:
|
||||
space_key = self._space_key
|
||||
|
||||
x = requests.post(self._api_url + "/get_nearest_neighbors", data = json.dumps({
|
||||
"space_key": space_key,
|
||||
"embedding": embedding,
|
||||
"n_results": n_results,
|
||||
"category_name": category_name,
|
||||
|
||||
@@ -7,11 +7,9 @@ from fastapi import FastAPI, Response, status
|
||||
from chroma_server.db.clickhouse import Clickhouse
|
||||
from chroma_server.index.hnswlib import Hnswlib
|
||||
from chroma_server.algorithms.rand_subsample import rand_bisectional_subsample
|
||||
from chroma_server.types import AddEmbedding, QueryEmbedding
|
||||
from chroma_server.types import AddEmbedding, QueryEmbedding, ProcessEmbedding, FetchEmbedding, CountEmbedding
|
||||
from chroma_server.utils import logger
|
||||
|
||||
|
||||
|
||||
# Boot script
|
||||
db = Clickhouse
|
||||
ann_index = Hnswlib
|
||||
@@ -22,20 +20,15 @@ app = FastAPI(debug=True)
|
||||
app._db = db()
|
||||
app._ann_index = ann_index()
|
||||
|
||||
if not os.path.exists(".chroma"):
|
||||
os.mkdir(".chroma")
|
||||
|
||||
if os.path.exists(".chroma/chroma.parquet"):
|
||||
logger.info("Loading existing chroma database")
|
||||
app._db.load()
|
||||
|
||||
if os.path.exists(".chroma/index.bin"):
|
||||
logger.info("Loading existing chroma index")
|
||||
app._ann_index.load(app._db.count(), len(app._db.fetch(limit=1).embedding_data))
|
||||
|
||||
# scoping
|
||||
# an embedding space is specific to a particular trained model and layer
|
||||
# instead of making the user manage this complexity, we will handle some conventions here
|
||||
# that being said, we will only store a single string, "space_key" in the db, which the user can, in principle, override
|
||||
# - embeddings are always written with and fetched from the same space_key
|
||||
# - indexes are specific to a space_key and a timestamp
|
||||
# - the client can handle the app + model_verison + layer => space_key string generation
|
||||
|
||||
# API Endpoints
|
||||
|
||||
@app.get("/api/v1")
|
||||
async def root():
|
||||
'''
|
||||
@@ -49,78 +42,65 @@ async def add_to_db(new_embedding: AddEmbedding):
|
||||
Save embedding to database
|
||||
- supports single or batched embeddings
|
||||
'''
|
||||
print("add_to_db, new_embedding.space_key", new_embedding, new_embedding.space_key)
|
||||
|
||||
app._db.add_batch(
|
||||
new_embedding.space_key,
|
||||
new_embedding.embedding_data,
|
||||
new_embedding.input_uri,
|
||||
new_embedding.dataset,
|
||||
new_embedding.custom_quality_score,
|
||||
new_embedding.category_name
|
||||
)
|
||||
)
|
||||
|
||||
return {"response": "Added record to database"}
|
||||
|
||||
@app.get("/api/v1/process")
|
||||
async def process():
|
||||
async def process(process_embedding: ProcessEmbedding):
|
||||
'''
|
||||
Currently generates an index for the embedding db
|
||||
'''
|
||||
app._ann_index.run(app._db.fetch())
|
||||
where_filter = {"space_key": process_embedding.space_key}
|
||||
print("process, where_filter", where_filter)
|
||||
app._ann_index.run(process_embedding.space_key, app._db.fetch(where_filter))
|
||||
|
||||
@app.get("/api/v1/fetch")
|
||||
async def fetch(where_filter={}, sort=None, limit=None):
|
||||
async def fetch(fetch_embedding: FetchEmbedding):
|
||||
'''
|
||||
Fetches embeddings from the database
|
||||
- enables filtering by where_filter, sorting by key, and limiting the number of results
|
||||
'''
|
||||
return app._db.fetch(where_filter, sort, limit).to_dict(orient="records")
|
||||
return app._db.fetch(fetch_embedding.where_filter, fetch_embedding.sort, fetch_embedding.limit)
|
||||
|
||||
@app.get("/api/v1/count")
|
||||
async def count():
|
||||
async def count(count_embedding: CountEmbedding):
|
||||
'''
|
||||
Returns the number of records in the database
|
||||
'''
|
||||
return ({"count": app._db.count()})
|
||||
|
||||
@app.get("/api/v1/persist")
|
||||
async def persist():
|
||||
'''
|
||||
Persist the database and index to disk
|
||||
'''
|
||||
if not os.path.exists(".chroma"):
|
||||
os.mkdir(".chroma")
|
||||
|
||||
app._db.persist()
|
||||
app._ann_index.persist()
|
||||
return True
|
||||
return {"count": app._db.count(space_key=count_embedding.space_key)}
|
||||
|
||||
@app.get("/api/v1/reset")
|
||||
async def reset():
|
||||
'''
|
||||
Reset the database and index
|
||||
'''
|
||||
shutil.rmtree(".chroma", ignore_errors=True)
|
||||
app._db = db()
|
||||
app._db.reset()
|
||||
app._ann_index = ann_index()
|
||||
return True
|
||||
|
||||
@app.get("/api/v1/rand")
|
||||
async def rand(where_filter={}, sort=None, limit=None):
|
||||
'''
|
||||
Randomly bisection the database
|
||||
'''
|
||||
results = app._db.fetch(where_filter, sort, limit)
|
||||
rand = rand_bisectional_subsample(results)
|
||||
return rand.to_dict(orient="records")
|
||||
|
||||
@app.post("/api/v1/get_nearest_neighbors")
|
||||
async def get_nearest_neighbors(embedding: QueryEmbedding):
|
||||
'''
|
||||
return the distance, database ids, and embedding themselves for the input embedding
|
||||
'''
|
||||
print("get_nearest_neighbors, embedding.space_key", embedding.space_key)
|
||||
if embedding.space_key is None:
|
||||
return {"error": "space_key is required"}
|
||||
|
||||
ids = None
|
||||
filter_by_where = {}
|
||||
filter_by_where["space_key"] = embedding.space_key
|
||||
if embedding.category_name is not None:
|
||||
filter_by_where['category_name'] = embedding.category_name
|
||||
if embedding.dataset is not None:
|
||||
@@ -128,9 +108,9 @@ async def get_nearest_neighbors(embedding: QueryEmbedding):
|
||||
|
||||
if filter_by_where is not None:
|
||||
results = app._db.fetch(filter_by_where)
|
||||
ids = [str(item[0]) for item in results]
|
||||
ids = [str(item[1]) for item in results] # 1 is the uuid column
|
||||
|
||||
uuids, distances = app._ann_index.get_nearest_neighbors(embedding.embedding, embedding.n_results, ids)
|
||||
uuids, distances = app._ann_index.get_nearest_neighbors(embedding.space_key, embedding.embedding, embedding.n_results, ids)
|
||||
return {
|
||||
"ids": uuids,
|
||||
"embeddings": app._db.get_by_ids(uuids),#
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
# TODO: update this to match the clickhouse implementation
|
||||
class Database():
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
@@ -11,16 +12,4 @@ class Database():
|
||||
|
||||
@abstractmethod
|
||||
def fetch(self, query):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_batch(self, batch):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def persist(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
pass
|
||||
@@ -10,6 +10,7 @@ class Clickhouse(Database):
|
||||
|
||||
def _create_table_embeddings(self):
|
||||
self._conn.execute('''CREATE TABLE IF NOT EXISTS embeddings (
|
||||
space_key String,
|
||||
uuid UUID,
|
||||
embedding_data Array(Float64),
|
||||
input_uri String,
|
||||
@@ -24,22 +25,21 @@ class Clickhouse(Database):
|
||||
self._conn = client
|
||||
self._create_table_embeddings()
|
||||
|
||||
def add_batch(self, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None):
|
||||
|
||||
def add_batch(self, space_key, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None):
|
||||
data_to_insert = []
|
||||
for i in range(len(embedding_data)):
|
||||
data_to_insert.append([uuid.uuid4(), embedding_data[i], input_uri[i], dataset[i], category_name[i]])
|
||||
data_to_insert.append([space_key[i], uuid.uuid4(), embedding_data[i], input_uri[i], dataset[i], category_name[i]])
|
||||
|
||||
self._conn.execute('''
|
||||
INSERT INTO embeddings (uuid, embedding_data, input_uri, dataset, category_name) VALUES''', data_to_insert)
|
||||
INSERT INTO embeddings (space_key, uuid, embedding_data, input_uri, dataset, category_name) VALUES''', data_to_insert)
|
||||
|
||||
def count(self):
|
||||
return self._conn.execute('SELECT COUNT() FROM embeddings')
|
||||
|
||||
def update(self, data): # call this update_custom_quality_score! that is all it does
|
||||
pass
|
||||
def count(self, space_key=None):
|
||||
return self._conn.execute(f"SELECT COUNT() FROM embeddings WHERE space_key = '{space_key}'")[0][0]
|
||||
|
||||
def fetch(self, where_filter={}, sort=None, limit=None):
|
||||
if where_filter["space_key"] is None:
|
||||
return {"error": "space_key is required"}
|
||||
|
||||
s3= time.time()
|
||||
# check to see if query is a dict and if it is a flat list of key value pairs
|
||||
if where_filter is not None:
|
||||
@@ -62,8 +62,11 @@ class Clickhouse(Database):
|
||||
if limit is not None or isinstance(limit, int):
|
||||
where_filter += f" LIMIT {limit}"
|
||||
|
||||
print("where_filter", where_filter)
|
||||
|
||||
val = self._conn.execute(f'''
|
||||
SELECT
|
||||
space_key,
|
||||
uuid,
|
||||
embedding_data,
|
||||
input_uri,
|
||||
@@ -78,18 +81,10 @@ class Clickhouse(Database):
|
||||
|
||||
return val
|
||||
|
||||
def delete_batch(self, batch):
|
||||
pass
|
||||
|
||||
def persist(self):
|
||||
pass
|
||||
|
||||
def load(self, path=".chroma/chroma.parquet"):
|
||||
pass
|
||||
|
||||
def get_by_ids(self, ids=list):
|
||||
return self._conn.execute(f'''
|
||||
SELECT
|
||||
space_key,
|
||||
uuid,
|
||||
embedding_data,
|
||||
input_uri,
|
||||
|
||||
@@ -1,176 +0,0 @@
|
||||
from os import EX_CANTCREAT
|
||||
from chroma_server.db.abstract import Database
|
||||
import duckdb
|
||||
# import numpy as np
|
||||
# import pandas as pd
|
||||
|
||||
class DuckDB(Database):
|
||||
_conn = None
|
||||
|
||||
def __init__(self):
|
||||
self._conn = duckdb.connect()
|
||||
self._conn.execute('''
|
||||
CREATE TABLE embeddings (
|
||||
id integer PRIMARY KEY,
|
||||
embedding_data REAL[],
|
||||
input_uri STRING,
|
||||
dataset STRING,
|
||||
custom_quality_score REAL,
|
||||
category_name STRING
|
||||
)
|
||||
''')
|
||||
|
||||
# ids to manage internal bookkeeping and *nothing else*, users should not have to care about these ids
|
||||
self._conn.execute('''
|
||||
CREATE SEQUENCE seq_id START 1;
|
||||
''')
|
||||
|
||||
self._conn.execute('''
|
||||
-- change the default null sorting order to either NULLS FIRST and NULLS LAST
|
||||
PRAGMA default_null_order='NULLS LAST';
|
||||
-- change the default sorting order to either DESC or ASC
|
||||
PRAGMA default_order='DESC';
|
||||
''')
|
||||
return
|
||||
|
||||
def add_batch(self, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None):
|
||||
'''
|
||||
Add embeddings to the database
|
||||
This accepts both a single input and a list of inputs
|
||||
'''
|
||||
|
||||
# create list of the types of all inputs
|
||||
types = [type(x).__name__ for x in [embedding_data, input_uri]]
|
||||
|
||||
# if all of the types are 'list' - do batch mode
|
||||
if all(x == 'list' for x in types):
|
||||
lengths = [len(x) for x in [embedding_data, input_uri]]
|
||||
|
||||
# accepts some inputs as str or none, and this multiples them out to the correct length
|
||||
if custom_quality_score is None or isinstance(custom_quality_score, str):
|
||||
custom_quality_score = [custom_quality_score] * lengths[0]
|
||||
if category_name is None or isinstance(category_name, str):
|
||||
category_name = [category_name] * lengths[0]
|
||||
if dataset is None or isinstance(dataset, str):
|
||||
dataset = [dataset] * lengths[0]
|
||||
|
||||
# we have to move from column to row format for duckdb
|
||||
data_to_insert = []
|
||||
for i in range(lengths[0]):
|
||||
data_to_insert.append([embedding_data[i], input_uri[i], dataset[i], custom_quality_score[i], category_name[i]])
|
||||
|
||||
if all(x == lengths[0] for x in lengths):
|
||||
self._conn.executemany('''
|
||||
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)''',
|
||||
data_to_insert
|
||||
)
|
||||
return
|
||||
|
||||
# if any of the types are 'list' - throw an error
|
||||
if any(x == list for x in [input_uri, dataset, custom_quality_score, category_name]):
|
||||
raise Exception("Invalid input types. One input is a list where others are not: " + str(types))
|
||||
|
||||
# single insert mode
|
||||
# This should never fire because we do everything in batch mode, but given the mode away from duckdb likely, I am just leaving it in
|
||||
self._conn.execute('''
|
||||
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)''',
|
||||
[embedding_data, input_uri, dataset, custom_quality_score, category_name]
|
||||
)
|
||||
|
||||
def count(self):
|
||||
return self._conn.execute('''
|
||||
SELECT COUNT(*) FROM embeddings;
|
||||
''').fetchone()[0]
|
||||
|
||||
def update(self, data): # call this update_custom_quality_score! that is all it does
|
||||
'''
|
||||
I was not able to figure out (yet) how to do a bulk update in duckdb
|
||||
This is going to be fairly slow
|
||||
'''
|
||||
for element in data:
|
||||
if element['custom_quality_score'] is None:
|
||||
continue
|
||||
self._conn.execute(f'''
|
||||
UPDATE embeddings SET custom_quality_score={element['custom_quality_score']} WHERE id={element['id']}'''
|
||||
)
|
||||
|
||||
def fetch(self, where_filter={}, sort=None, limit=None):
|
||||
# check to see if query is a dict and if it is a flat list of key value pairs
|
||||
if where_filter is not None:
|
||||
if not isinstance(where_filter, dict):
|
||||
raise Exception("Invalid where_filter: " + str(where_filter))
|
||||
|
||||
# ensure where_filter is a flat dict
|
||||
for key in where_filter:
|
||||
if isinstance(where_filter[key], dict):
|
||||
raise Exception("Invalid where_filter: " + str(where_filter))
|
||||
|
||||
where_filter = " AND ".join([f"{key} = '{value}'" for key, value in where_filter.items()])
|
||||
|
||||
if where_filter:
|
||||
where_filter = f"WHERE {where_filter}"
|
||||
|
||||
if sort is not None:
|
||||
where_filter += f" ORDER BY {sort}"
|
||||
|
||||
if limit is not None or isinstance(limit, int):
|
||||
where_filter += f" LIMIT {limit}"
|
||||
|
||||
return self._conn.execute(f'''
|
||||
SELECT
|
||||
id,
|
||||
embedding_data,
|
||||
input_uri,
|
||||
dataset,
|
||||
custom_quality_score,
|
||||
category_name
|
||||
FROM
|
||||
embeddings
|
||||
{where_filter}
|
||||
''').fetchdf().replace({np.nan: None}) # replace nan with None for json serialization
|
||||
|
||||
def delete_batch(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def persist(self):
|
||||
'''
|
||||
Persist the database to disk
|
||||
'''
|
||||
if self._conn is None:
|
||||
return
|
||||
|
||||
self._conn.execute('''
|
||||
COPY
|
||||
(SELECT * FROM embeddings)
|
||||
TO '.chroma/chroma.parquet'
|
||||
(FORMAT PARQUET);
|
||||
''')
|
||||
|
||||
def load(self, path=".chroma/chroma.parquet"):
|
||||
'''
|
||||
Load the database from disk
|
||||
'''
|
||||
self._conn.execute(f"INSERT INTO embeddings SELECT * FROM read_parquet('{path}');")
|
||||
|
||||
def get_by_ids(self, ids=list):
|
||||
# select from duckdb table where ids are in the list
|
||||
if not isinstance(ids, list):
|
||||
raise Exception("ids must be a list")
|
||||
|
||||
if not ids:
|
||||
# create an empty pandas dataframe
|
||||
return pd.DataFrame()
|
||||
|
||||
return self._conn.execute(f'''
|
||||
SELECT
|
||||
id,
|
||||
embedding_data,
|
||||
input_uri,
|
||||
dataset,
|
||||
custom_quality_score,
|
||||
category_name
|
||||
FROM
|
||||
embeddings
|
||||
WHERE
|
||||
id IN ({','.join([str(x) for x in ids])})
|
||||
''').fetchdf().replace({np.nan: None}) # replace nan with None for json serialization
|
||||
@@ -1,4 +1,5 @@
|
||||
import hnswlib
|
||||
import pickle
|
||||
import time
|
||||
import numpy as np
|
||||
from chroma_server.index.abstract import Index
|
||||
@@ -6,38 +7,37 @@ from chroma_server.utils import logger
|
||||
|
||||
class Hnswlib(Index):
|
||||
|
||||
# we cache the index and mappers for the latest space_key
|
||||
_space_key = None
|
||||
_index = None
|
||||
|
||||
# these data structures enable us to map between uuids and ids
|
||||
# - our uuids are strings (clickhouse doesnt do autoincrementing ids for performance)
|
||||
# - but hnswlib uses integers only as ids
|
||||
# - so this is a temporary bandaid.
|
||||
# TODO: this should get written to disk somehow or we the index will be come useless after a restart
|
||||
# - so this is a bandaid.
|
||||
_id_to_uuid = {}
|
||||
_uuid_to_id = {}
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def run(self, embedding_data):
|
||||
def run(self, space_key, embedding_data):
|
||||
# more comments available at the source: https://github.com/nmslib/hnswlib
|
||||
|
||||
self._space_key = space_key
|
||||
|
||||
s1 = time.time()
|
||||
uuids = []
|
||||
embeddings = []
|
||||
ids = []
|
||||
i = 0
|
||||
for embedding in embedding_data:
|
||||
uuids.append(str(embedding[0]))
|
||||
embeddings.append((embedding[1]))
|
||||
uuids.append(str(embedding[1]))
|
||||
embeddings.append((embedding[2]))
|
||||
ids.append(i)
|
||||
self._id_to_uuid[i] = str(embedding[0])
|
||||
self._uuid_to_id[str(embedding[0])] = i
|
||||
self._id_to_uuid[i] = str(embedding[1])
|
||||
self._uuid_to_id[str(embedding[1])] = i
|
||||
i += 1
|
||||
|
||||
print('time to create uuids and embeddings: ', time.time() - s1)
|
||||
|
||||
# We split the data in two batches:
|
||||
data1 = embeddings
|
||||
dim = len(data1[0])
|
||||
num_elements = len(data1)
|
||||
@@ -67,25 +67,40 @@ class Hnswlib(Index):
|
||||
|
||||
self._index = p
|
||||
|
||||
def fetch(self, query):
|
||||
raise NotImplementedError
|
||||
self.save()
|
||||
|
||||
def delete_batch(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def persist(self):
|
||||
def save(self):
|
||||
if self._index is None:
|
||||
return
|
||||
self._index.save_index(".chroma/index.bin")
|
||||
logger.debug('Index saved to .chroma/index.bin')
|
||||
self._index.save_index(f"/index_data/index_{self._space_key}.bin")
|
||||
|
||||
def load(self, elements, dimensionality):
|
||||
# pickle the mappers
|
||||
with open(f"/index_data/id_to_uuid_{self._space_key}.pkl", 'wb') as f:
|
||||
pickle.dump(self._id_to_uuid, f, pickle.HIGHEST_PROTOCOL)
|
||||
with open(f"/index_data/uuid_to_id_{self._space_key}.pkl", 'wb') as f:
|
||||
pickle.dump(self._uuid_to_id, f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
logger.debug('Index saved to /index_data/index.bin')
|
||||
|
||||
def load(self, space_key, elements, dimensionality):
|
||||
p = hnswlib.Index(space='l2', dim= dimensionality)
|
||||
self._index = p
|
||||
self._index.load_index(".chroma/index.bin", max_elements= elements)
|
||||
self._index.load_index(f"/index_data/index_{space_key}.bin", max_elements= elements)
|
||||
|
||||
# unpickle the mappers
|
||||
with open(f"/index_data/id_to_uuid_{space_key}.pkl", 'rb') as f:
|
||||
self._id_to_uuid = pickle.load(f)
|
||||
with open(f"/index_data/uuid_to_id_{space_key}.pkl", 'rb') as f:
|
||||
self._uuid_to_id = pickle.load(f)
|
||||
|
||||
self._space_key = space_key
|
||||
|
||||
# do knn_query on hnswlib to get nearest neighbors
|
||||
def get_nearest_neighbors(self, query, k, uuids=None):
|
||||
def get_nearest_neighbors(self, space_key, query, k, uuids=None):
|
||||
|
||||
if self._space_key != space_key:
|
||||
# TODO: deal with this magic number....
|
||||
self.load(space_key, 500_000, len(query))
|
||||
|
||||
s2= time.time()
|
||||
# get ids from uuids
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Union, Any
|
||||
|
||||
# type supports single and batch mode
|
||||
class AddEmbedding(BaseModel):
|
||||
space_key: Union[str, list]
|
||||
embedding_data: list
|
||||
input_uri: Union[str, list]
|
||||
dataset: Union[str, list] = None
|
||||
@@ -10,7 +11,19 @@ class AddEmbedding(BaseModel):
|
||||
category_name: Union[str, list] = None
|
||||
|
||||
class QueryEmbedding(BaseModel):
|
||||
space_key: str = None
|
||||
embedding: list
|
||||
n_results: int = 10
|
||||
category_name: str = None
|
||||
dataset: str = None
|
||||
|
||||
class ProcessEmbedding(BaseModel):
|
||||
space_key: str = None
|
||||
|
||||
class FetchEmbedding(BaseModel):
|
||||
where_filter: dict = {}
|
||||
sort: str = None
|
||||
limit: int = None
|
||||
|
||||
class CountEmbedding(BaseModel):
|
||||
space_key: str = None
|
||||
@@ -11,6 +11,7 @@ services:
|
||||
dockerfile: Dockerfile
|
||||
volumes:
|
||||
- ./chroma-server/:/chroma-server/
|
||||
- index_data:/index_data
|
||||
command: uvicorn chroma_server:app --reload --workers 1 --host 0.0.0.0 --port 8000
|
||||
# env_file:
|
||||
# - ./chroma-server/.env
|
||||
@@ -36,4 +37,6 @@ services:
|
||||
volumes:
|
||||
clickhouse_data:
|
||||
driver: local
|
||||
index_data:
|
||||
driver: local
|
||||
|
||||
|
||||
@@ -1,5 +1,36 @@
|
||||
from hashlib import new
|
||||
import chroma_client
|
||||
import pandas as pd
|
||||
from chroma_client import Chroma
|
||||
|
||||
new_labels = chroma_client.fetch_new_labels()
|
||||
print(new_labels)
|
||||
client1 = Chroma(app="yolov3", model_version="1", layer="1")
|
||||
# client1.reset()
|
||||
|
||||
|
||||
knife_embedding = [0.2310010939836502, -0.3462161719799042, 0.29164767265319824, -0.09828940033912659, 1.814868450164795, -10.517369270324707, -13.531850814819336, -12.730537414550781, -13.011675834655762, -10.257010459899902, -13.779699325561523, -11.963963508605957, -13.948140144348145, -12.46799087524414, -14.569470405578613, -16.388280868530273, -13.76762580871582, -12.192169189453125, -12.204055786132812, -12.259000778198242, -13.696036338806152, -14.609177589416504, -16.951879501342773, -17.096384048461914, -14.355693817138672, -16.643482208251953, -14.270745277404785, -14.375198364257812, -14.381218910217285, -13.475995063781738, -12.694938659667969, -10.011992454528809, -9.770626068115234, -13.155019760131836, -16.136341094970703, -6.552414417266846, -11.243837356567383, -16.678457260131836, -14.629229545593262, -10.052337646484375, -15.451828956604004, -12.561151504516602, -11.68396282196045, -11.975972175598145, -11.09926986694336, -13.060500144958496, -12.075592994689941, -1.0808746814727783, 1.7046797275543213, -3.8080708980560303, -11.401922225952148, -12.184720039367676, -13.262567520141602, -11.299583435058594, -13.654638290405273, -10.767330169677734, -9.012763977050781, -10.202326774597168, -10.088111877441406, -13.247991561889648, -9.651527404785156, -11.903244972229004, -13.922954559326172, -17.37179946899414, -12.51513385772705, -7.8046746253967285, -14.406414985656738, -13.172696113586426, -11.194984436035156, -12.029500961303711, -10.996524810791016, -10.828441619873047, -8.673471450805664, -13.800869941711426, -9.680946350097656, -12.964024543762207, -9.694372177124023, -13.132003784179688, -9.38864803314209, -14.305071830749512, -14.4693603515625, -5.0566205978393555, -15.685358047485352, -12.493011474609375, -8.424881935119629]
|
||||
|
||||
get_nearest_neighbors = client1.get_nearest_neighbors(knife_embedding, 4, None, "training")
|
||||
res_df = pd.DataFrame(get_nearest_neighbors['embeddings'])
|
||||
print(res_df.head())
|
||||
|
||||
# client1.log([[1,2,3,4,5]], ["/images/1"], ["training"], ['spoon'])
|
||||
# client1.log([[1,2,3,4,5]], ["/images/2"], ["training"], ['spoon'])
|
||||
# client1.log([[1,2,3,4,5]], ["/images/3"], ["training"], ['spoon'])
|
||||
# client1.log([[1,2,3,4,5]], ["/images/4"], ["training"], ['spoon'])
|
||||
# client1.log([[1,2,3,4,5]], ["/prod/1"], ["test"], ['spoon'])
|
||||
# client1.log([[1,2,3,4,5]], ["/prod/2"], ["test"], ['spoon'])
|
||||
# print("context", client1.get_context())
|
||||
# print("context", client1.heartbeat())
|
||||
# print("layer 1", client1.count(client1.get_context()))
|
||||
# # print("fetch", client1.fetch())
|
||||
# print(client1.process())
|
||||
# print(client1.get_nearest_neighbors([1,2,3,4,5], 2))
|
||||
|
||||
# client1.set_context("test", "1", "2")
|
||||
# client1.log([[1,2,3,4,5]], ["/images/1"], ["training"], ['knife'])
|
||||
# client1.log([[1,2,3,4,5]], ["/images/4"], ["training"], ['knife'])
|
||||
# client1.log([[1,2,3,4,5]], ["/prod/2"], ["test"], ['knife'])
|
||||
# print("context", client1.get_context())
|
||||
# print("context", client1.heartbeat())
|
||||
# print("layer 1", client1.count(client1.get_context()))
|
||||
# # print("fetch", client1.fetch())
|
||||
# print(client1.process())
|
||||
# print(client1.get_nearest_neighbors([1,2,3,4,5], 2))
|
||||
@@ -19,7 +19,7 @@ if __name__ == "__main__":
|
||||
|
||||
data_length = len(df)
|
||||
|
||||
chroma = Chroma()
|
||||
chroma = Chroma(app="yolov3", model_version="1", layer="1")
|
||||
chroma.reset() #make sure we are using a fresh db
|
||||
allstart = time.time()
|
||||
start = time.time()
|
||||
|
||||
Reference in New Issue
Block a user