mirror of
https://github.com/placeholder-soft/chroma.git
synced 2026-01-13 08:50:30 +08:00
black python formatting
This commit is contained in:
5
Makefile
Normal file
5
Makefile
Normal file
@@ -0,0 +1,5 @@
|
||||
black:
|
||||
black --fast chroma-server chroma-client
|
||||
|
||||
check_black:
|
||||
black --check --fast chroma-server chroma-client
|
||||
@@ -2,6 +2,7 @@ import requests
|
||||
import json
|
||||
from typing import Union
|
||||
|
||||
|
||||
class Chroma:
|
||||
|
||||
_api_url = "http://localhost:8000/api/v1"
|
||||
@@ -15,127 +16,142 @@ class Chroma:
|
||||
self.url = url
|
||||
|
||||
def count(self):
|
||||
'''
|
||||
"""
|
||||
Returns the number of embeddings in the database
|
||||
'''
|
||||
"""
|
||||
x = requests.get(self._api_url + "/count")
|
||||
return x.json()
|
||||
|
||||
def fetch(self, where_filter={}, sort=None, limit=None):
|
||||
'''
|
||||
"""
|
||||
Fetches embeddings from the database
|
||||
'''
|
||||
x = requests.get(self._api_url + "/fetch", data=json.dumps({
|
||||
"where_filter":json.dumps(where_filter),
|
||||
"sort":sort,
|
||||
"limit":limit
|
||||
}))
|
||||
"""
|
||||
x = requests.get(
|
||||
self._api_url + "/fetch",
|
||||
data=json.dumps(
|
||||
{"where_filter": json.dumps(where_filter), "sort": sort, "limit": limit}
|
||||
),
|
||||
)
|
||||
return x.json()
|
||||
|
||||
def process(self):
|
||||
'''
|
||||
"""
|
||||
Processes embeddings in the database
|
||||
- currently this only runs hnswlib, doesnt return anything
|
||||
'''
|
||||
"""
|
||||
requests.get(self._api_url + "/process")
|
||||
return True
|
||||
|
||||
def reset(self):
|
||||
'''
|
||||
"""
|
||||
Resets the database
|
||||
'''
|
||||
"""
|
||||
return requests.get(self._api_url + "/reset")
|
||||
|
||||
def persist(self):
|
||||
'''
|
||||
"""
|
||||
Persists the database to disk in the .chroma folder inside chroma-server
|
||||
'''
|
||||
"""
|
||||
return requests.get(self._api_url + "/persist")
|
||||
|
||||
def rand(self):
|
||||
'''
|
||||
"""
|
||||
Stubbed out sampling endpoint, returns a random bisection of the database
|
||||
'''
|
||||
"""
|
||||
x = requests.get(self._api_url + "/rand")
|
||||
return x.json()
|
||||
|
||||
def heartbeat(self):
|
||||
'''
|
||||
"""
|
||||
Returns the current server time in milliseconds to check if the server is alive
|
||||
'''
|
||||
"""
|
||||
x = requests.get(self._api_url)
|
||||
return x.json()
|
||||
|
||||
def log(self,
|
||||
embedding_data: list,
|
||||
input_uri: list,
|
||||
def log(
|
||||
self,
|
||||
embedding_data: list,
|
||||
input_uri: list,
|
||||
dataset: list = None,
|
||||
category_name: list = None):
|
||||
'''
|
||||
category_name: list = None,
|
||||
):
|
||||
"""
|
||||
Logs a batch of embeddings to the database
|
||||
- pass in column oriented data lists
|
||||
'''
|
||||
"""
|
||||
|
||||
x = requests.post(self._api_url + "/add", data = json.dumps({
|
||||
"embedding_data": embedding_data,
|
||||
"input_uri": input_uri,
|
||||
"dataset": dataset,
|
||||
"category_name": category_name
|
||||
}) )
|
||||
x = requests.post(
|
||||
self._api_url + "/add",
|
||||
data=json.dumps(
|
||||
{
|
||||
"embedding_data": embedding_data,
|
||||
"input_uri": input_uri,
|
||||
"dataset": dataset,
|
||||
"category_name": category_name,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
if x.status_code == 201:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def log_training(self, embedding_data: list, input_uri: list, category_name: list):
|
||||
'''
|
||||
"""
|
||||
Small wrapper around log() to log a batch of training embedding
|
||||
- sets dataset to "training"
|
||||
'''
|
||||
"""
|
||||
return self.log(
|
||||
embedding_data=embedding_data,
|
||||
input_uri=input_uri,
|
||||
embedding_data=embedding_data,
|
||||
input_uri=input_uri,
|
||||
dataset="training",
|
||||
category_name=category_name
|
||||
category_name=category_name,
|
||||
)
|
||||
|
||||
|
||||
def log_production(self, embedding_data: list, input_uri: list, category_name: list):
|
||||
'''
|
||||
"""
|
||||
Small wrapper around log() to log a batch of production embedding
|
||||
- sets dataset to "production"
|
||||
'''
|
||||
"""
|
||||
return self.log(
|
||||
embedding_data=embedding_data,
|
||||
input_uri=input_uri,
|
||||
embedding_data=embedding_data,
|
||||
input_uri=input_uri,
|
||||
dataset="production",
|
||||
category_name=category_name
|
||||
category_name=category_name,
|
||||
)
|
||||
|
||||
|
||||
def log_triage(self, embedding_data: list, input_uri: list, category_name: list):
|
||||
'''
|
||||
"""
|
||||
Small wrapper around log() to log a batch of triage embedding
|
||||
- sets dataset to "triage"
|
||||
'''
|
||||
"""
|
||||
return self.log(
|
||||
embedding_data=embedding_data,
|
||||
input_uri=input_uri,
|
||||
embedding_data=embedding_data,
|
||||
input_uri=input_uri,
|
||||
dataset="triage",
|
||||
category_name=category_name
|
||||
category_name=category_name,
|
||||
)
|
||||
|
||||
def get_nearest_neighbors(self, embedding, n_results=10, category_name=None, dataset="training"):
|
||||
'''
|
||||
|
||||
def get_nearest_neighbors(
|
||||
self, embedding, n_results=10, category_name=None, dataset="training"
|
||||
):
|
||||
"""
|
||||
Gets the nearest neighbors of a single embedding
|
||||
'''
|
||||
x = requests.post(self._api_url + "/get_nearest_neighbors", data = json.dumps({
|
||||
"embedding": embedding,
|
||||
"n_results": n_results,
|
||||
"category_name": category_name,
|
||||
"dataset": dataset
|
||||
}) )
|
||||
"""
|
||||
x = requests.post(
|
||||
self._api_url + "/get_nearest_neighbors",
|
||||
data=json.dumps(
|
||||
{
|
||||
"embedding": embedding,
|
||||
"n_results": n_results,
|
||||
"category_name": category_name,
|
||||
"dataset": dataset,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
if x.status_code == 200:
|
||||
return x.json()
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
@@ -4,11 +4,14 @@ from chroma_client import Chroma
|
||||
import pytest
|
||||
import time
|
||||
from httpx import AsyncClient
|
||||
|
||||
# from ..api import app # this wont work because i moved the file
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def anyio_backend():
|
||||
return 'asyncio'
|
||||
return "asyncio"
|
||||
|
||||
|
||||
def test_init():
|
||||
chroma = Chroma()
|
||||
@@ -25,4 +28,3 @@ def test_init():
|
||||
# chroma = Chroma(url="http://test/api/v1")
|
||||
# response = await chroma.count()
|
||||
# raise Exception("response" + response)
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
|
||||
import random
|
||||
|
||||
|
||||
def rand_bisectional_subsample(data):
|
||||
"""
|
||||
Randomly bisectionally subsample a list of data to size.
|
||||
"""
|
||||
return data.sample(frac=0.5, replace=True, random_state=1)
|
||||
return data.sample(frac=0.5, replace=True, random_state=1)
|
||||
|
||||
@@ -2,22 +2,23 @@ import numpy as np
|
||||
import json
|
||||
import ast
|
||||
|
||||
|
||||
def class_distances(data):
|
||||
''''
|
||||
"""'
|
||||
This is all very subject to change, so essentially just copy and paste from what we had before
|
||||
'''
|
||||
"""
|
||||
|
||||
return False
|
||||
|
||||
# def unpack_annotations(embeddings):
|
||||
# annotations = [json.loads(embedding['infer'])["annotations"]for embedding in embeddings]
|
||||
# annotations = [annotation for annotation_list in annotations for annotation in annotation_list]
|
||||
# annotations = [annotation for annotation_list in annotations for annotation in annotation_list]
|
||||
# # Unpack embedding data
|
||||
# embeddings = [embedding["embedding_data"] for embedding in embeddings]
|
||||
# embedding_vectors_by_category = {}
|
||||
# for embedding_annotation_pair in zip(embeddings, annotations):
|
||||
# data = np.array(embedding_annotation_pair[0])
|
||||
# category = embedding_annotation_pair[1]['category_id']
|
||||
# category = embedding_annotation_pair[1]['category_id']
|
||||
# if category in embedding_vectors_by_category.keys():
|
||||
# embedding_vectors_by_category[category] = np.append(
|
||||
# embedding_vectors_by_category[category], data[np.newaxis, :], axis=0
|
||||
@@ -84,5 +85,5 @@ def class_distances(data):
|
||||
|
||||
# if (len(inferences) == 0):
|
||||
# raise Exception("No inferences found for datapoint")
|
||||
|
||||
# return output_distances
|
||||
|
||||
# return output_distances
|
||||
|
||||
@@ -11,7 +11,6 @@ from chroma_server.types import AddEmbedding, QueryEmbedding
|
||||
from chroma_server.utils import logger
|
||||
|
||||
|
||||
|
||||
# Boot script
|
||||
db = DuckDB
|
||||
ann_index = Hnswlib
|
||||
@@ -34,104 +33,112 @@ if os.path.exists(".chroma/index.bin"):
|
||||
app._ann_index.load(app._db.count(), len(app._db.fetch(limit=1).embedding_data))
|
||||
|
||||
|
||||
|
||||
# API Endpoints
|
||||
|
||||
|
||||
@app.get("/api/v1")
|
||||
async def root():
|
||||
'''
|
||||
"""
|
||||
Heartbeat endpoint
|
||||
'''
|
||||
"""
|
||||
return {"nanosecond heartbeat": int(1000 * time.time_ns())}
|
||||
|
||||
|
||||
@app.post("/api/v1/add", status_code=status.HTTP_201_CREATED)
|
||||
async def add_to_db(new_embedding: AddEmbedding):
|
||||
'''
|
||||
"""
|
||||
Save embedding to database
|
||||
- supports single or batched embeddings
|
||||
'''
|
||||
"""
|
||||
|
||||
app._db.add_batch(
|
||||
new_embedding.embedding_data,
|
||||
new_embedding.input_uri,
|
||||
new_embedding.embedding_data,
|
||||
new_embedding.input_uri,
|
||||
new_embedding.dataset,
|
||||
new_embedding.custom_quality_score,
|
||||
new_embedding.category_name
|
||||
)
|
||||
new_embedding.custom_quality_score,
|
||||
new_embedding.category_name,
|
||||
)
|
||||
|
||||
return {"response": "Added record to database"}
|
||||
|
||||
|
||||
@app.get("/api/v1/process")
|
||||
async def process():
|
||||
'''
|
||||
"""
|
||||
Currently generates an index for the embedding db
|
||||
'''
|
||||
"""
|
||||
app._ann_index.run(app._db.fetch())
|
||||
|
||||
|
||||
@app.get("/api/v1/fetch")
|
||||
async def fetch(where_filter={}, sort=None, limit=None):
|
||||
'''
|
||||
"""
|
||||
Fetches embeddings from the database
|
||||
- enables filtering by where_filter, sorting by key, and limiting the number of results
|
||||
'''
|
||||
"""
|
||||
return app._db.fetch(where_filter, sort, limit).to_dict(orient="records")
|
||||
|
||||
|
||||
@app.get("/api/v1/count")
|
||||
async def count():
|
||||
'''
|
||||
"""
|
||||
Returns the number of records in the database
|
||||
'''
|
||||
return ({"count": app._db.count()})
|
||||
"""
|
||||
return {"count": app._db.count()}
|
||||
|
||||
|
||||
@app.get("/api/v1/persist")
|
||||
async def persist():
|
||||
'''
|
||||
"""
|
||||
Persist the database and index to disk
|
||||
'''
|
||||
"""
|
||||
if not os.path.exists(".chroma"):
|
||||
os.mkdir(".chroma")
|
||||
|
||||
|
||||
app._db.persist()
|
||||
app._ann_index.persist()
|
||||
return True
|
||||
|
||||
|
||||
@app.get("/api/v1/reset")
|
||||
async def reset():
|
||||
'''
|
||||
"""
|
||||
Reset the database and index
|
||||
'''
|
||||
"""
|
||||
shutil.rmtree(".chroma", ignore_errors=True)
|
||||
app._db = db()
|
||||
app._ann_index = ann_index()
|
||||
return True
|
||||
|
||||
|
||||
@app.get("/api/v1/rand")
|
||||
async def rand(where_filter={}, sort=None, limit=None):
|
||||
'''
|
||||
"""
|
||||
Randomly bisection the database
|
||||
'''
|
||||
"""
|
||||
results = app._db.fetch(where_filter, sort, limit)
|
||||
rand = rand_bisectional_subsample(results)
|
||||
return rand.to_dict(orient="records")
|
||||
|
||||
|
||||
@app.post("/api/v1/get_nearest_neighbors")
|
||||
async def get_nearest_neighbors(embedding: QueryEmbedding):
|
||||
'''
|
||||
"""
|
||||
return the distance, database ids, and embedding themselves for the input embedding
|
||||
'''
|
||||
"""
|
||||
ids = None
|
||||
filter_by_where = {}
|
||||
if embedding.category_name is not None:
|
||||
filter_by_where['category_name'] = embedding.category_name
|
||||
filter_by_where["category_name"] = embedding.category_name
|
||||
if embedding.dataset is not None:
|
||||
filter_by_where['dataset'] = embedding.dataset
|
||||
filter_by_where["dataset"] = embedding.dataset
|
||||
|
||||
if filter_by_where is not None:
|
||||
ids = app._db.fetch(filter_by_where)["id"].tolist()
|
||||
|
||||
|
||||
nn = app._ann_index.get_nearest_neighbors(embedding.embedding, embedding.n_results, ids)
|
||||
return {
|
||||
"ids": nn[0].tolist()[0],
|
||||
"embeddings": app._db.get_by_ids(nn[0].tolist()[0]).to_dict(orient="records"),
|
||||
"distances": nn[1].tolist()[0]
|
||||
}
|
||||
"distances": nn[1].tolist()[0],
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
class Database():
|
||||
|
||||
class Database:
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
pass
|
||||
@@ -23,4 +24,4 @@ class Database():
|
||||
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -4,12 +4,14 @@ import duckdb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class DuckDB(Database):
|
||||
_conn = None
|
||||
|
||||
def __init__(self):
|
||||
self._conn = duckdb.connect()
|
||||
self._conn.execute('''
|
||||
self._conn.execute(
|
||||
"""
|
||||
CREATE TABLE embeddings (
|
||||
id integer PRIMARY KEY,
|
||||
embedding_data REAL[],
|
||||
@@ -18,32 +20,39 @@ class DuckDB(Database):
|
||||
custom_quality_score REAL,
|
||||
category_name STRING
|
||||
)
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
# ids to manage internal bookkeeping and *nothing else*, users should not have to care about these ids
|
||||
self._conn.execute('''
|
||||
self._conn.execute(
|
||||
"""
|
||||
CREATE SEQUENCE seq_id START 1;
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
self._conn.execute('''
|
||||
self._conn.execute(
|
||||
"""
|
||||
-- change the default null sorting order to either NULLS FIRST and NULLS LAST
|
||||
PRAGMA default_null_order='NULLS LAST';
|
||||
-- change the default sorting order to either DESC or ASC
|
||||
PRAGMA default_order='DESC';
|
||||
''')
|
||||
"""
|
||||
)
|
||||
return
|
||||
|
||||
def add_batch(self, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None):
|
||||
'''
|
||||
def add_batch(
|
||||
self, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None
|
||||
):
|
||||
"""
|
||||
Add embeddings to the database
|
||||
This accepts both a single input and a list of inputs
|
||||
'''
|
||||
"""
|
||||
|
||||
# create list of the types of all inputs
|
||||
types = [type(x).__name__ for x in [embedding_data, input_uri]]
|
||||
|
||||
# if all of the types are 'list' - do batch mode
|
||||
if all(x == 'list' for x in types):
|
||||
if all(x == "list" for x in types):
|
||||
lengths = [len(x) for x in [embedding_data, input_uri]]
|
||||
|
||||
# accepts some inputs as str or none, and this multiples them out to the correct length
|
||||
@@ -57,41 +66,56 @@ class DuckDB(Database):
|
||||
# we have to move from column to row format for duckdb
|
||||
data_to_insert = []
|
||||
for i in range(lengths[0]):
|
||||
data_to_insert.append([embedding_data[i], input_uri[i], dataset[i], custom_quality_score[i], category_name[i]])
|
||||
data_to_insert.append(
|
||||
[
|
||||
embedding_data[i],
|
||||
input_uri[i],
|
||||
dataset[i],
|
||||
custom_quality_score[i],
|
||||
category_name[i],
|
||||
]
|
||||
)
|
||||
|
||||
if all(x == lengths[0] for x in lengths):
|
||||
self._conn.executemany('''
|
||||
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)''',
|
||||
data_to_insert
|
||||
self._conn.executemany(
|
||||
"""
|
||||
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)""",
|
||||
data_to_insert,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
# if any of the types are 'list' - throw an error
|
||||
if any(x == list for x in [input_uri, dataset, custom_quality_score, category_name]):
|
||||
raise Exception("Invalid input types. One input is a list where others are not: " + str(types))
|
||||
raise Exception(
|
||||
"Invalid input types. One input is a list where others are not: " + str(types)
|
||||
)
|
||||
|
||||
# single insert mode
|
||||
# This should never fire because we do everything in batch mode, but given the mode away from duckdb likely, I am just leaving it in
|
||||
self._conn.execute('''
|
||||
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)''',
|
||||
[embedding_data, input_uri, dataset, custom_quality_score, category_name]
|
||||
self._conn.execute(
|
||||
"""
|
||||
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)""",
|
||||
[embedding_data, input_uri, dataset, custom_quality_score, category_name],
|
||||
)
|
||||
|
||||
def count(self):
|
||||
return self._conn.execute('''
|
||||
SELECT COUNT(*) FROM embeddings;
|
||||
''').fetchone()[0]
|
||||
|
||||
def update(self, data): # call this update_custom_quality_score! that is all it does
|
||||
'''
|
||||
def count(self):
|
||||
return self._conn.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM embeddings;
|
||||
"""
|
||||
).fetchone()[0]
|
||||
|
||||
def update(self, data): # call this update_custom_quality_score! that is all it does
|
||||
"""
|
||||
I was not able to figure out (yet) how to do a bulk update in duckdb
|
||||
This is going to be fairly slow
|
||||
'''
|
||||
for element in data:
|
||||
if element['custom_quality_score'] is None:
|
||||
"""
|
||||
for element in data:
|
||||
if element["custom_quality_score"] is None:
|
||||
continue
|
||||
self._conn.execute(f'''
|
||||
UPDATE embeddings SET custom_quality_score={element['custom_quality_score']} WHERE id={element['id']}'''
|
||||
self._conn.execute(
|
||||
f"""
|
||||
UPDATE embeddings SET custom_quality_score={element['custom_quality_score']} WHERE id={element['id']}"""
|
||||
)
|
||||
|
||||
def fetch(self, where_filter={}, sort=None, limit=None):
|
||||
@@ -99,12 +123,12 @@ class DuckDB(Database):
|
||||
if where_filter is not None:
|
||||
if not isinstance(where_filter, dict):
|
||||
raise Exception("Invalid where_filter: " + str(where_filter))
|
||||
|
||||
|
||||
# ensure where_filter is a flat dict
|
||||
for key in where_filter:
|
||||
if isinstance(where_filter[key], dict):
|
||||
raise Exception("Invalid where_filter: " + str(where_filter))
|
||||
|
||||
|
||||
where_filter = " AND ".join([f"{key} = '{value}'" for key, value in where_filter.items()])
|
||||
|
||||
if where_filter:
|
||||
@@ -116,7 +140,9 @@ class DuckDB(Database):
|
||||
if limit is not None or isinstance(limit, int):
|
||||
where_filter += f" LIMIT {limit}"
|
||||
|
||||
return self._conn.execute(f'''
|
||||
return (
|
||||
self._conn.execute(
|
||||
f"""
|
||||
SELECT
|
||||
id,
|
||||
embedding_data,
|
||||
@@ -127,41 +153,49 @@ class DuckDB(Database):
|
||||
FROM
|
||||
embeddings
|
||||
{where_filter}
|
||||
''').fetchdf().replace({np.nan: None}) # replace nan with None for json serialization
|
||||
"""
|
||||
)
|
||||
.fetchdf()
|
||||
.replace({np.nan: None})
|
||||
) # replace nan with None for json serialization
|
||||
|
||||
def delete_batch(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def persist(self):
|
||||
'''
|
||||
"""
|
||||
Persist the database to disk
|
||||
'''
|
||||
"""
|
||||
if self._conn is None:
|
||||
return
|
||||
|
||||
self._conn.execute('''
|
||||
self._conn.execute(
|
||||
"""
|
||||
COPY
|
||||
(SELECT * FROM embeddings)
|
||||
TO '.chroma/chroma.parquet'
|
||||
(FORMAT PARQUET);
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
def load(self, path=".chroma/chroma.parquet"):
|
||||
'''
|
||||
"""
|
||||
Load the database from disk
|
||||
'''
|
||||
"""
|
||||
self._conn.execute(f"INSERT INTO embeddings SELECT * FROM read_parquet('{path}');")
|
||||
|
||||
def get_by_ids(self, ids=list):
|
||||
# select from duckdb table where ids are in the list
|
||||
if not isinstance(ids, list):
|
||||
raise Exception("ids must be a list")
|
||||
|
||||
|
||||
if not ids:
|
||||
# create an empty pandas dataframe
|
||||
return pd.DataFrame()
|
||||
|
||||
return self._conn.execute(f'''
|
||||
return (
|
||||
self._conn.execute(
|
||||
f"""
|
||||
SELECT
|
||||
id,
|
||||
embedding_data,
|
||||
@@ -173,4 +207,8 @@ class DuckDB(Database):
|
||||
embeddings
|
||||
WHERE
|
||||
id IN ({','.join([str(x) for x in ids])})
|
||||
''').fetchdf().replace({np.nan: None}) # replace nan with None for json serialization
|
||||
"""
|
||||
)
|
||||
.fetchdf()
|
||||
.replace({np.nan: None})
|
||||
) # replace nan with None for json serialization
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
class Index():
|
||||
|
||||
class Index:
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
pass
|
||||
@@ -23,4 +24,4 @@ class Index():
|
||||
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -3,6 +3,7 @@ import numpy as np
|
||||
from chroma_server.index.abstract import Index
|
||||
from chroma_server.utils import logger
|
||||
|
||||
|
||||
class Hnswlib(Index):
|
||||
|
||||
_index = None
|
||||
@@ -14,20 +15,22 @@ class Hnswlib(Index):
|
||||
# more comments available at the source: https://github.com/nmslib/hnswlib
|
||||
|
||||
# We split the data in two batches:
|
||||
data1 = embedding_data['embedding_data'].to_numpy().tolist()
|
||||
data1 = embedding_data["embedding_data"].to_numpy().tolist()
|
||||
dim = len(data1[0])
|
||||
num_elements = len(data1)
|
||||
num_elements = len(data1)
|
||||
# logger.debug("dimensionality is:", dim)
|
||||
# logger.debug("total number of elements is:", num_elements)
|
||||
# logger.debug("max elements", num_elements//2)
|
||||
|
||||
concatted_data = data1
|
||||
concatted_data = data1
|
||||
# logger.debug("concatted_data", len(concatted_data))
|
||||
|
||||
p = hnswlib.Index(space='l2', dim=dim) # # Declaring index, possible options are l2, cosine or ip
|
||||
p.init_index(max_elements=len(data1), ef_construction=100, M=16) # Initing index
|
||||
|
||||
p = hnswlib.Index(
|
||||
space="l2", dim=dim
|
||||
) # # Declaring index, possible options are l2, cosine or ip
|
||||
p.init_index(max_elements=len(data1), ef_construction=100, M=16) # Initing index
|
||||
p.set_ef(10) # Controlling the recall by setting ef:
|
||||
p.set_num_threads(4) # Set number of threads used during batch search/construction
|
||||
p.set_num_threads(4) # Set number of threads used during batch search/construction
|
||||
|
||||
# logger.debug("Adding first batch of elements", (len(data1)))
|
||||
p.add_items(data1, embedding_data["id"])
|
||||
@@ -37,12 +40,15 @@ class Hnswlib(Index):
|
||||
# logger.debug("database_ids", database_ids)
|
||||
# logger.debug("distances", distances)
|
||||
# logger.debug(len(distances))
|
||||
logger.debug("Recall for the first batch:" + str(np.mean(database_ids.reshape(-1) == np.arange(len(data1)))))
|
||||
logger.debug(
|
||||
"Recall for the first batch:"
|
||||
+ str(np.mean(database_ids.reshape(-1) == np.arange(len(data1))))
|
||||
)
|
||||
|
||||
self._index = p
|
||||
|
||||
def fetch(self, query):
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError
|
||||
|
||||
def delete_batch(self, batch):
|
||||
raise NotImplementedError
|
||||
@@ -51,12 +57,12 @@ class Hnswlib(Index):
|
||||
if self._index is None:
|
||||
return
|
||||
self._index.save_index(".chroma/index.bin")
|
||||
logger.debug('Index saved to .chroma/index.bin')
|
||||
logger.debug("Index saved to .chroma/index.bin")
|
||||
|
||||
def load(self, elements, dimensionality):
|
||||
p = hnswlib.Index(space='l2', dim= dimensionality)
|
||||
p = hnswlib.Index(space="l2", dim=dimensionality)
|
||||
self._index = p
|
||||
self._index.load_index(".chroma/index.bin", max_elements= elements)
|
||||
self._index.load_index(".chroma/index.bin", max_elements=elements)
|
||||
|
||||
# do knn_query on hnswlib to get nearest neighbors
|
||||
def get_nearest_neighbors(self, query, k, ids=None):
|
||||
|
||||
@@ -4,32 +4,45 @@ from httpx import AsyncClient
|
||||
|
||||
from ..api import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def anyio_backend():
|
||||
return 'asyncio'
|
||||
return "asyncio"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_root():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
response = await ac.get("/api/v1")
|
||||
assert response.status_code == 200
|
||||
assert abs(response.json()["nanosecond heartbeat"] - int(1000 * time.time_ns())) < 3_000_000_000 # a billion nanoseconds = 3s
|
||||
assert (
|
||||
abs(response.json()["nanosecond heartbeat"] - int(1000 * time.time_ns())) < 3_000_000_000
|
||||
) # a billion nanoseconds = 3s
|
||||
|
||||
|
||||
async def post_one_record(ac):
|
||||
return await ac.post("/api/v1/add", json={
|
||||
"embedding_data": [1.02, 2.03, 3.03],
|
||||
"input_uri": "https://example.com",
|
||||
"dataset": "coco",
|
||||
"category_name": "person"
|
||||
})
|
||||
return await ac.post(
|
||||
"/api/v1/add",
|
||||
json={
|
||||
"embedding_data": [1.02, 2.03, 3.03],
|
||||
"input_uri": "https://example.com",
|
||||
"dataset": "coco",
|
||||
"category_name": "person",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def post_batch_records(ac):
|
||||
return await ac.post("/api/v1/add", json={
|
||||
"embedding_data": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
|
||||
"input_uri": ["https://example.com", "https://example.com"],
|
||||
"dataset": "training",
|
||||
"category_name": "person"
|
||||
})
|
||||
return await ac.post(
|
||||
"/api/v1/add",
|
||||
json={
|
||||
"embedding_data": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
|
||||
"input_uri": ["https://example.com", "https://example.com"],
|
||||
"dataset": "training",
|
||||
"category_name": "person",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_add_to_db():
|
||||
@@ -38,6 +51,7 @@ async def test_add_to_db():
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {"response": "Added record to database"}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_add_to_db_batch():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
@@ -45,6 +59,7 @@ async def test_add_to_db_batch():
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {"response": "Added record to database"}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_fetch_from_db():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
@@ -53,15 +68,17 @@ async def test_fetch_from_db():
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_count_from_db():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
await ac.get("/api/v1/reset") # reset db
|
||||
await ac.get("/api/v1/reset") # reset db
|
||||
await post_batch_records(ac)
|
||||
response = await ac.get("/api/v1/count")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"count": 2}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_reset_db():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
@@ -74,25 +91,19 @@ async def test_reset_db():
|
||||
response = await ac.get("/api/v1/count")
|
||||
assert response.json() == {"count": 0}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_nearest_neighbors():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
await ac.get("/api/v1/reset")
|
||||
await post_batch_records(ac)
|
||||
await ac.get("/api/v1/process")
|
||||
response = await ac.post("/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1})
|
||||
response = await ac.post(
|
||||
"/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()["ids"]) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_nearest_neighbors_filter():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
await ac.get("/api/v1/reset")
|
||||
await post_batch_records(ac)
|
||||
await ac.get("/api/v1/process")
|
||||
response = await ac.post("/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1, "dataset": "training", "category_name": "monkey"})
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()["ids"]) == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_nearest_neighbors_filter():
|
||||
@@ -100,7 +111,34 @@ async def test_get_nearest_neighbors_filter():
|
||||
await ac.get("/api/v1/reset")
|
||||
await post_batch_records(ac)
|
||||
await ac.get("/api/v1/process")
|
||||
response = await ac.post("/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 2, "dataset": "training", "category_name": "person"})
|
||||
response = await ac.post(
|
||||
"/api/v1/get_nearest_neighbors",
|
||||
json={
|
||||
"embedding": [1.1, 2.3, 3.2],
|
||||
"n_results": 1,
|
||||
"dataset": "training",
|
||||
"category_name": "monkey",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()["ids"]) == 0
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_nearest_neighbors_filter():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
await ac.get("/api/v1/reset")
|
||||
await post_batch_records(ac)
|
||||
await ac.get("/api/v1/process")
|
||||
response = await ac.post(
|
||||
"/api/v1/get_nearest_neighbors",
|
||||
json={
|
||||
"embedding": [1.1, 2.3, 3.2],
|
||||
"n_results": 2,
|
||||
"dataset": "training",
|
||||
"category_name": "person",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()["ids"]) == 2
|
||||
|
||||
@@ -118,4 +156,4 @@ async def test_get_nearest_neighbors_filter():
|
||||
|
||||
# Purposefully untested
|
||||
# - process
|
||||
# - rand
|
||||
# - rand
|
||||
|
||||
@@ -6,9 +6,10 @@ class AddEmbedding(BaseModel):
|
||||
embedding_data: list
|
||||
input_uri: Union[str, list]
|
||||
dataset: Union[str, list] = None
|
||||
custom_quality_score: Union[float, list] = None
|
||||
custom_quality_score: Union[float, list] = None
|
||||
category_name: Union[str, list] = None
|
||||
|
||||
|
||||
class QueryEmbedding(BaseModel):
|
||||
embedding: list
|
||||
n_results: int = 10
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
|
||||
|
||||
def setup_logging():
|
||||
logging.basicConfig(filename="chroma_logs.log")
|
||||
logger = logging.getLogger("Chroma")
|
||||
@@ -7,4 +8,5 @@ def setup_logging():
|
||||
logger.debug("Logger created")
|
||||
return logger
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
8
pyproject.toml
Normal file
8
pyproject.toml
Normal file
@@ -0,0 +1,8 @@
|
||||
[tool.black]
|
||||
line-length = 100
|
||||
|
||||
# Black will refuse to run if it's not this version.
|
||||
required-version = "22.6.0"
|
||||
|
||||
# Ensure black's output will be compatible with all listed versions.
|
||||
target-version = ['py36', 'py37', 'py38', 'py39', 'py310']
|
||||
Reference in New Issue
Block a user