black python formatting

This commit is contained in:
Jeffrey Huber
2022-10-31 12:42:34 -07:00
parent 6b87b2ed48
commit 30fbd06bb1
14 changed files with 317 additions and 191 deletions

5
Makefile Normal file
View File

@@ -0,0 +1,5 @@
black:
black --fast chroma-server chroma-client
check_black:
black --check --fast chroma-server chroma-client

View File

@@ -2,6 +2,7 @@ import requests
import json
from typing import Union
class Chroma:
_api_url = "http://localhost:8000/api/v1"
@@ -15,127 +16,142 @@ class Chroma:
self.url = url
def count(self):
'''
"""
Returns the number of embeddings in the database
'''
"""
x = requests.get(self._api_url + "/count")
return x.json()
def fetch(self, where_filter={}, sort=None, limit=None):
'''
"""
Fetches embeddings from the database
'''
x = requests.get(self._api_url + "/fetch", data=json.dumps({
"where_filter":json.dumps(where_filter),
"sort":sort,
"limit":limit
}))
"""
x = requests.get(
self._api_url + "/fetch",
data=json.dumps(
{"where_filter": json.dumps(where_filter), "sort": sort, "limit": limit}
),
)
return x.json()
def process(self):
'''
"""
Processes embeddings in the database
- currently this only runs hnswlib, doesnt return anything
'''
"""
requests.get(self._api_url + "/process")
return True
def reset(self):
'''
"""
Resets the database
'''
"""
return requests.get(self._api_url + "/reset")
def persist(self):
'''
"""
Persists the database to disk in the .chroma folder inside chroma-server
'''
"""
return requests.get(self._api_url + "/persist")
def rand(self):
'''
"""
Stubbed out sampling endpoint, returns a random bisection of the database
'''
"""
x = requests.get(self._api_url + "/rand")
return x.json()
def heartbeat(self):
'''
"""
Returns the current server time in milliseconds to check if the server is alive
'''
"""
x = requests.get(self._api_url)
return x.json()
def log(self,
embedding_data: list,
input_uri: list,
def log(
self,
embedding_data: list,
input_uri: list,
dataset: list = None,
category_name: list = None):
'''
category_name: list = None,
):
"""
Logs a batch of embeddings to the database
- pass in column oriented data lists
'''
"""
x = requests.post(self._api_url + "/add", data = json.dumps({
"embedding_data": embedding_data,
"input_uri": input_uri,
"dataset": dataset,
"category_name": category_name
}) )
x = requests.post(
self._api_url + "/add",
data=json.dumps(
{
"embedding_data": embedding_data,
"input_uri": input_uri,
"dataset": dataset,
"category_name": category_name,
}
),
)
if x.status_code == 201:
return True
else:
return False
def log_training(self, embedding_data: list, input_uri: list, category_name: list):
'''
"""
Small wrapper around log() to log a batch of training embedding
- sets dataset to "training"
'''
"""
return self.log(
embedding_data=embedding_data,
input_uri=input_uri,
embedding_data=embedding_data,
input_uri=input_uri,
dataset="training",
category_name=category_name
category_name=category_name,
)
def log_production(self, embedding_data: list, input_uri: list, category_name: list):
'''
"""
Small wrapper around log() to log a batch of production embedding
- sets dataset to "production"
'''
"""
return self.log(
embedding_data=embedding_data,
input_uri=input_uri,
embedding_data=embedding_data,
input_uri=input_uri,
dataset="production",
category_name=category_name
category_name=category_name,
)
def log_triage(self, embedding_data: list, input_uri: list, category_name: list):
'''
"""
Small wrapper around log() to log a batch of triage embedding
- sets dataset to "triage"
'''
"""
return self.log(
embedding_data=embedding_data,
input_uri=input_uri,
embedding_data=embedding_data,
input_uri=input_uri,
dataset="triage",
category_name=category_name
category_name=category_name,
)
def get_nearest_neighbors(self, embedding, n_results=10, category_name=None, dataset="training"):
'''
def get_nearest_neighbors(
self, embedding, n_results=10, category_name=None, dataset="training"
):
"""
Gets the nearest neighbors of a single embedding
'''
x = requests.post(self._api_url + "/get_nearest_neighbors", data = json.dumps({
"embedding": embedding,
"n_results": n_results,
"category_name": category_name,
"dataset": dataset
}) )
"""
x = requests.post(
self._api_url + "/get_nearest_neighbors",
data=json.dumps(
{
"embedding": embedding,
"n_results": n_results,
"category_name": category_name,
"dataset": dataset,
}
),
)
if x.status_code == 200:
return x.json()
else:
return False
return False

View File

@@ -4,11 +4,14 @@ from chroma_client import Chroma
import pytest
import time
from httpx import AsyncClient
# from ..api import app # this wont work because i moved the file
@pytest.fixture
def anyio_backend():
return 'asyncio'
return "asyncio"
def test_init():
chroma = Chroma()
@@ -25,4 +28,3 @@ def test_init():
# chroma = Chroma(url="http://test/api/v1")
# response = await chroma.count()
# raise Exception("response" + response)

View File

@@ -1,8 +1,8 @@
import random
def rand_bisectional_subsample(data):
"""
Randomly bisectionally subsample a list of data to size.
"""
return data.sample(frac=0.5, replace=True, random_state=1)
return data.sample(frac=0.5, replace=True, random_state=1)

View File

@@ -2,22 +2,23 @@ import numpy as np
import json
import ast
def class_distances(data):
''''
"""'
This is all very subject to change, so essentially just copy and paste from what we had before
'''
"""
return False
# def unpack_annotations(embeddings):
# annotations = [json.loads(embedding['infer'])["annotations"]for embedding in embeddings]
# annotations = [annotation for annotation_list in annotations for annotation in annotation_list]
# annotations = [annotation for annotation_list in annotations for annotation in annotation_list]
# # Unpack embedding data
# embeddings = [embedding["embedding_data"] for embedding in embeddings]
# embedding_vectors_by_category = {}
# for embedding_annotation_pair in zip(embeddings, annotations):
# data = np.array(embedding_annotation_pair[0])
# category = embedding_annotation_pair[1]['category_id']
# category = embedding_annotation_pair[1]['category_id']
# if category in embedding_vectors_by_category.keys():
# embedding_vectors_by_category[category] = np.append(
# embedding_vectors_by_category[category], data[np.newaxis, :], axis=0
@@ -84,5 +85,5 @@ def class_distances(data):
# if (len(inferences) == 0):
# raise Exception("No inferences found for datapoint")
# return output_distances
# return output_distances

View File

@@ -11,7 +11,6 @@ from chroma_server.types import AddEmbedding, QueryEmbedding
from chroma_server.utils import logger
# Boot script
db = DuckDB
ann_index = Hnswlib
@@ -34,104 +33,112 @@ if os.path.exists(".chroma/index.bin"):
app._ann_index.load(app._db.count(), len(app._db.fetch(limit=1).embedding_data))
# API Endpoints
@app.get("/api/v1")
async def root():
'''
"""
Heartbeat endpoint
'''
"""
return {"nanosecond heartbeat": int(1000 * time.time_ns())}
@app.post("/api/v1/add", status_code=status.HTTP_201_CREATED)
async def add_to_db(new_embedding: AddEmbedding):
'''
"""
Save embedding to database
- supports single or batched embeddings
'''
"""
app._db.add_batch(
new_embedding.embedding_data,
new_embedding.input_uri,
new_embedding.embedding_data,
new_embedding.input_uri,
new_embedding.dataset,
new_embedding.custom_quality_score,
new_embedding.category_name
)
new_embedding.custom_quality_score,
new_embedding.category_name,
)
return {"response": "Added record to database"}
@app.get("/api/v1/process")
async def process():
'''
"""
Currently generates an index for the embedding db
'''
"""
app._ann_index.run(app._db.fetch())
@app.get("/api/v1/fetch")
async def fetch(where_filter={}, sort=None, limit=None):
'''
"""
Fetches embeddings from the database
- enables filtering by where_filter, sorting by key, and limiting the number of results
'''
"""
return app._db.fetch(where_filter, sort, limit).to_dict(orient="records")
@app.get("/api/v1/count")
async def count():
'''
"""
Returns the number of records in the database
'''
return ({"count": app._db.count()})
"""
return {"count": app._db.count()}
@app.get("/api/v1/persist")
async def persist():
'''
"""
Persist the database and index to disk
'''
"""
if not os.path.exists(".chroma"):
os.mkdir(".chroma")
app._db.persist()
app._ann_index.persist()
return True
@app.get("/api/v1/reset")
async def reset():
'''
"""
Reset the database and index
'''
"""
shutil.rmtree(".chroma", ignore_errors=True)
app._db = db()
app._ann_index = ann_index()
return True
@app.get("/api/v1/rand")
async def rand(where_filter={}, sort=None, limit=None):
'''
"""
Randomly bisection the database
'''
"""
results = app._db.fetch(where_filter, sort, limit)
rand = rand_bisectional_subsample(results)
return rand.to_dict(orient="records")
@app.post("/api/v1/get_nearest_neighbors")
async def get_nearest_neighbors(embedding: QueryEmbedding):
'''
"""
return the distance, database ids, and embedding themselves for the input embedding
'''
"""
ids = None
filter_by_where = {}
if embedding.category_name is not None:
filter_by_where['category_name'] = embedding.category_name
filter_by_where["category_name"] = embedding.category_name
if embedding.dataset is not None:
filter_by_where['dataset'] = embedding.dataset
filter_by_where["dataset"] = embedding.dataset
if filter_by_where is not None:
ids = app._db.fetch(filter_by_where)["id"].tolist()
nn = app._ann_index.get_nearest_neighbors(embedding.embedding, embedding.n_results, ids)
return {
"ids": nn[0].tolist()[0],
"embeddings": app._db.get_by_ids(nn[0].tolist()[0]).to_dict(orient="records"),
"distances": nn[1].tolist()[0]
}
"distances": nn[1].tolist()[0],
}

View File

@@ -1,6 +1,7 @@
from abc import abstractmethod
class Database():
class Database:
@abstractmethod
def __init__(self):
pass
@@ -23,4 +24,4 @@ class Database():
@abstractmethod
def load(self):
pass
pass

View File

@@ -4,12 +4,14 @@ import duckdb
import numpy as np
import pandas as pd
class DuckDB(Database):
_conn = None
def __init__(self):
self._conn = duckdb.connect()
self._conn.execute('''
self._conn.execute(
"""
CREATE TABLE embeddings (
id integer PRIMARY KEY,
embedding_data REAL[],
@@ -18,32 +20,39 @@ class DuckDB(Database):
custom_quality_score REAL,
category_name STRING
)
''')
"""
)
# ids to manage internal bookkeeping and *nothing else*, users should not have to care about these ids
self._conn.execute('''
self._conn.execute(
"""
CREATE SEQUENCE seq_id START 1;
''')
"""
)
self._conn.execute('''
self._conn.execute(
"""
-- change the default null sorting order to either NULLS FIRST and NULLS LAST
PRAGMA default_null_order='NULLS LAST';
-- change the default sorting order to either DESC or ASC
PRAGMA default_order='DESC';
''')
"""
)
return
def add_batch(self, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None):
'''
def add_batch(
self, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None
):
"""
Add embeddings to the database
This accepts both a single input and a list of inputs
'''
"""
# create list of the types of all inputs
types = [type(x).__name__ for x in [embedding_data, input_uri]]
# if all of the types are 'list' - do batch mode
if all(x == 'list' for x in types):
if all(x == "list" for x in types):
lengths = [len(x) for x in [embedding_data, input_uri]]
# accepts some inputs as str or none, and this multiples them out to the correct length
@@ -57,41 +66,56 @@ class DuckDB(Database):
# we have to move from column to row format for duckdb
data_to_insert = []
for i in range(lengths[0]):
data_to_insert.append([embedding_data[i], input_uri[i], dataset[i], custom_quality_score[i], category_name[i]])
data_to_insert.append(
[
embedding_data[i],
input_uri[i],
dataset[i],
custom_quality_score[i],
category_name[i],
]
)
if all(x == lengths[0] for x in lengths):
self._conn.executemany('''
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)''',
data_to_insert
self._conn.executemany(
"""
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)""",
data_to_insert,
)
return
# if any of the types are 'list' - throw an error
if any(x == list for x in [input_uri, dataset, custom_quality_score, category_name]):
raise Exception("Invalid input types. One input is a list where others are not: " + str(types))
raise Exception(
"Invalid input types. One input is a list where others are not: " + str(types)
)
# single insert mode
# This should never fire because we do everything in batch mode, but given the mode away from duckdb likely, I am just leaving it in
self._conn.execute('''
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)''',
[embedding_data, input_uri, dataset, custom_quality_score, category_name]
self._conn.execute(
"""
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)""",
[embedding_data, input_uri, dataset, custom_quality_score, category_name],
)
def count(self):
return self._conn.execute('''
SELECT COUNT(*) FROM embeddings;
''').fetchone()[0]
def update(self, data): # call this update_custom_quality_score! that is all it does
'''
def count(self):
return self._conn.execute(
"""
SELECT COUNT(*) FROM embeddings;
"""
).fetchone()[0]
def update(self, data): # call this update_custom_quality_score! that is all it does
"""
I was not able to figure out (yet) how to do a bulk update in duckdb
This is going to be fairly slow
'''
for element in data:
if element['custom_quality_score'] is None:
"""
for element in data:
if element["custom_quality_score"] is None:
continue
self._conn.execute(f'''
UPDATE embeddings SET custom_quality_score={element['custom_quality_score']} WHERE id={element['id']}'''
self._conn.execute(
f"""
UPDATE embeddings SET custom_quality_score={element['custom_quality_score']} WHERE id={element['id']}"""
)
def fetch(self, where_filter={}, sort=None, limit=None):
@@ -99,12 +123,12 @@ class DuckDB(Database):
if where_filter is not None:
if not isinstance(where_filter, dict):
raise Exception("Invalid where_filter: " + str(where_filter))
# ensure where_filter is a flat dict
for key in where_filter:
if isinstance(where_filter[key], dict):
raise Exception("Invalid where_filter: " + str(where_filter))
where_filter = " AND ".join([f"{key} = '{value}'" for key, value in where_filter.items()])
if where_filter:
@@ -116,7 +140,9 @@ class DuckDB(Database):
if limit is not None or isinstance(limit, int):
where_filter += f" LIMIT {limit}"
return self._conn.execute(f'''
return (
self._conn.execute(
f"""
SELECT
id,
embedding_data,
@@ -127,41 +153,49 @@ class DuckDB(Database):
FROM
embeddings
{where_filter}
''').fetchdf().replace({np.nan: None}) # replace nan with None for json serialization
"""
)
.fetchdf()
.replace({np.nan: None})
) # replace nan with None for json serialization
def delete_batch(self, batch):
raise NotImplementedError
def persist(self):
'''
"""
Persist the database to disk
'''
"""
if self._conn is None:
return
self._conn.execute('''
self._conn.execute(
"""
COPY
(SELECT * FROM embeddings)
TO '.chroma/chroma.parquet'
(FORMAT PARQUET);
''')
"""
)
def load(self, path=".chroma/chroma.parquet"):
'''
"""
Load the database from disk
'''
"""
self._conn.execute(f"INSERT INTO embeddings SELECT * FROM read_parquet('{path}');")
def get_by_ids(self, ids=list):
# select from duckdb table where ids are in the list
if not isinstance(ids, list):
raise Exception("ids must be a list")
if not ids:
# create an empty pandas dataframe
return pd.DataFrame()
return self._conn.execute(f'''
return (
self._conn.execute(
f"""
SELECT
id,
embedding_data,
@@ -173,4 +207,8 @@ class DuckDB(Database):
embeddings
WHERE
id IN ({','.join([str(x) for x in ids])})
''').fetchdf().replace({np.nan: None}) # replace nan with None for json serialization
"""
)
.fetchdf()
.replace({np.nan: None})
) # replace nan with None for json serialization

View File

@@ -1,6 +1,7 @@
from abc import abstractmethod
class Index():
class Index:
@abstractmethod
def __init__(self):
pass
@@ -23,4 +24,4 @@ class Index():
@abstractmethod
def load(self):
pass
pass

View File

@@ -3,6 +3,7 @@ import numpy as np
from chroma_server.index.abstract import Index
from chroma_server.utils import logger
class Hnswlib(Index):
_index = None
@@ -14,20 +15,22 @@ class Hnswlib(Index):
# more comments available at the source: https://github.com/nmslib/hnswlib
# We split the data in two batches:
data1 = embedding_data['embedding_data'].to_numpy().tolist()
data1 = embedding_data["embedding_data"].to_numpy().tolist()
dim = len(data1[0])
num_elements = len(data1)
num_elements = len(data1)
# logger.debug("dimensionality is:", dim)
# logger.debug("total number of elements is:", num_elements)
# logger.debug("max elements", num_elements//2)
concatted_data = data1
concatted_data = data1
# logger.debug("concatted_data", len(concatted_data))
p = hnswlib.Index(space='l2', dim=dim) # # Declaring index, possible options are l2, cosine or ip
p.init_index(max_elements=len(data1), ef_construction=100, M=16) # Initing index
p = hnswlib.Index(
space="l2", dim=dim
) # # Declaring index, possible options are l2, cosine or ip
p.init_index(max_elements=len(data1), ef_construction=100, M=16) # Initing index
p.set_ef(10) # Controlling the recall by setting ef:
p.set_num_threads(4) # Set number of threads used during batch search/construction
p.set_num_threads(4) # Set number of threads used during batch search/construction
# logger.debug("Adding first batch of elements", (len(data1)))
p.add_items(data1, embedding_data["id"])
@@ -37,12 +40,15 @@ class Hnswlib(Index):
# logger.debug("database_ids", database_ids)
# logger.debug("distances", distances)
# logger.debug(len(distances))
logger.debug("Recall for the first batch:" + str(np.mean(database_ids.reshape(-1) == np.arange(len(data1)))))
logger.debug(
"Recall for the first batch:"
+ str(np.mean(database_ids.reshape(-1) == np.arange(len(data1))))
)
self._index = p
def fetch(self, query):
raise NotImplementedError
raise NotImplementedError
def delete_batch(self, batch):
raise NotImplementedError
@@ -51,12 +57,12 @@ class Hnswlib(Index):
if self._index is None:
return
self._index.save_index(".chroma/index.bin")
logger.debug('Index saved to .chroma/index.bin')
logger.debug("Index saved to .chroma/index.bin")
def load(self, elements, dimensionality):
p = hnswlib.Index(space='l2', dim= dimensionality)
p = hnswlib.Index(space="l2", dim=dimensionality)
self._index = p
self._index.load_index(".chroma/index.bin", max_elements= elements)
self._index.load_index(".chroma/index.bin", max_elements=elements)
# do knn_query on hnswlib to get nearest neighbors
def get_nearest_neighbors(self, query, k, ids=None):

View File

@@ -4,32 +4,45 @@ from httpx import AsyncClient
from ..api import app
@pytest.fixture
def anyio_backend():
return 'asyncio'
return "asyncio"
@pytest.mark.anyio
async def test_root():
async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.get("/api/v1")
assert response.status_code == 200
assert abs(response.json()["nanosecond heartbeat"] - int(1000 * time.time_ns())) < 3_000_000_000 # a billion nanoseconds = 3s
assert (
abs(response.json()["nanosecond heartbeat"] - int(1000 * time.time_ns())) < 3_000_000_000
) # a billion nanoseconds = 3s
async def post_one_record(ac):
return await ac.post("/api/v1/add", json={
"embedding_data": [1.02, 2.03, 3.03],
"input_uri": "https://example.com",
"dataset": "coco",
"category_name": "person"
})
return await ac.post(
"/api/v1/add",
json={
"embedding_data": [1.02, 2.03, 3.03],
"input_uri": "https://example.com",
"dataset": "coco",
"category_name": "person",
},
)
async def post_batch_records(ac):
return await ac.post("/api/v1/add", json={
"embedding_data": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"input_uri": ["https://example.com", "https://example.com"],
"dataset": "training",
"category_name": "person"
})
return await ac.post(
"/api/v1/add",
json={
"embedding_data": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"input_uri": ["https://example.com", "https://example.com"],
"dataset": "training",
"category_name": "person",
},
)
@pytest.mark.anyio
async def test_add_to_db():
@@ -38,6 +51,7 @@ async def test_add_to_db():
assert response.status_code == 201
assert response.json() == {"response": "Added record to database"}
@pytest.mark.anyio
async def test_add_to_db_batch():
async with AsyncClient(app=app, base_url="http://test") as ac:
@@ -45,6 +59,7 @@ async def test_add_to_db_batch():
assert response.status_code == 201
assert response.json() == {"response": "Added record to database"}
@pytest.mark.anyio
async def test_fetch_from_db():
async with AsyncClient(app=app, base_url="http://test") as ac:
@@ -53,15 +68,17 @@ async def test_fetch_from_db():
assert response.status_code == 200
assert len(response.json()) == 1
@pytest.mark.anyio
async def test_count_from_db():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.get("/api/v1/reset") # reset db
await ac.get("/api/v1/reset") # reset db
await post_batch_records(ac)
response = await ac.get("/api/v1/count")
assert response.status_code == 200
assert response.json() == {"count": 2}
@pytest.mark.anyio
async def test_reset_db():
async with AsyncClient(app=app, base_url="http://test") as ac:
@@ -74,25 +91,19 @@ async def test_reset_db():
response = await ac.get("/api/v1/count")
assert response.json() == {"count": 0}
@pytest.mark.anyio
async def test_get_nearest_neighbors():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.get("/api/v1/reset")
await post_batch_records(ac)
await ac.get("/api/v1/process")
response = await ac.post("/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1})
response = await ac.post(
"/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1}
)
assert response.status_code == 200
assert len(response.json()["ids"]) == 1
@pytest.mark.anyio
async def test_get_nearest_neighbors_filter():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.get("/api/v1/reset")
await post_batch_records(ac)
await ac.get("/api/v1/process")
response = await ac.post("/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1, "dataset": "training", "category_name": "monkey"})
assert response.status_code == 200
assert len(response.json()["ids"]) == 0
@pytest.mark.anyio
async def test_get_nearest_neighbors_filter():
@@ -100,7 +111,34 @@ async def test_get_nearest_neighbors_filter():
await ac.get("/api/v1/reset")
await post_batch_records(ac)
await ac.get("/api/v1/process")
response = await ac.post("/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 2, "dataset": "training", "category_name": "person"})
response = await ac.post(
"/api/v1/get_nearest_neighbors",
json={
"embedding": [1.1, 2.3, 3.2],
"n_results": 1,
"dataset": "training",
"category_name": "monkey",
},
)
assert response.status_code == 200
assert len(response.json()["ids"]) == 0
@pytest.mark.anyio
async def test_get_nearest_neighbors_filter():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.get("/api/v1/reset")
await post_batch_records(ac)
await ac.get("/api/v1/process")
response = await ac.post(
"/api/v1/get_nearest_neighbors",
json={
"embedding": [1.1, 2.3, 3.2],
"n_results": 2,
"dataset": "training",
"category_name": "person",
},
)
assert response.status_code == 200
assert len(response.json()["ids"]) == 2
@@ -118,4 +156,4 @@ async def test_get_nearest_neighbors_filter():
# Purposefully untested
# - process
# - rand
# - rand

View File

@@ -6,9 +6,10 @@ class AddEmbedding(BaseModel):
embedding_data: list
input_uri: Union[str, list]
dataset: Union[str, list] = None
custom_quality_score: Union[float, list] = None
custom_quality_score: Union[float, list] = None
category_name: Union[str, list] = None
class QueryEmbedding(BaseModel):
embedding: list
n_results: int = 10

View File

@@ -1,5 +1,6 @@
import logging
def setup_logging():
logging.basicConfig(filename="chroma_logs.log")
logger = logging.getLogger("Chroma")
@@ -7,4 +8,5 @@ def setup_logging():
logger.debug("Logger created")
return logger
logger = setup_logging()
logger = setup_logging()

8
pyproject.toml Normal file
View File

@@ -0,0 +1,8 @@
[tool.black]
line-length = 100
# Black will refuse to run if it's not this version.
required-version = "22.6.0"
# Ensure black's output will be compatible with all listed versions.
target-version = ['py36', 'py37', 'py38', 'py39', 'py310']