mirror of
https://github.com/placeholder-soft/chroma.git
synced 2026-04-30 04:45:01 +08:00
black python formatting
This commit is contained in:
5
Makefile
Normal file
5
Makefile
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
black:
|
||||||
|
black --fast chroma-server chroma-client
|
||||||
|
|
||||||
|
check_black:
|
||||||
|
black --check --fast chroma-server chroma-client
|
||||||
@@ -2,6 +2,7 @@ import requests
|
|||||||
import json
|
import json
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
class Chroma:
|
class Chroma:
|
||||||
|
|
||||||
_api_url = "http://localhost:8000/api/v1"
|
_api_url = "http://localhost:8000/api/v1"
|
||||||
@@ -15,127 +16,142 @@ class Chroma:
|
|||||||
self.url = url
|
self.url = url
|
||||||
|
|
||||||
def count(self):
|
def count(self):
|
||||||
'''
|
"""
|
||||||
Returns the number of embeddings in the database
|
Returns the number of embeddings in the database
|
||||||
'''
|
"""
|
||||||
x = requests.get(self._api_url + "/count")
|
x = requests.get(self._api_url + "/count")
|
||||||
return x.json()
|
return x.json()
|
||||||
|
|
||||||
def fetch(self, where_filter={}, sort=None, limit=None):
|
def fetch(self, where_filter={}, sort=None, limit=None):
|
||||||
'''
|
"""
|
||||||
Fetches embeddings from the database
|
Fetches embeddings from the database
|
||||||
'''
|
"""
|
||||||
x = requests.get(self._api_url + "/fetch", data=json.dumps({
|
x = requests.get(
|
||||||
"where_filter":json.dumps(where_filter),
|
self._api_url + "/fetch",
|
||||||
"sort":sort,
|
data=json.dumps(
|
||||||
"limit":limit
|
{"where_filter": json.dumps(where_filter), "sort": sort, "limit": limit}
|
||||||
}))
|
),
|
||||||
|
)
|
||||||
return x.json()
|
return x.json()
|
||||||
|
|
||||||
def process(self):
|
def process(self):
|
||||||
'''
|
"""
|
||||||
Processes embeddings in the database
|
Processes embeddings in the database
|
||||||
- currently this only runs hnswlib, doesnt return anything
|
- currently this only runs hnswlib, doesnt return anything
|
||||||
'''
|
"""
|
||||||
requests.get(self._api_url + "/process")
|
requests.get(self._api_url + "/process")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
'''
|
"""
|
||||||
Resets the database
|
Resets the database
|
||||||
'''
|
"""
|
||||||
return requests.get(self._api_url + "/reset")
|
return requests.get(self._api_url + "/reset")
|
||||||
|
|
||||||
def persist(self):
|
def persist(self):
|
||||||
'''
|
"""
|
||||||
Persists the database to disk in the .chroma folder inside chroma-server
|
Persists the database to disk in the .chroma folder inside chroma-server
|
||||||
'''
|
"""
|
||||||
return requests.get(self._api_url + "/persist")
|
return requests.get(self._api_url + "/persist")
|
||||||
|
|
||||||
def rand(self):
|
def rand(self):
|
||||||
'''
|
"""
|
||||||
Stubbed out sampling endpoint, returns a random bisection of the database
|
Stubbed out sampling endpoint, returns a random bisection of the database
|
||||||
'''
|
"""
|
||||||
x = requests.get(self._api_url + "/rand")
|
x = requests.get(self._api_url + "/rand")
|
||||||
return x.json()
|
return x.json()
|
||||||
|
|
||||||
def heartbeat(self):
|
def heartbeat(self):
|
||||||
'''
|
"""
|
||||||
Returns the current server time in milliseconds to check if the server is alive
|
Returns the current server time in milliseconds to check if the server is alive
|
||||||
'''
|
"""
|
||||||
x = requests.get(self._api_url)
|
x = requests.get(self._api_url)
|
||||||
return x.json()
|
return x.json()
|
||||||
|
|
||||||
def log(self,
|
def log(
|
||||||
embedding_data: list,
|
self,
|
||||||
input_uri: list,
|
embedding_data: list,
|
||||||
|
input_uri: list,
|
||||||
dataset: list = None,
|
dataset: list = None,
|
||||||
category_name: list = None):
|
category_name: list = None,
|
||||||
'''
|
):
|
||||||
|
"""
|
||||||
Logs a batch of embeddings to the database
|
Logs a batch of embeddings to the database
|
||||||
- pass in column oriented data lists
|
- pass in column oriented data lists
|
||||||
'''
|
"""
|
||||||
|
|
||||||
x = requests.post(self._api_url + "/add", data = json.dumps({
|
x = requests.post(
|
||||||
"embedding_data": embedding_data,
|
self._api_url + "/add",
|
||||||
"input_uri": input_uri,
|
data=json.dumps(
|
||||||
"dataset": dataset,
|
{
|
||||||
"category_name": category_name
|
"embedding_data": embedding_data,
|
||||||
}) )
|
"input_uri": input_uri,
|
||||||
|
"dataset": dataset,
|
||||||
|
"category_name": category_name,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
if x.status_code == 201:
|
if x.status_code == 201:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def log_training(self, embedding_data: list, input_uri: list, category_name: list):
|
def log_training(self, embedding_data: list, input_uri: list, category_name: list):
|
||||||
'''
|
"""
|
||||||
Small wrapper around log() to log a batch of training embedding
|
Small wrapper around log() to log a batch of training embedding
|
||||||
- sets dataset to "training"
|
- sets dataset to "training"
|
||||||
'''
|
"""
|
||||||
return self.log(
|
return self.log(
|
||||||
embedding_data=embedding_data,
|
embedding_data=embedding_data,
|
||||||
input_uri=input_uri,
|
input_uri=input_uri,
|
||||||
dataset="training",
|
dataset="training",
|
||||||
category_name=category_name
|
category_name=category_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def log_production(self, embedding_data: list, input_uri: list, category_name: list):
|
def log_production(self, embedding_data: list, input_uri: list, category_name: list):
|
||||||
'''
|
"""
|
||||||
Small wrapper around log() to log a batch of production embedding
|
Small wrapper around log() to log a batch of production embedding
|
||||||
- sets dataset to "production"
|
- sets dataset to "production"
|
||||||
'''
|
"""
|
||||||
return self.log(
|
return self.log(
|
||||||
embedding_data=embedding_data,
|
embedding_data=embedding_data,
|
||||||
input_uri=input_uri,
|
input_uri=input_uri,
|
||||||
dataset="production",
|
dataset="production",
|
||||||
category_name=category_name
|
category_name=category_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def log_triage(self, embedding_data: list, input_uri: list, category_name: list):
|
def log_triage(self, embedding_data: list, input_uri: list, category_name: list):
|
||||||
'''
|
"""
|
||||||
Small wrapper around log() to log a batch of triage embedding
|
Small wrapper around log() to log a batch of triage embedding
|
||||||
- sets dataset to "triage"
|
- sets dataset to "triage"
|
||||||
'''
|
"""
|
||||||
return self.log(
|
return self.log(
|
||||||
embedding_data=embedding_data,
|
embedding_data=embedding_data,
|
||||||
input_uri=input_uri,
|
input_uri=input_uri,
|
||||||
dataset="triage",
|
dataset="triage",
|
||||||
category_name=category_name
|
category_name=category_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_nearest_neighbors(self, embedding, n_results=10, category_name=None, dataset="training"):
|
def get_nearest_neighbors(
|
||||||
'''
|
self, embedding, n_results=10, category_name=None, dataset="training"
|
||||||
|
):
|
||||||
|
"""
|
||||||
Gets the nearest neighbors of a single embedding
|
Gets the nearest neighbors of a single embedding
|
||||||
'''
|
"""
|
||||||
x = requests.post(self._api_url + "/get_nearest_neighbors", data = json.dumps({
|
x = requests.post(
|
||||||
"embedding": embedding,
|
self._api_url + "/get_nearest_neighbors",
|
||||||
"n_results": n_results,
|
data=json.dumps(
|
||||||
"category_name": category_name,
|
{
|
||||||
"dataset": dataset
|
"embedding": embedding,
|
||||||
}) )
|
"n_results": n_results,
|
||||||
|
"category_name": category_name,
|
||||||
|
"dataset": dataset,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
if x.status_code == 200:
|
if x.status_code == 200:
|
||||||
return x.json()
|
return x.json()
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -4,11 +4,14 @@ from chroma_client import Chroma
|
|||||||
import pytest
|
import pytest
|
||||||
import time
|
import time
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
|
||||||
# from ..api import app # this wont work because i moved the file
|
# from ..api import app # this wont work because i moved the file
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def anyio_backend():
|
def anyio_backend():
|
||||||
return 'asyncio'
|
return "asyncio"
|
||||||
|
|
||||||
|
|
||||||
def test_init():
|
def test_init():
|
||||||
chroma = Chroma()
|
chroma = Chroma()
|
||||||
@@ -25,4 +28,3 @@ def test_init():
|
|||||||
# chroma = Chroma(url="http://test/api/v1")
|
# chroma = Chroma(url="http://test/api/v1")
|
||||||
# response = await chroma.count()
|
# response = await chroma.count()
|
||||||
# raise Exception("response" + response)
|
# raise Exception("response" + response)
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
|
||||||
def rand_bisectional_subsample(data):
|
def rand_bisectional_subsample(data):
|
||||||
"""
|
"""
|
||||||
Randomly bisectionally subsample a list of data to size.
|
Randomly bisectionally subsample a list of data to size.
|
||||||
"""
|
"""
|
||||||
return data.sample(frac=0.5, replace=True, random_state=1)
|
return data.sample(frac=0.5, replace=True, random_state=1)
|
||||||
|
|||||||
@@ -2,22 +2,23 @@ import numpy as np
|
|||||||
import json
|
import json
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
|
|
||||||
def class_distances(data):
|
def class_distances(data):
|
||||||
''''
|
"""'
|
||||||
This is all very subject to change, so essentially just copy and paste from what we had before
|
This is all very subject to change, so essentially just copy and paste from what we had before
|
||||||
'''
|
"""
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# def unpack_annotations(embeddings):
|
# def unpack_annotations(embeddings):
|
||||||
# annotations = [json.loads(embedding['infer'])["annotations"]for embedding in embeddings]
|
# annotations = [json.loads(embedding['infer'])["annotations"]for embedding in embeddings]
|
||||||
# annotations = [annotation for annotation_list in annotations for annotation in annotation_list]
|
# annotations = [annotation for annotation_list in annotations for annotation in annotation_list]
|
||||||
# # Unpack embedding data
|
# # Unpack embedding data
|
||||||
# embeddings = [embedding["embedding_data"] for embedding in embeddings]
|
# embeddings = [embedding["embedding_data"] for embedding in embeddings]
|
||||||
# embedding_vectors_by_category = {}
|
# embedding_vectors_by_category = {}
|
||||||
# for embedding_annotation_pair in zip(embeddings, annotations):
|
# for embedding_annotation_pair in zip(embeddings, annotations):
|
||||||
# data = np.array(embedding_annotation_pair[0])
|
# data = np.array(embedding_annotation_pair[0])
|
||||||
# category = embedding_annotation_pair[1]['category_id']
|
# category = embedding_annotation_pair[1]['category_id']
|
||||||
# if category in embedding_vectors_by_category.keys():
|
# if category in embedding_vectors_by_category.keys():
|
||||||
# embedding_vectors_by_category[category] = np.append(
|
# embedding_vectors_by_category[category] = np.append(
|
||||||
# embedding_vectors_by_category[category], data[np.newaxis, :], axis=0
|
# embedding_vectors_by_category[category], data[np.newaxis, :], axis=0
|
||||||
@@ -84,5 +85,5 @@ def class_distances(data):
|
|||||||
|
|
||||||
# if (len(inferences) == 0):
|
# if (len(inferences) == 0):
|
||||||
# raise Exception("No inferences found for datapoint")
|
# raise Exception("No inferences found for datapoint")
|
||||||
|
|
||||||
# return output_distances
|
# return output_distances
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from chroma_server.types import AddEmbedding, QueryEmbedding
|
|||||||
from chroma_server.utils import logger
|
from chroma_server.utils import logger
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Boot script
|
# Boot script
|
||||||
db = DuckDB
|
db = DuckDB
|
||||||
ann_index = Hnswlib
|
ann_index = Hnswlib
|
||||||
@@ -34,104 +33,112 @@ if os.path.exists(".chroma/index.bin"):
|
|||||||
app._ann_index.load(app._db.count(), len(app._db.fetch(limit=1).embedding_data))
|
app._ann_index.load(app._db.count(), len(app._db.fetch(limit=1).embedding_data))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# API Endpoints
|
# API Endpoints
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/v1")
|
@app.get("/api/v1")
|
||||||
async def root():
|
async def root():
|
||||||
'''
|
"""
|
||||||
Heartbeat endpoint
|
Heartbeat endpoint
|
||||||
'''
|
"""
|
||||||
return {"nanosecond heartbeat": int(1000 * time.time_ns())}
|
return {"nanosecond heartbeat": int(1000 * time.time_ns())}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/v1/add", status_code=status.HTTP_201_CREATED)
|
@app.post("/api/v1/add", status_code=status.HTTP_201_CREATED)
|
||||||
async def add_to_db(new_embedding: AddEmbedding):
|
async def add_to_db(new_embedding: AddEmbedding):
|
||||||
'''
|
"""
|
||||||
Save embedding to database
|
Save embedding to database
|
||||||
- supports single or batched embeddings
|
- supports single or batched embeddings
|
||||||
'''
|
"""
|
||||||
|
|
||||||
app._db.add_batch(
|
app._db.add_batch(
|
||||||
new_embedding.embedding_data,
|
new_embedding.embedding_data,
|
||||||
new_embedding.input_uri,
|
new_embedding.input_uri,
|
||||||
new_embedding.dataset,
|
new_embedding.dataset,
|
||||||
new_embedding.custom_quality_score,
|
new_embedding.custom_quality_score,
|
||||||
new_embedding.category_name
|
new_embedding.category_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"response": "Added record to database"}
|
return {"response": "Added record to database"}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/v1/process")
|
@app.get("/api/v1/process")
|
||||||
async def process():
|
async def process():
|
||||||
'''
|
"""
|
||||||
Currently generates an index for the embedding db
|
Currently generates an index for the embedding db
|
||||||
'''
|
"""
|
||||||
app._ann_index.run(app._db.fetch())
|
app._ann_index.run(app._db.fetch())
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/v1/fetch")
|
@app.get("/api/v1/fetch")
|
||||||
async def fetch(where_filter={}, sort=None, limit=None):
|
async def fetch(where_filter={}, sort=None, limit=None):
|
||||||
'''
|
"""
|
||||||
Fetches embeddings from the database
|
Fetches embeddings from the database
|
||||||
- enables filtering by where_filter, sorting by key, and limiting the number of results
|
- enables filtering by where_filter, sorting by key, and limiting the number of results
|
||||||
'''
|
"""
|
||||||
return app._db.fetch(where_filter, sort, limit).to_dict(orient="records")
|
return app._db.fetch(where_filter, sort, limit).to_dict(orient="records")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/v1/count")
|
@app.get("/api/v1/count")
|
||||||
async def count():
|
async def count():
|
||||||
'''
|
"""
|
||||||
Returns the number of records in the database
|
Returns the number of records in the database
|
||||||
'''
|
"""
|
||||||
return ({"count": app._db.count()})
|
return {"count": app._db.count()}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/v1/persist")
|
@app.get("/api/v1/persist")
|
||||||
async def persist():
|
async def persist():
|
||||||
'''
|
"""
|
||||||
Persist the database and index to disk
|
Persist the database and index to disk
|
||||||
'''
|
"""
|
||||||
if not os.path.exists(".chroma"):
|
if not os.path.exists(".chroma"):
|
||||||
os.mkdir(".chroma")
|
os.mkdir(".chroma")
|
||||||
|
|
||||||
app._db.persist()
|
app._db.persist()
|
||||||
app._ann_index.persist()
|
app._ann_index.persist()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/v1/reset")
|
@app.get("/api/v1/reset")
|
||||||
async def reset():
|
async def reset():
|
||||||
'''
|
"""
|
||||||
Reset the database and index
|
Reset the database and index
|
||||||
'''
|
"""
|
||||||
shutil.rmtree(".chroma", ignore_errors=True)
|
shutil.rmtree(".chroma", ignore_errors=True)
|
||||||
app._db = db()
|
app._db = db()
|
||||||
app._ann_index = ann_index()
|
app._ann_index = ann_index()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/v1/rand")
|
@app.get("/api/v1/rand")
|
||||||
async def rand(where_filter={}, sort=None, limit=None):
|
async def rand(where_filter={}, sort=None, limit=None):
|
||||||
'''
|
"""
|
||||||
Randomly bisection the database
|
Randomly bisection the database
|
||||||
'''
|
"""
|
||||||
results = app._db.fetch(where_filter, sort, limit)
|
results = app._db.fetch(where_filter, sort, limit)
|
||||||
rand = rand_bisectional_subsample(results)
|
rand = rand_bisectional_subsample(results)
|
||||||
return rand.to_dict(orient="records")
|
return rand.to_dict(orient="records")
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/v1/get_nearest_neighbors")
|
@app.post("/api/v1/get_nearest_neighbors")
|
||||||
async def get_nearest_neighbors(embedding: QueryEmbedding):
|
async def get_nearest_neighbors(embedding: QueryEmbedding):
|
||||||
'''
|
"""
|
||||||
return the distance, database ids, and embedding themselves for the input embedding
|
return the distance, database ids, and embedding themselves for the input embedding
|
||||||
'''
|
"""
|
||||||
ids = None
|
ids = None
|
||||||
filter_by_where = {}
|
filter_by_where = {}
|
||||||
if embedding.category_name is not None:
|
if embedding.category_name is not None:
|
||||||
filter_by_where['category_name'] = embedding.category_name
|
filter_by_where["category_name"] = embedding.category_name
|
||||||
if embedding.dataset is not None:
|
if embedding.dataset is not None:
|
||||||
filter_by_where['dataset'] = embedding.dataset
|
filter_by_where["dataset"] = embedding.dataset
|
||||||
|
|
||||||
if filter_by_where is not None:
|
if filter_by_where is not None:
|
||||||
ids = app._db.fetch(filter_by_where)["id"].tolist()
|
ids = app._db.fetch(filter_by_where)["id"].tolist()
|
||||||
|
|
||||||
nn = app._ann_index.get_nearest_neighbors(embedding.embedding, embedding.n_results, ids)
|
nn = app._ann_index.get_nearest_neighbors(embedding.embedding, embedding.n_results, ids)
|
||||||
return {
|
return {
|
||||||
"ids": nn[0].tolist()[0],
|
"ids": nn[0].tolist()[0],
|
||||||
"embeddings": app._db.get_by_ids(nn[0].tolist()[0]).to_dict(orient="records"),
|
"embeddings": app._db.get_by_ids(nn[0].tolist()[0]).to_dict(orient="records"),
|
||||||
"distances": nn[1].tolist()[0]
|
"distances": nn[1].tolist()[0],
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
|
||||||
class Database():
|
|
||||||
|
class Database:
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
@@ -23,4 +24,4 @@ class Database():
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load(self):
|
def load(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -4,12 +4,14 @@ import duckdb
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
class DuckDB(Database):
|
class DuckDB(Database):
|
||||||
_conn = None
|
_conn = None
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._conn = duckdb.connect()
|
self._conn = duckdb.connect()
|
||||||
self._conn.execute('''
|
self._conn.execute(
|
||||||
|
"""
|
||||||
CREATE TABLE embeddings (
|
CREATE TABLE embeddings (
|
||||||
id integer PRIMARY KEY,
|
id integer PRIMARY KEY,
|
||||||
embedding_data REAL[],
|
embedding_data REAL[],
|
||||||
@@ -18,32 +20,39 @@ class DuckDB(Database):
|
|||||||
custom_quality_score REAL,
|
custom_quality_score REAL,
|
||||||
category_name STRING
|
category_name STRING
|
||||||
)
|
)
|
||||||
''')
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
# ids to manage internal bookkeeping and *nothing else*, users should not have to care about these ids
|
# ids to manage internal bookkeeping and *nothing else*, users should not have to care about these ids
|
||||||
self._conn.execute('''
|
self._conn.execute(
|
||||||
|
"""
|
||||||
CREATE SEQUENCE seq_id START 1;
|
CREATE SEQUENCE seq_id START 1;
|
||||||
''')
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
self._conn.execute('''
|
self._conn.execute(
|
||||||
|
"""
|
||||||
-- change the default null sorting order to either NULLS FIRST and NULLS LAST
|
-- change the default null sorting order to either NULLS FIRST and NULLS LAST
|
||||||
PRAGMA default_null_order='NULLS LAST';
|
PRAGMA default_null_order='NULLS LAST';
|
||||||
-- change the default sorting order to either DESC or ASC
|
-- change the default sorting order to either DESC or ASC
|
||||||
PRAGMA default_order='DESC';
|
PRAGMA default_order='DESC';
|
||||||
''')
|
"""
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
def add_batch(self, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None):
|
def add_batch(
|
||||||
'''
|
self, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
Add embeddings to the database
|
Add embeddings to the database
|
||||||
This accepts both a single input and a list of inputs
|
This accepts both a single input and a list of inputs
|
||||||
'''
|
"""
|
||||||
|
|
||||||
# create list of the types of all inputs
|
# create list of the types of all inputs
|
||||||
types = [type(x).__name__ for x in [embedding_data, input_uri]]
|
types = [type(x).__name__ for x in [embedding_data, input_uri]]
|
||||||
|
|
||||||
# if all of the types are 'list' - do batch mode
|
# if all of the types are 'list' - do batch mode
|
||||||
if all(x == 'list' for x in types):
|
if all(x == "list" for x in types):
|
||||||
lengths = [len(x) for x in [embedding_data, input_uri]]
|
lengths = [len(x) for x in [embedding_data, input_uri]]
|
||||||
|
|
||||||
# accepts some inputs as str or none, and this multiples them out to the correct length
|
# accepts some inputs as str or none, and this multiples them out to the correct length
|
||||||
@@ -57,41 +66,56 @@ class DuckDB(Database):
|
|||||||
# we have to move from column to row format for duckdb
|
# we have to move from column to row format for duckdb
|
||||||
data_to_insert = []
|
data_to_insert = []
|
||||||
for i in range(lengths[0]):
|
for i in range(lengths[0]):
|
||||||
data_to_insert.append([embedding_data[i], input_uri[i], dataset[i], custom_quality_score[i], category_name[i]])
|
data_to_insert.append(
|
||||||
|
[
|
||||||
|
embedding_data[i],
|
||||||
|
input_uri[i],
|
||||||
|
dataset[i],
|
||||||
|
custom_quality_score[i],
|
||||||
|
category_name[i],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if all(x == lengths[0] for x in lengths):
|
if all(x == lengths[0] for x in lengths):
|
||||||
self._conn.executemany('''
|
self._conn.executemany(
|
||||||
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)''',
|
"""
|
||||||
data_to_insert
|
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)""",
|
||||||
|
data_to_insert,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# if any of the types are 'list' - throw an error
|
# if any of the types are 'list' - throw an error
|
||||||
if any(x == list for x in [input_uri, dataset, custom_quality_score, category_name]):
|
if any(x == list for x in [input_uri, dataset, custom_quality_score, category_name]):
|
||||||
raise Exception("Invalid input types. One input is a list where others are not: " + str(types))
|
raise Exception(
|
||||||
|
"Invalid input types. One input is a list where others are not: " + str(types)
|
||||||
|
)
|
||||||
|
|
||||||
# single insert mode
|
# single insert mode
|
||||||
# This should never fire because we do everything in batch mode, but given the mode away from duckdb likely, I am just leaving it in
|
# This should never fire because we do everything in batch mode, but given the mode away from duckdb likely, I am just leaving it in
|
||||||
self._conn.execute('''
|
self._conn.execute(
|
||||||
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)''',
|
"""
|
||||||
[embedding_data, input_uri, dataset, custom_quality_score, category_name]
|
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)""",
|
||||||
|
[embedding_data, input_uri, dataset, custom_quality_score, category_name],
|
||||||
)
|
)
|
||||||
|
|
||||||
def count(self):
|
|
||||||
return self._conn.execute('''
|
|
||||||
SELECT COUNT(*) FROM embeddings;
|
|
||||||
''').fetchone()[0]
|
|
||||||
|
|
||||||
def update(self, data): # call this update_custom_quality_score! that is all it does
|
def count(self):
|
||||||
'''
|
return self._conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT COUNT(*) FROM embeddings;
|
||||||
|
"""
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
def update(self, data): # call this update_custom_quality_score! that is all it does
|
||||||
|
"""
|
||||||
I was not able to figure out (yet) how to do a bulk update in duckdb
|
I was not able to figure out (yet) how to do a bulk update in duckdb
|
||||||
This is going to be fairly slow
|
This is going to be fairly slow
|
||||||
'''
|
"""
|
||||||
for element in data:
|
for element in data:
|
||||||
if element['custom_quality_score'] is None:
|
if element["custom_quality_score"] is None:
|
||||||
continue
|
continue
|
||||||
self._conn.execute(f'''
|
self._conn.execute(
|
||||||
UPDATE embeddings SET custom_quality_score={element['custom_quality_score']} WHERE id={element['id']}'''
|
f"""
|
||||||
|
UPDATE embeddings SET custom_quality_score={element['custom_quality_score']} WHERE id={element['id']}"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def fetch(self, where_filter={}, sort=None, limit=None):
|
def fetch(self, where_filter={}, sort=None, limit=None):
|
||||||
@@ -99,12 +123,12 @@ class DuckDB(Database):
|
|||||||
if where_filter is not None:
|
if where_filter is not None:
|
||||||
if not isinstance(where_filter, dict):
|
if not isinstance(where_filter, dict):
|
||||||
raise Exception("Invalid where_filter: " + str(where_filter))
|
raise Exception("Invalid where_filter: " + str(where_filter))
|
||||||
|
|
||||||
# ensure where_filter is a flat dict
|
# ensure where_filter is a flat dict
|
||||||
for key in where_filter:
|
for key in where_filter:
|
||||||
if isinstance(where_filter[key], dict):
|
if isinstance(where_filter[key], dict):
|
||||||
raise Exception("Invalid where_filter: " + str(where_filter))
|
raise Exception("Invalid where_filter: " + str(where_filter))
|
||||||
|
|
||||||
where_filter = " AND ".join([f"{key} = '{value}'" for key, value in where_filter.items()])
|
where_filter = " AND ".join([f"{key} = '{value}'" for key, value in where_filter.items()])
|
||||||
|
|
||||||
if where_filter:
|
if where_filter:
|
||||||
@@ -116,7 +140,9 @@ class DuckDB(Database):
|
|||||||
if limit is not None or isinstance(limit, int):
|
if limit is not None or isinstance(limit, int):
|
||||||
where_filter += f" LIMIT {limit}"
|
where_filter += f" LIMIT {limit}"
|
||||||
|
|
||||||
return self._conn.execute(f'''
|
return (
|
||||||
|
self._conn.execute(
|
||||||
|
f"""
|
||||||
SELECT
|
SELECT
|
||||||
id,
|
id,
|
||||||
embedding_data,
|
embedding_data,
|
||||||
@@ -127,41 +153,49 @@ class DuckDB(Database):
|
|||||||
FROM
|
FROM
|
||||||
embeddings
|
embeddings
|
||||||
{where_filter}
|
{where_filter}
|
||||||
''').fetchdf().replace({np.nan: None}) # replace nan with None for json serialization
|
"""
|
||||||
|
)
|
||||||
|
.fetchdf()
|
||||||
|
.replace({np.nan: None})
|
||||||
|
) # replace nan with None for json serialization
|
||||||
|
|
||||||
def delete_batch(self, batch):
|
def delete_batch(self, batch):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def persist(self):
|
def persist(self):
|
||||||
'''
|
"""
|
||||||
Persist the database to disk
|
Persist the database to disk
|
||||||
'''
|
"""
|
||||||
if self._conn is None:
|
if self._conn is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._conn.execute('''
|
self._conn.execute(
|
||||||
|
"""
|
||||||
COPY
|
COPY
|
||||||
(SELECT * FROM embeddings)
|
(SELECT * FROM embeddings)
|
||||||
TO '.chroma/chroma.parquet'
|
TO '.chroma/chroma.parquet'
|
||||||
(FORMAT PARQUET);
|
(FORMAT PARQUET);
|
||||||
''')
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
def load(self, path=".chroma/chroma.parquet"):
|
def load(self, path=".chroma/chroma.parquet"):
|
||||||
'''
|
"""
|
||||||
Load the database from disk
|
Load the database from disk
|
||||||
'''
|
"""
|
||||||
self._conn.execute(f"INSERT INTO embeddings SELECT * FROM read_parquet('{path}');")
|
self._conn.execute(f"INSERT INTO embeddings SELECT * FROM read_parquet('{path}');")
|
||||||
|
|
||||||
def get_by_ids(self, ids=list):
|
def get_by_ids(self, ids=list):
|
||||||
# select from duckdb table where ids are in the list
|
# select from duckdb table where ids are in the list
|
||||||
if not isinstance(ids, list):
|
if not isinstance(ids, list):
|
||||||
raise Exception("ids must be a list")
|
raise Exception("ids must be a list")
|
||||||
|
|
||||||
if not ids:
|
if not ids:
|
||||||
# create an empty pandas dataframe
|
# create an empty pandas dataframe
|
||||||
return pd.DataFrame()
|
return pd.DataFrame()
|
||||||
|
|
||||||
return self._conn.execute(f'''
|
return (
|
||||||
|
self._conn.execute(
|
||||||
|
f"""
|
||||||
SELECT
|
SELECT
|
||||||
id,
|
id,
|
||||||
embedding_data,
|
embedding_data,
|
||||||
@@ -173,4 +207,8 @@ class DuckDB(Database):
|
|||||||
embeddings
|
embeddings
|
||||||
WHERE
|
WHERE
|
||||||
id IN ({','.join([str(x) for x in ids])})
|
id IN ({','.join([str(x) for x in ids])})
|
||||||
''').fetchdf().replace({np.nan: None}) # replace nan with None for json serialization
|
"""
|
||||||
|
)
|
||||||
|
.fetchdf()
|
||||||
|
.replace({np.nan: None})
|
||||||
|
) # replace nan with None for json serialization
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
|
||||||
class Index():
|
|
||||||
|
class Index:
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
@@ -23,4 +24,4 @@ class Index():
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load(self):
|
def load(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import numpy as np
|
|||||||
from chroma_server.index.abstract import Index
|
from chroma_server.index.abstract import Index
|
||||||
from chroma_server.utils import logger
|
from chroma_server.utils import logger
|
||||||
|
|
||||||
|
|
||||||
class Hnswlib(Index):
|
class Hnswlib(Index):
|
||||||
|
|
||||||
_index = None
|
_index = None
|
||||||
@@ -14,20 +15,22 @@ class Hnswlib(Index):
|
|||||||
# more comments available at the source: https://github.com/nmslib/hnswlib
|
# more comments available at the source: https://github.com/nmslib/hnswlib
|
||||||
|
|
||||||
# We split the data in two batches:
|
# We split the data in two batches:
|
||||||
data1 = embedding_data['embedding_data'].to_numpy().tolist()
|
data1 = embedding_data["embedding_data"].to_numpy().tolist()
|
||||||
dim = len(data1[0])
|
dim = len(data1[0])
|
||||||
num_elements = len(data1)
|
num_elements = len(data1)
|
||||||
# logger.debug("dimensionality is:", dim)
|
# logger.debug("dimensionality is:", dim)
|
||||||
# logger.debug("total number of elements is:", num_elements)
|
# logger.debug("total number of elements is:", num_elements)
|
||||||
# logger.debug("max elements", num_elements//2)
|
# logger.debug("max elements", num_elements//2)
|
||||||
|
|
||||||
concatted_data = data1
|
concatted_data = data1
|
||||||
# logger.debug("concatted_data", len(concatted_data))
|
# logger.debug("concatted_data", len(concatted_data))
|
||||||
|
|
||||||
p = hnswlib.Index(space='l2', dim=dim) # # Declaring index, possible options are l2, cosine or ip
|
p = hnswlib.Index(
|
||||||
p.init_index(max_elements=len(data1), ef_construction=100, M=16) # Initing index
|
space="l2", dim=dim
|
||||||
|
) # # Declaring index, possible options are l2, cosine or ip
|
||||||
|
p.init_index(max_elements=len(data1), ef_construction=100, M=16) # Initing index
|
||||||
p.set_ef(10) # Controlling the recall by setting ef:
|
p.set_ef(10) # Controlling the recall by setting ef:
|
||||||
p.set_num_threads(4) # Set number of threads used during batch search/construction
|
p.set_num_threads(4) # Set number of threads used during batch search/construction
|
||||||
|
|
||||||
# logger.debug("Adding first batch of elements", (len(data1)))
|
# logger.debug("Adding first batch of elements", (len(data1)))
|
||||||
p.add_items(data1, embedding_data["id"])
|
p.add_items(data1, embedding_data["id"])
|
||||||
@@ -37,12 +40,15 @@ class Hnswlib(Index):
|
|||||||
# logger.debug("database_ids", database_ids)
|
# logger.debug("database_ids", database_ids)
|
||||||
# logger.debug("distances", distances)
|
# logger.debug("distances", distances)
|
||||||
# logger.debug(len(distances))
|
# logger.debug(len(distances))
|
||||||
logger.debug("Recall for the first batch:" + str(np.mean(database_ids.reshape(-1) == np.arange(len(data1)))))
|
logger.debug(
|
||||||
|
"Recall for the first batch:"
|
||||||
|
+ str(np.mean(database_ids.reshape(-1) == np.arange(len(data1))))
|
||||||
|
)
|
||||||
|
|
||||||
self._index = p
|
self._index = p
|
||||||
|
|
||||||
def fetch(self, query):
|
def fetch(self, query):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def delete_batch(self, batch):
|
def delete_batch(self, batch):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -51,12 +57,12 @@ class Hnswlib(Index):
|
|||||||
if self._index is None:
|
if self._index is None:
|
||||||
return
|
return
|
||||||
self._index.save_index(".chroma/index.bin")
|
self._index.save_index(".chroma/index.bin")
|
||||||
logger.debug('Index saved to .chroma/index.bin')
|
logger.debug("Index saved to .chroma/index.bin")
|
||||||
|
|
||||||
def load(self, elements, dimensionality):
|
def load(self, elements, dimensionality):
|
||||||
p = hnswlib.Index(space='l2', dim= dimensionality)
|
p = hnswlib.Index(space="l2", dim=dimensionality)
|
||||||
self._index = p
|
self._index = p
|
||||||
self._index.load_index(".chroma/index.bin", max_elements= elements)
|
self._index.load_index(".chroma/index.bin", max_elements=elements)
|
||||||
|
|
||||||
# do knn_query on hnswlib to get nearest neighbors
|
# do knn_query on hnswlib to get nearest neighbors
|
||||||
def get_nearest_neighbors(self, query, k, ids=None):
|
def get_nearest_neighbors(self, query, k, ids=None):
|
||||||
|
|||||||
@@ -4,32 +4,45 @@ from httpx import AsyncClient
|
|||||||
|
|
||||||
from ..api import app
|
from ..api import app
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def anyio_backend():
|
def anyio_backend():
|
||||||
return 'asyncio'
|
return "asyncio"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_root():
|
async def test_root():
|
||||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||||
response = await ac.get("/api/v1")
|
response = await ac.get("/api/v1")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert abs(response.json()["nanosecond heartbeat"] - int(1000 * time.time_ns())) < 3_000_000_000 # a billion nanoseconds = 3s
|
assert (
|
||||||
|
abs(response.json()["nanosecond heartbeat"] - int(1000 * time.time_ns())) < 3_000_000_000
|
||||||
|
) # a billion nanoseconds = 3s
|
||||||
|
|
||||||
|
|
||||||
async def post_one_record(ac):
|
async def post_one_record(ac):
|
||||||
return await ac.post("/api/v1/add", json={
|
return await ac.post(
|
||||||
"embedding_data": [1.02, 2.03, 3.03],
|
"/api/v1/add",
|
||||||
"input_uri": "https://example.com",
|
json={
|
||||||
"dataset": "coco",
|
"embedding_data": [1.02, 2.03, 3.03],
|
||||||
"category_name": "person"
|
"input_uri": "https://example.com",
|
||||||
})
|
"dataset": "coco",
|
||||||
|
"category_name": "person",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def post_batch_records(ac):
|
async def post_batch_records(ac):
|
||||||
return await ac.post("/api/v1/add", json={
|
return await ac.post(
|
||||||
"embedding_data": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
|
"/api/v1/add",
|
||||||
"input_uri": ["https://example.com", "https://example.com"],
|
json={
|
||||||
"dataset": "training",
|
"embedding_data": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
|
||||||
"category_name": "person"
|
"input_uri": ["https://example.com", "https://example.com"],
|
||||||
})
|
"dataset": "training",
|
||||||
|
"category_name": "person",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_add_to_db():
|
async def test_add_to_db():
|
||||||
@@ -38,6 +51,7 @@ async def test_add_to_db():
|
|||||||
assert response.status_code == 201
|
assert response.status_code == 201
|
||||||
assert response.json() == {"response": "Added record to database"}
|
assert response.json() == {"response": "Added record to database"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_add_to_db_batch():
|
async def test_add_to_db_batch():
|
||||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||||
@@ -45,6 +59,7 @@ async def test_add_to_db_batch():
|
|||||||
assert response.status_code == 201
|
assert response.status_code == 201
|
||||||
assert response.json() == {"response": "Added record to database"}
|
assert response.json() == {"response": "Added record to database"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_fetch_from_db():
|
async def test_fetch_from_db():
|
||||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||||
@@ -53,15 +68,17 @@ async def test_fetch_from_db():
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert len(response.json()) == 1
|
assert len(response.json()) == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_count_from_db():
|
async def test_count_from_db():
|
||||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||||
await ac.get("/api/v1/reset") # reset db
|
await ac.get("/api/v1/reset") # reset db
|
||||||
await post_batch_records(ac)
|
await post_batch_records(ac)
|
||||||
response = await ac.get("/api/v1/count")
|
response = await ac.get("/api/v1/count")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {"count": 2}
|
assert response.json() == {"count": 2}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_reset_db():
|
async def test_reset_db():
|
||||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||||
@@ -74,25 +91,19 @@ async def test_reset_db():
|
|||||||
response = await ac.get("/api/v1/count")
|
response = await ac.get("/api/v1/count")
|
||||||
assert response.json() == {"count": 0}
|
assert response.json() == {"count": 0}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_nearest_neighbors():
|
async def test_get_nearest_neighbors():
|
||||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||||
await ac.get("/api/v1/reset")
|
await ac.get("/api/v1/reset")
|
||||||
await post_batch_records(ac)
|
await post_batch_records(ac)
|
||||||
await ac.get("/api/v1/process")
|
await ac.get("/api/v1/process")
|
||||||
response = await ac.post("/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1})
|
response = await ac.post(
|
||||||
|
"/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1}
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert len(response.json()["ids"]) == 1
|
assert len(response.json()["ids"]) == 1
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_get_nearest_neighbors_filter():
|
|
||||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
|
||||||
await ac.get("/api/v1/reset")
|
|
||||||
await post_batch_records(ac)
|
|
||||||
await ac.get("/api/v1/process")
|
|
||||||
response = await ac.post("/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1, "dataset": "training", "category_name": "monkey"})
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert len(response.json()["ids"]) == 0
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_nearest_neighbors_filter():
|
async def test_get_nearest_neighbors_filter():
|
||||||
@@ -100,7 +111,34 @@ async def test_get_nearest_neighbors_filter():
|
|||||||
await ac.get("/api/v1/reset")
|
await ac.get("/api/v1/reset")
|
||||||
await post_batch_records(ac)
|
await post_batch_records(ac)
|
||||||
await ac.get("/api/v1/process")
|
await ac.get("/api/v1/process")
|
||||||
response = await ac.post("/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 2, "dataset": "training", "category_name": "person"})
|
response = await ac.post(
|
||||||
|
"/api/v1/get_nearest_neighbors",
|
||||||
|
json={
|
||||||
|
"embedding": [1.1, 2.3, 3.2],
|
||||||
|
"n_results": 1,
|
||||||
|
"dataset": "training",
|
||||||
|
"category_name": "monkey",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert len(response.json()["ids"]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_nearest_neighbors_filter():
|
||||||
|
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||||
|
await ac.get("/api/v1/reset")
|
||||||
|
await post_batch_records(ac)
|
||||||
|
await ac.get("/api/v1/process")
|
||||||
|
response = await ac.post(
|
||||||
|
"/api/v1/get_nearest_neighbors",
|
||||||
|
json={
|
||||||
|
"embedding": [1.1, 2.3, 3.2],
|
||||||
|
"n_results": 2,
|
||||||
|
"dataset": "training",
|
||||||
|
"category_name": "person",
|
||||||
|
},
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert len(response.json()["ids"]) == 2
|
assert len(response.json()["ids"]) == 2
|
||||||
|
|
||||||
@@ -118,4 +156,4 @@ async def test_get_nearest_neighbors_filter():
|
|||||||
|
|
||||||
# Purposefully untested
|
# Purposefully untested
|
||||||
# - process
|
# - process
|
||||||
# - rand
|
# - rand
|
||||||
|
|||||||
@@ -6,9 +6,10 @@ class AddEmbedding(BaseModel):
|
|||||||
embedding_data: list
|
embedding_data: list
|
||||||
input_uri: Union[str, list]
|
input_uri: Union[str, list]
|
||||||
dataset: Union[str, list] = None
|
dataset: Union[str, list] = None
|
||||||
custom_quality_score: Union[float, list] = None
|
custom_quality_score: Union[float, list] = None
|
||||||
category_name: Union[str, list] = None
|
category_name: Union[str, list] = None
|
||||||
|
|
||||||
|
|
||||||
class QueryEmbedding(BaseModel):
|
class QueryEmbedding(BaseModel):
|
||||||
embedding: list
|
embedding: list
|
||||||
n_results: int = 10
|
n_results: int = 10
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
def setup_logging():
|
def setup_logging():
|
||||||
logging.basicConfig(filename="chroma_logs.log")
|
logging.basicConfig(filename="chroma_logs.log")
|
||||||
logger = logging.getLogger("Chroma")
|
logger = logging.getLogger("Chroma")
|
||||||
@@ -7,4 +8,5 @@ def setup_logging():
|
|||||||
logger.debug("Logger created")
|
logger.debug("Logger created")
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
logger = setup_logging()
|
|
||||||
|
logger = setup_logging()
|
||||||
|
|||||||
8
pyproject.toml
Normal file
8
pyproject.toml
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
[tool.black]
|
||||||
|
line-length = 100
|
||||||
|
|
||||||
|
# Black will refuse to run if it's not this version.
|
||||||
|
required-version = "22.6.0"
|
||||||
|
|
||||||
|
# Ensure black's output will be compatible with all listed versions.
|
||||||
|
target-version = ['py36', 'py37', 'py38', 'py39', 'py310']
|
||||||
Reference in New Issue
Block a user