black python formatting

This commit is contained in:
Jeffrey Huber
2022-10-31 12:42:34 -07:00
parent 6b87b2ed48
commit 30fbd06bb1
14 changed files with 317 additions and 191 deletions

5
Makefile Normal file
View File

@@ -0,0 +1,5 @@
black:
black --fast chroma-server chroma-client
check_black:
black --check --fast chroma-server chroma-client

View File

@@ -2,6 +2,7 @@ import requests
import json import json
from typing import Union from typing import Union
class Chroma: class Chroma:
_api_url = "http://localhost:8000/api/v1" _api_url = "http://localhost:8000/api/v1"
@@ -15,127 +16,142 @@ class Chroma:
self.url = url self.url = url
def count(self): def count(self):
''' """
Returns the number of embeddings in the database Returns the number of embeddings in the database
''' """
x = requests.get(self._api_url + "/count") x = requests.get(self._api_url + "/count")
return x.json() return x.json()
def fetch(self, where_filter={}, sort=None, limit=None): def fetch(self, where_filter={}, sort=None, limit=None):
''' """
Fetches embeddings from the database Fetches embeddings from the database
''' """
x = requests.get(self._api_url + "/fetch", data=json.dumps({ x = requests.get(
"where_filter":json.dumps(where_filter), self._api_url + "/fetch",
"sort":sort, data=json.dumps(
"limit":limit {"where_filter": json.dumps(where_filter), "sort": sort, "limit": limit}
})) ),
)
return x.json() return x.json()
def process(self): def process(self):
''' """
Processes embeddings in the database Processes embeddings in the database
- currently this only runs hnswlib, doesnt return anything - currently this only runs hnswlib, doesnt return anything
''' """
requests.get(self._api_url + "/process") requests.get(self._api_url + "/process")
return True return True
def reset(self): def reset(self):
''' """
Resets the database Resets the database
''' """
return requests.get(self._api_url + "/reset") return requests.get(self._api_url + "/reset")
def persist(self): def persist(self):
''' """
Persists the database to disk in the .chroma folder inside chroma-server Persists the database to disk in the .chroma folder inside chroma-server
''' """
return requests.get(self._api_url + "/persist") return requests.get(self._api_url + "/persist")
def rand(self): def rand(self):
''' """
Stubbed out sampling endpoint, returns a random bisection of the database Stubbed out sampling endpoint, returns a random bisection of the database
''' """
x = requests.get(self._api_url + "/rand") x = requests.get(self._api_url + "/rand")
return x.json() return x.json()
def heartbeat(self): def heartbeat(self):
''' """
Returns the current server time in milliseconds to check if the server is alive Returns the current server time in milliseconds to check if the server is alive
''' """
x = requests.get(self._api_url) x = requests.get(self._api_url)
return x.json() return x.json()
def log(self, def log(
embedding_data: list, self,
input_uri: list, embedding_data: list,
input_uri: list,
dataset: list = None, dataset: list = None,
category_name: list = None): category_name: list = None,
''' ):
"""
Logs a batch of embeddings to the database Logs a batch of embeddings to the database
- pass in column oriented data lists - pass in column oriented data lists
''' """
x = requests.post(self._api_url + "/add", data = json.dumps({ x = requests.post(
"embedding_data": embedding_data, self._api_url + "/add",
"input_uri": input_uri, data=json.dumps(
"dataset": dataset, {
"category_name": category_name "embedding_data": embedding_data,
}) ) "input_uri": input_uri,
"dataset": dataset,
"category_name": category_name,
}
),
)
if x.status_code == 201: if x.status_code == 201:
return True return True
else: else:
return False return False
def log_training(self, embedding_data: list, input_uri: list, category_name: list): def log_training(self, embedding_data: list, input_uri: list, category_name: list):
''' """
Small wrapper around log() to log a batch of training embedding Small wrapper around log() to log a batch of training embedding
- sets dataset to "training" - sets dataset to "training"
''' """
return self.log( return self.log(
embedding_data=embedding_data, embedding_data=embedding_data,
input_uri=input_uri, input_uri=input_uri,
dataset="training", dataset="training",
category_name=category_name category_name=category_name,
) )
def log_production(self, embedding_data: list, input_uri: list, category_name: list): def log_production(self, embedding_data: list, input_uri: list, category_name: list):
''' """
Small wrapper around log() to log a batch of production embedding Small wrapper around log() to log a batch of production embedding
- sets dataset to "production" - sets dataset to "production"
''' """
return self.log( return self.log(
embedding_data=embedding_data, embedding_data=embedding_data,
input_uri=input_uri, input_uri=input_uri,
dataset="production", dataset="production",
category_name=category_name category_name=category_name,
) )
def log_triage(self, embedding_data: list, input_uri: list, category_name: list): def log_triage(self, embedding_data: list, input_uri: list, category_name: list):
''' """
Small wrapper around log() to log a batch of triage embedding Small wrapper around log() to log a batch of triage embedding
- sets dataset to "triage" - sets dataset to "triage"
''' """
return self.log( return self.log(
embedding_data=embedding_data, embedding_data=embedding_data,
input_uri=input_uri, input_uri=input_uri,
dataset="triage", dataset="triage",
category_name=category_name category_name=category_name,
) )
def get_nearest_neighbors(self, embedding, n_results=10, category_name=None, dataset="training"): def get_nearest_neighbors(
''' self, embedding, n_results=10, category_name=None, dataset="training"
):
"""
Gets the nearest neighbors of a single embedding Gets the nearest neighbors of a single embedding
''' """
x = requests.post(self._api_url + "/get_nearest_neighbors", data = json.dumps({ x = requests.post(
"embedding": embedding, self._api_url + "/get_nearest_neighbors",
"n_results": n_results, data=json.dumps(
"category_name": category_name, {
"dataset": dataset "embedding": embedding,
}) ) "n_results": n_results,
"category_name": category_name,
"dataset": dataset,
}
),
)
if x.status_code == 200: if x.status_code == 200:
return x.json() return x.json()
else: else:
return False return False

View File

@@ -4,11 +4,14 @@ from chroma_client import Chroma
import pytest import pytest
import time import time
from httpx import AsyncClient from httpx import AsyncClient
# from ..api import app # this wont work because i moved the file # from ..api import app # this wont work because i moved the file
@pytest.fixture @pytest.fixture
def anyio_backend(): def anyio_backend():
return 'asyncio' return "asyncio"
def test_init(): def test_init():
chroma = Chroma() chroma = Chroma()
@@ -25,4 +28,3 @@ def test_init():
# chroma = Chroma(url="http://test/api/v1") # chroma = Chroma(url="http://test/api/v1")
# response = await chroma.count() # response = await chroma.count()
# raise Exception("response" + response) # raise Exception("response" + response)

View File

@@ -1,8 +1,8 @@
import random import random
def rand_bisectional_subsample(data): def rand_bisectional_subsample(data):
""" """
Randomly bisectionally subsample a list of data to size. Randomly bisectionally subsample a list of data to size.
""" """
return data.sample(frac=0.5, replace=True, random_state=1) return data.sample(frac=0.5, replace=True, random_state=1)

View File

@@ -2,22 +2,23 @@ import numpy as np
import json import json
import ast import ast
def class_distances(data): def class_distances(data):
'''' """'
This is all very subject to change, so essentially just copy and paste from what we had before This is all very subject to change, so essentially just copy and paste from what we had before
''' """
return False return False
# def unpack_annotations(embeddings): # def unpack_annotations(embeddings):
# annotations = [json.loads(embedding['infer'])["annotations"]for embedding in embeddings] # annotations = [json.loads(embedding['infer'])["annotations"]for embedding in embeddings]
# annotations = [annotation for annotation_list in annotations for annotation in annotation_list] # annotations = [annotation for annotation_list in annotations for annotation in annotation_list]
# # Unpack embedding data # # Unpack embedding data
# embeddings = [embedding["embedding_data"] for embedding in embeddings] # embeddings = [embedding["embedding_data"] for embedding in embeddings]
# embedding_vectors_by_category = {} # embedding_vectors_by_category = {}
# for embedding_annotation_pair in zip(embeddings, annotations): # for embedding_annotation_pair in zip(embeddings, annotations):
# data = np.array(embedding_annotation_pair[0]) # data = np.array(embedding_annotation_pair[0])
# category = embedding_annotation_pair[1]['category_id'] # category = embedding_annotation_pair[1]['category_id']
# if category in embedding_vectors_by_category.keys(): # if category in embedding_vectors_by_category.keys():
# embedding_vectors_by_category[category] = np.append( # embedding_vectors_by_category[category] = np.append(
# embedding_vectors_by_category[category], data[np.newaxis, :], axis=0 # embedding_vectors_by_category[category], data[np.newaxis, :], axis=0
@@ -84,5 +85,5 @@ def class_distances(data):
# if (len(inferences) == 0): # if (len(inferences) == 0):
# raise Exception("No inferences found for datapoint") # raise Exception("No inferences found for datapoint")
# return output_distances # return output_distances

View File

@@ -11,7 +11,6 @@ from chroma_server.types import AddEmbedding, QueryEmbedding
from chroma_server.utils import logger from chroma_server.utils import logger
# Boot script # Boot script
db = DuckDB db = DuckDB
ann_index = Hnswlib ann_index = Hnswlib
@@ -34,104 +33,112 @@ if os.path.exists(".chroma/index.bin"):
app._ann_index.load(app._db.count(), len(app._db.fetch(limit=1).embedding_data)) app._ann_index.load(app._db.count(), len(app._db.fetch(limit=1).embedding_data))
# API Endpoints # API Endpoints
@app.get("/api/v1") @app.get("/api/v1")
async def root(): async def root():
''' """
Heartbeat endpoint Heartbeat endpoint
''' """
return {"nanosecond heartbeat": int(1000 * time.time_ns())} return {"nanosecond heartbeat": int(1000 * time.time_ns())}
@app.post("/api/v1/add", status_code=status.HTTP_201_CREATED) @app.post("/api/v1/add", status_code=status.HTTP_201_CREATED)
async def add_to_db(new_embedding: AddEmbedding): async def add_to_db(new_embedding: AddEmbedding):
''' """
Save embedding to database Save embedding to database
- supports single or batched embeddings - supports single or batched embeddings
''' """
app._db.add_batch( app._db.add_batch(
new_embedding.embedding_data, new_embedding.embedding_data,
new_embedding.input_uri, new_embedding.input_uri,
new_embedding.dataset, new_embedding.dataset,
new_embedding.custom_quality_score, new_embedding.custom_quality_score,
new_embedding.category_name new_embedding.category_name,
) )
return {"response": "Added record to database"} return {"response": "Added record to database"}
@app.get("/api/v1/process") @app.get("/api/v1/process")
async def process(): async def process():
''' """
Currently generates an index for the embedding db Currently generates an index for the embedding db
''' """
app._ann_index.run(app._db.fetch()) app._ann_index.run(app._db.fetch())
@app.get("/api/v1/fetch") @app.get("/api/v1/fetch")
async def fetch(where_filter={}, sort=None, limit=None): async def fetch(where_filter={}, sort=None, limit=None):
''' """
Fetches embeddings from the database Fetches embeddings from the database
- enables filtering by where_filter, sorting by key, and limiting the number of results - enables filtering by where_filter, sorting by key, and limiting the number of results
''' """
return app._db.fetch(where_filter, sort, limit).to_dict(orient="records") return app._db.fetch(where_filter, sort, limit).to_dict(orient="records")
@app.get("/api/v1/count") @app.get("/api/v1/count")
async def count(): async def count():
''' """
Returns the number of records in the database Returns the number of records in the database
''' """
return ({"count": app._db.count()}) return {"count": app._db.count()}
@app.get("/api/v1/persist") @app.get("/api/v1/persist")
async def persist(): async def persist():
''' """
Persist the database and index to disk Persist the database and index to disk
''' """
if not os.path.exists(".chroma"): if not os.path.exists(".chroma"):
os.mkdir(".chroma") os.mkdir(".chroma")
app._db.persist() app._db.persist()
app._ann_index.persist() app._ann_index.persist()
return True return True
@app.get("/api/v1/reset") @app.get("/api/v1/reset")
async def reset(): async def reset():
''' """
Reset the database and index Reset the database and index
''' """
shutil.rmtree(".chroma", ignore_errors=True) shutil.rmtree(".chroma", ignore_errors=True)
app._db = db() app._db = db()
app._ann_index = ann_index() app._ann_index = ann_index()
return True return True
@app.get("/api/v1/rand") @app.get("/api/v1/rand")
async def rand(where_filter={}, sort=None, limit=None): async def rand(where_filter={}, sort=None, limit=None):
''' """
Randomly bisection the database Randomly bisection the database
''' """
results = app._db.fetch(where_filter, sort, limit) results = app._db.fetch(where_filter, sort, limit)
rand = rand_bisectional_subsample(results) rand = rand_bisectional_subsample(results)
return rand.to_dict(orient="records") return rand.to_dict(orient="records")
@app.post("/api/v1/get_nearest_neighbors") @app.post("/api/v1/get_nearest_neighbors")
async def get_nearest_neighbors(embedding: QueryEmbedding): async def get_nearest_neighbors(embedding: QueryEmbedding):
''' """
return the distance, database ids, and embedding themselves for the input embedding return the distance, database ids, and embedding themselves for the input embedding
''' """
ids = None ids = None
filter_by_where = {} filter_by_where = {}
if embedding.category_name is not None: if embedding.category_name is not None:
filter_by_where['category_name'] = embedding.category_name filter_by_where["category_name"] = embedding.category_name
if embedding.dataset is not None: if embedding.dataset is not None:
filter_by_where['dataset'] = embedding.dataset filter_by_where["dataset"] = embedding.dataset
if filter_by_where is not None: if filter_by_where is not None:
ids = app._db.fetch(filter_by_where)["id"].tolist() ids = app._db.fetch(filter_by_where)["id"].tolist()
nn = app._ann_index.get_nearest_neighbors(embedding.embedding, embedding.n_results, ids) nn = app._ann_index.get_nearest_neighbors(embedding.embedding, embedding.n_results, ids)
return { return {
"ids": nn[0].tolist()[0], "ids": nn[0].tolist()[0],
"embeddings": app._db.get_by_ids(nn[0].tolist()[0]).to_dict(orient="records"), "embeddings": app._db.get_by_ids(nn[0].tolist()[0]).to_dict(orient="records"),
"distances": nn[1].tolist()[0] "distances": nn[1].tolist()[0],
} }

View File

@@ -1,6 +1,7 @@
from abc import abstractmethod from abc import abstractmethod
class Database():
class Database:
@abstractmethod @abstractmethod
def __init__(self): def __init__(self):
pass pass
@@ -23,4 +24,4 @@ class Database():
@abstractmethod @abstractmethod
def load(self): def load(self):
pass pass

View File

@@ -4,12 +4,14 @@ import duckdb
import numpy as np import numpy as np
import pandas as pd import pandas as pd
class DuckDB(Database): class DuckDB(Database):
_conn = None _conn = None
def __init__(self): def __init__(self):
self._conn = duckdb.connect() self._conn = duckdb.connect()
self._conn.execute(''' self._conn.execute(
"""
CREATE TABLE embeddings ( CREATE TABLE embeddings (
id integer PRIMARY KEY, id integer PRIMARY KEY,
embedding_data REAL[], embedding_data REAL[],
@@ -18,32 +20,39 @@ class DuckDB(Database):
custom_quality_score REAL, custom_quality_score REAL,
category_name STRING category_name STRING
) )
''') """
)
# ids to manage internal bookkeeping and *nothing else*, users should not have to care about these ids # ids to manage internal bookkeeping and *nothing else*, users should not have to care about these ids
self._conn.execute(''' self._conn.execute(
"""
CREATE SEQUENCE seq_id START 1; CREATE SEQUENCE seq_id START 1;
''') """
)
self._conn.execute(''' self._conn.execute(
"""
-- change the default null sorting order to either NULLS FIRST and NULLS LAST -- change the default null sorting order to either NULLS FIRST and NULLS LAST
PRAGMA default_null_order='NULLS LAST'; PRAGMA default_null_order='NULLS LAST';
-- change the default sorting order to either DESC or ASC -- change the default sorting order to either DESC or ASC
PRAGMA default_order='DESC'; PRAGMA default_order='DESC';
''') """
)
return return
def add_batch(self, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None): def add_batch(
''' self, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None
):
"""
Add embeddings to the database Add embeddings to the database
This accepts both a single input and a list of inputs This accepts both a single input and a list of inputs
''' """
# create list of the types of all inputs # create list of the types of all inputs
types = [type(x).__name__ for x in [embedding_data, input_uri]] types = [type(x).__name__ for x in [embedding_data, input_uri]]
# if all of the types are 'list' - do batch mode # if all of the types are 'list' - do batch mode
if all(x == 'list' for x in types): if all(x == "list" for x in types):
lengths = [len(x) for x in [embedding_data, input_uri]] lengths = [len(x) for x in [embedding_data, input_uri]]
# accepts some inputs as str or none, and this multiples them out to the correct length # accepts some inputs as str or none, and this multiples them out to the correct length
@@ -57,41 +66,56 @@ class DuckDB(Database):
# we have to move from column to row format for duckdb # we have to move from column to row format for duckdb
data_to_insert = [] data_to_insert = []
for i in range(lengths[0]): for i in range(lengths[0]):
data_to_insert.append([embedding_data[i], input_uri[i], dataset[i], custom_quality_score[i], category_name[i]]) data_to_insert.append(
[
embedding_data[i],
input_uri[i],
dataset[i],
custom_quality_score[i],
category_name[i],
]
)
if all(x == lengths[0] for x in lengths): if all(x == lengths[0] for x in lengths):
self._conn.executemany(''' self._conn.executemany(
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)''', """
data_to_insert INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)""",
data_to_insert,
) )
return return
# if any of the types are 'list' - throw an error # if any of the types are 'list' - throw an error
if any(x == list for x in [input_uri, dataset, custom_quality_score, category_name]): if any(x == list for x in [input_uri, dataset, custom_quality_score, category_name]):
raise Exception("Invalid input types. One input is a list where others are not: " + str(types)) raise Exception(
"Invalid input types. One input is a list where others are not: " + str(types)
)
# single insert mode # single insert mode
# This should never fire because we do everything in batch mode, but given the mode away from duckdb likely, I am just leaving it in # This should never fire because we do everything in batch mode, but given the mode away from duckdb likely, I am just leaving it in
self._conn.execute(''' self._conn.execute(
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)''', """
[embedding_data, input_uri, dataset, custom_quality_score, category_name] INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)""",
[embedding_data, input_uri, dataset, custom_quality_score, category_name],
) )
def count(self):
return self._conn.execute('''
SELECT COUNT(*) FROM embeddings;
''').fetchone()[0]
def update(self, data): # call this update_custom_quality_score! that is all it does def count(self):
''' return self._conn.execute(
"""
SELECT COUNT(*) FROM embeddings;
"""
).fetchone()[0]
def update(self, data): # call this update_custom_quality_score! that is all it does
"""
I was not able to figure out (yet) how to do a bulk update in duckdb I was not able to figure out (yet) how to do a bulk update in duckdb
This is going to be fairly slow This is going to be fairly slow
''' """
for element in data: for element in data:
if element['custom_quality_score'] is None: if element["custom_quality_score"] is None:
continue continue
self._conn.execute(f''' self._conn.execute(
UPDATE embeddings SET custom_quality_score={element['custom_quality_score']} WHERE id={element['id']}''' f"""
UPDATE embeddings SET custom_quality_score={element['custom_quality_score']} WHERE id={element['id']}"""
) )
def fetch(self, where_filter={}, sort=None, limit=None): def fetch(self, where_filter={}, sort=None, limit=None):
@@ -99,12 +123,12 @@ class DuckDB(Database):
if where_filter is not None: if where_filter is not None:
if not isinstance(where_filter, dict): if not isinstance(where_filter, dict):
raise Exception("Invalid where_filter: " + str(where_filter)) raise Exception("Invalid where_filter: " + str(where_filter))
# ensure where_filter is a flat dict # ensure where_filter is a flat dict
for key in where_filter: for key in where_filter:
if isinstance(where_filter[key], dict): if isinstance(where_filter[key], dict):
raise Exception("Invalid where_filter: " + str(where_filter)) raise Exception("Invalid where_filter: " + str(where_filter))
where_filter = " AND ".join([f"{key} = '{value}'" for key, value in where_filter.items()]) where_filter = " AND ".join([f"{key} = '{value}'" for key, value in where_filter.items()])
if where_filter: if where_filter:
@@ -116,7 +140,9 @@ class DuckDB(Database):
if limit is not None or isinstance(limit, int): if limit is not None or isinstance(limit, int):
where_filter += f" LIMIT {limit}" where_filter += f" LIMIT {limit}"
return self._conn.execute(f''' return (
self._conn.execute(
f"""
SELECT SELECT
id, id,
embedding_data, embedding_data,
@@ -127,41 +153,49 @@ class DuckDB(Database):
FROM FROM
embeddings embeddings
{where_filter} {where_filter}
''').fetchdf().replace({np.nan: None}) # replace nan with None for json serialization """
)
.fetchdf()
.replace({np.nan: None})
) # replace nan with None for json serialization
def delete_batch(self, batch): def delete_batch(self, batch):
raise NotImplementedError raise NotImplementedError
def persist(self): def persist(self):
''' """
Persist the database to disk Persist the database to disk
''' """
if self._conn is None: if self._conn is None:
return return
self._conn.execute(''' self._conn.execute(
"""
COPY COPY
(SELECT * FROM embeddings) (SELECT * FROM embeddings)
TO '.chroma/chroma.parquet' TO '.chroma/chroma.parquet'
(FORMAT PARQUET); (FORMAT PARQUET);
''') """
)
def load(self, path=".chroma/chroma.parquet"): def load(self, path=".chroma/chroma.parquet"):
''' """
Load the database from disk Load the database from disk
''' """
self._conn.execute(f"INSERT INTO embeddings SELECT * FROM read_parquet('{path}');") self._conn.execute(f"INSERT INTO embeddings SELECT * FROM read_parquet('{path}');")
def get_by_ids(self, ids=list): def get_by_ids(self, ids=list):
# select from duckdb table where ids are in the list # select from duckdb table where ids are in the list
if not isinstance(ids, list): if not isinstance(ids, list):
raise Exception("ids must be a list") raise Exception("ids must be a list")
if not ids: if not ids:
# create an empty pandas dataframe # create an empty pandas dataframe
return pd.DataFrame() return pd.DataFrame()
return self._conn.execute(f''' return (
self._conn.execute(
f"""
SELECT SELECT
id, id,
embedding_data, embedding_data,
@@ -173,4 +207,8 @@ class DuckDB(Database):
embeddings embeddings
WHERE WHERE
id IN ({','.join([str(x) for x in ids])}) id IN ({','.join([str(x) for x in ids])})
''').fetchdf().replace({np.nan: None}) # replace nan with None for json serialization """
)
.fetchdf()
.replace({np.nan: None})
) # replace nan with None for json serialization

View File

@@ -1,6 +1,7 @@
from abc import abstractmethod from abc import abstractmethod
class Index():
class Index:
@abstractmethod @abstractmethod
def __init__(self): def __init__(self):
pass pass
@@ -23,4 +24,4 @@ class Index():
@abstractmethod @abstractmethod
def load(self): def load(self):
pass pass

View File

@@ -3,6 +3,7 @@ import numpy as np
from chroma_server.index.abstract import Index from chroma_server.index.abstract import Index
from chroma_server.utils import logger from chroma_server.utils import logger
class Hnswlib(Index): class Hnswlib(Index):
_index = None _index = None
@@ -14,20 +15,22 @@ class Hnswlib(Index):
# more comments available at the source: https://github.com/nmslib/hnswlib # more comments available at the source: https://github.com/nmslib/hnswlib
# We split the data in two batches: # We split the data in two batches:
data1 = embedding_data['embedding_data'].to_numpy().tolist() data1 = embedding_data["embedding_data"].to_numpy().tolist()
dim = len(data1[0]) dim = len(data1[0])
num_elements = len(data1) num_elements = len(data1)
# logger.debug("dimensionality is:", dim) # logger.debug("dimensionality is:", dim)
# logger.debug("total number of elements is:", num_elements) # logger.debug("total number of elements is:", num_elements)
# logger.debug("max elements", num_elements//2) # logger.debug("max elements", num_elements//2)
concatted_data = data1 concatted_data = data1
# logger.debug("concatted_data", len(concatted_data)) # logger.debug("concatted_data", len(concatted_data))
p = hnswlib.Index(space='l2', dim=dim) # # Declaring index, possible options are l2, cosine or ip p = hnswlib.Index(
p.init_index(max_elements=len(data1), ef_construction=100, M=16) # Initing index space="l2", dim=dim
) # # Declaring index, possible options are l2, cosine or ip
p.init_index(max_elements=len(data1), ef_construction=100, M=16) # Initing index
p.set_ef(10) # Controlling the recall by setting ef: p.set_ef(10) # Controlling the recall by setting ef:
p.set_num_threads(4) # Set number of threads used during batch search/construction p.set_num_threads(4) # Set number of threads used during batch search/construction
# logger.debug("Adding first batch of elements", (len(data1))) # logger.debug("Adding first batch of elements", (len(data1)))
p.add_items(data1, embedding_data["id"]) p.add_items(data1, embedding_data["id"])
@@ -37,12 +40,15 @@ class Hnswlib(Index):
# logger.debug("database_ids", database_ids) # logger.debug("database_ids", database_ids)
# logger.debug("distances", distances) # logger.debug("distances", distances)
# logger.debug(len(distances)) # logger.debug(len(distances))
logger.debug("Recall for the first batch:" + str(np.mean(database_ids.reshape(-1) == np.arange(len(data1))))) logger.debug(
"Recall for the first batch:"
+ str(np.mean(database_ids.reshape(-1) == np.arange(len(data1))))
)
self._index = p self._index = p
def fetch(self, query): def fetch(self, query):
raise NotImplementedError raise NotImplementedError
def delete_batch(self, batch): def delete_batch(self, batch):
raise NotImplementedError raise NotImplementedError
@@ -51,12 +57,12 @@ class Hnswlib(Index):
if self._index is None: if self._index is None:
return return
self._index.save_index(".chroma/index.bin") self._index.save_index(".chroma/index.bin")
logger.debug('Index saved to .chroma/index.bin') logger.debug("Index saved to .chroma/index.bin")
def load(self, elements, dimensionality): def load(self, elements, dimensionality):
p = hnswlib.Index(space='l2', dim= dimensionality) p = hnswlib.Index(space="l2", dim=dimensionality)
self._index = p self._index = p
self._index.load_index(".chroma/index.bin", max_elements= elements) self._index.load_index(".chroma/index.bin", max_elements=elements)
# do knn_query on hnswlib to get nearest neighbors # do knn_query on hnswlib to get nearest neighbors
def get_nearest_neighbors(self, query, k, ids=None): def get_nearest_neighbors(self, query, k, ids=None):

View File

@@ -4,32 +4,45 @@ from httpx import AsyncClient
from ..api import app from ..api import app
@pytest.fixture @pytest.fixture
def anyio_backend(): def anyio_backend():
return 'asyncio' return "asyncio"
@pytest.mark.anyio @pytest.mark.anyio
async def test_root(): async def test_root():
async with AsyncClient(app=app, base_url="http://test") as ac: async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.get("/api/v1") response = await ac.get("/api/v1")
assert response.status_code == 200 assert response.status_code == 200
assert abs(response.json()["nanosecond heartbeat"] - int(1000 * time.time_ns())) < 3_000_000_000 # a billion nanoseconds = 3s assert (
abs(response.json()["nanosecond heartbeat"] - int(1000 * time.time_ns())) < 3_000_000_000
) # a billion nanoseconds = 3s
async def post_one_record(ac): async def post_one_record(ac):
return await ac.post("/api/v1/add", json={ return await ac.post(
"embedding_data": [1.02, 2.03, 3.03], "/api/v1/add",
"input_uri": "https://example.com", json={
"dataset": "coco", "embedding_data": [1.02, 2.03, 3.03],
"category_name": "person" "input_uri": "https://example.com",
}) "dataset": "coco",
"category_name": "person",
},
)
async def post_batch_records(ac): async def post_batch_records(ac):
return await ac.post("/api/v1/add", json={ return await ac.post(
"embedding_data": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], "/api/v1/add",
"input_uri": ["https://example.com", "https://example.com"], json={
"dataset": "training", "embedding_data": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"category_name": "person" "input_uri": ["https://example.com", "https://example.com"],
}) "dataset": "training",
"category_name": "person",
},
)
@pytest.mark.anyio @pytest.mark.anyio
async def test_add_to_db(): async def test_add_to_db():
@@ -38,6 +51,7 @@ async def test_add_to_db():
assert response.status_code == 201 assert response.status_code == 201
assert response.json() == {"response": "Added record to database"} assert response.json() == {"response": "Added record to database"}
@pytest.mark.anyio @pytest.mark.anyio
async def test_add_to_db_batch(): async def test_add_to_db_batch():
async with AsyncClient(app=app, base_url="http://test") as ac: async with AsyncClient(app=app, base_url="http://test") as ac:
@@ -45,6 +59,7 @@ async def test_add_to_db_batch():
assert response.status_code == 201 assert response.status_code == 201
assert response.json() == {"response": "Added record to database"} assert response.json() == {"response": "Added record to database"}
@pytest.mark.anyio @pytest.mark.anyio
async def test_fetch_from_db(): async def test_fetch_from_db():
async with AsyncClient(app=app, base_url="http://test") as ac: async with AsyncClient(app=app, base_url="http://test") as ac:
@@ -53,15 +68,17 @@ async def test_fetch_from_db():
assert response.status_code == 200 assert response.status_code == 200
assert len(response.json()) == 1 assert len(response.json()) == 1
@pytest.mark.anyio @pytest.mark.anyio
async def test_count_from_db(): async def test_count_from_db():
async with AsyncClient(app=app, base_url="http://test") as ac: async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.get("/api/v1/reset") # reset db await ac.get("/api/v1/reset") # reset db
await post_batch_records(ac) await post_batch_records(ac)
response = await ac.get("/api/v1/count") response = await ac.get("/api/v1/count")
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"count": 2} assert response.json() == {"count": 2}
@pytest.mark.anyio @pytest.mark.anyio
async def test_reset_db(): async def test_reset_db():
async with AsyncClient(app=app, base_url="http://test") as ac: async with AsyncClient(app=app, base_url="http://test") as ac:
@@ -74,25 +91,19 @@ async def test_reset_db():
response = await ac.get("/api/v1/count") response = await ac.get("/api/v1/count")
assert response.json() == {"count": 0} assert response.json() == {"count": 0}
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_nearest_neighbors(): async def test_get_nearest_neighbors():
async with AsyncClient(app=app, base_url="http://test") as ac: async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.get("/api/v1/reset") await ac.get("/api/v1/reset")
await post_batch_records(ac) await post_batch_records(ac)
await ac.get("/api/v1/process") await ac.get("/api/v1/process")
response = await ac.post("/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1}) response = await ac.post(
"/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1}
)
assert response.status_code == 200 assert response.status_code == 200
assert len(response.json()["ids"]) == 1 assert len(response.json()["ids"]) == 1
@pytest.mark.anyio
async def test_get_nearest_neighbors_filter():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.get("/api/v1/reset")
await post_batch_records(ac)
await ac.get("/api/v1/process")
response = await ac.post("/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1, "dataset": "training", "category_name": "monkey"})
assert response.status_code == 200
assert len(response.json()["ids"]) == 0
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_nearest_neighbors_filter(): async def test_get_nearest_neighbors_filter():
@@ -100,7 +111,34 @@ async def test_get_nearest_neighbors_filter():
await ac.get("/api/v1/reset") await ac.get("/api/v1/reset")
await post_batch_records(ac) await post_batch_records(ac)
await ac.get("/api/v1/process") await ac.get("/api/v1/process")
response = await ac.post("/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 2, "dataset": "training", "category_name": "person"}) response = await ac.post(
"/api/v1/get_nearest_neighbors",
json={
"embedding": [1.1, 2.3, 3.2],
"n_results": 1,
"dataset": "training",
"category_name": "monkey",
},
)
assert response.status_code == 200
assert len(response.json()["ids"]) == 0
@pytest.mark.anyio
async def test_get_nearest_neighbors_filter():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.get("/api/v1/reset")
await post_batch_records(ac)
await ac.get("/api/v1/process")
response = await ac.post(
"/api/v1/get_nearest_neighbors",
json={
"embedding": [1.1, 2.3, 3.2],
"n_results": 2,
"dataset": "training",
"category_name": "person",
},
)
assert response.status_code == 200 assert response.status_code == 200
assert len(response.json()["ids"]) == 2 assert len(response.json()["ids"]) == 2
@@ -118,4 +156,4 @@ async def test_get_nearest_neighbors_filter():
# Purposefully untested # Purposefully untested
# - process # - process
# - rand # - rand

View File

@@ -6,9 +6,10 @@ class AddEmbedding(BaseModel):
embedding_data: list embedding_data: list
input_uri: Union[str, list] input_uri: Union[str, list]
dataset: Union[str, list] = None dataset: Union[str, list] = None
custom_quality_score: Union[float, list] = None custom_quality_score: Union[float, list] = None
category_name: Union[str, list] = None category_name: Union[str, list] = None
class QueryEmbedding(BaseModel): class QueryEmbedding(BaseModel):
embedding: list embedding: list
n_results: int = 10 n_results: int = 10

View File

@@ -1,5 +1,6 @@
import logging import logging
def setup_logging(): def setup_logging():
logging.basicConfig(filename="chroma_logs.log") logging.basicConfig(filename="chroma_logs.log")
logger = logging.getLogger("Chroma") logger = logging.getLogger("Chroma")
@@ -7,4 +8,5 @@ def setup_logging():
logger.debug("Logger created") logger.debug("Logger created")
return logger return logger
logger = setup_logging()
logger = setup_logging()

8
pyproject.toml Normal file
View File

@@ -0,0 +1,8 @@
[tool.black]
line-length = 100
# Black will refuse to run if it's not this version.
required-version = "22.6.0"
# Ensure black's output will be compatible with all listed versions.
target-version = ['py36', 'py37', 'py38', 'py39', 'py310']