This commit is contained in:
Jeffrey Huber
2022-11-06 22:08:41 -08:00
parent 8839232b75
commit e662654e38
10 changed files with 186 additions and 306 deletions

View File

@@ -6,7 +6,10 @@ class Chroma:
_api_url = "http://localhost:8000/api/v1"
def __init__(self, url=None):
# we enable the user to set the space_key in the constructor
_space_key = None
def __init__(self, url=None, app=None, model_version=None, layer=None):
"""Initialize Chroma client"""
if isinstance(url, str) and url.startswith("http"):
@@ -14,30 +17,54 @@ class Chroma:
self.url = url
def count(self):
if app and model_version and layer:
self._space_key = app + "_" + model_version + "_" + layer
def set_context(self, app, model_version, layer):
'''
Sets the context of the client
'''
self._space_key = app + "_" + model_version + "_" + layer
def set_space_key(self, space_key):
'''
Sets the space key for the client
'''
self._space_key = space_key
def get_context(self):
'''
Returns the space key
'''
return self._space_key
def count(self, space_key=None):
'''
Returns the number of embeddings in the database
'''
x = requests.get(self._api_url + "/count")
payload = json.dumps({"space_key": space_key or self._space_key})
x = requests.get(self._api_url + "/count", data=payload)
return x.json()
def fetch(self, where_filter={}, sort=None, limit=None):
'''
Fetches embeddings from the database
'''
where_filter["space_key"] = self._space_key
x = requests.get(self._api_url + "/fetch", data=json.dumps({
"where_filter":json.dumps(where_filter),
"where_filter":where_filter,
"sort":sort,
"limit":limit
}))
return x.json()
def process(self):
def process(self, space_key=None):
'''
Processes embeddings in the database
- currently this only runs hnswlib, doesnt return anything
'''
requests.get(self._api_url + "/process")
requests.get(self._api_url + "/process", data = json.dumps({"space_key": space_key or self._space_key}))
return True
def reset(self):
@@ -46,19 +73,6 @@ class Chroma:
'''
return requests.get(self._api_url + "/reset")
def persist(self):
'''
Persists the database to disk in the .chroma folder inside chroma-server
'''
return requests.get(self._api_url + "/persist")
def rand(self):
'''
Stubbed out sampling endpoint, returns a random bisection of the database
'''
x = requests.get(self._api_url + "/rand")
return x.json()
def heartbeat(self):
'''
Returns the current server time in milliseconds to check if the server is alive
@@ -70,13 +84,22 @@ class Chroma:
embedding_data: list,
input_uri: list,
dataset: list = None,
category_name: list = None):
category_name: list = None,
space_keys: list = None
):
'''
Logs a batch of embeddings to the database
- pass in column oriented data lists
'''
if not space_keys:
if isinstance(dataset, list):
space_keys = [self._space_key] * len(dataset)
else:
space_keys = self._space_key
x = requests.post(self._api_url + "/add", data = json.dumps({
"space_key": space_keys,
"embedding_data": embedding_data,
"input_uri": input_uri,
"dataset": dataset,
@@ -88,7 +111,7 @@ class Chroma:
else:
return False
def log_training(self, embedding_data: list, input_uri: list, category_name: list):
def log_training(self, embedding_data: list, input_uri: list, category_name: list, space_keys: list = None):
'''
Small wrapper around log() to log a batch of training embedding
- sets dataset to "training"
@@ -97,10 +120,11 @@ class Chroma:
embedding_data=embedding_data,
input_uri=input_uri,
dataset="training",
category_name=category_name
category_name=category_name,
space_keys=space_keys
)
def log_production(self, embedding_data: list, input_uri: list, category_name: list):
def log_production(self, embedding_data: list, input_uri: list, category_name: list, space_keys: list = None):
'''
Small wrapper around log() to log a batch of production embedding
- sets dataset to "production"
@@ -109,10 +133,11 @@ class Chroma:
embedding_data=embedding_data,
input_uri=input_uri,
dataset="production",
category_name=category_name
category_name=category_name,
space_keys=space_keys
)
def log_triage(self, embedding_data: list, input_uri: list, category_name: list):
def log_triage(self, embedding_data: list, input_uri: list, category_name: list, space_keys: list = None):
'''
Small wrapper around log() to log a batch of triage embedding
- sets dataset to "triage"
@@ -121,14 +146,19 @@ class Chroma:
embedding_data=embedding_data,
input_uri=input_uri,
dataset="triage",
category_name=category_name
category_name=category_name,
space_keys=space_keys
)
def get_nearest_neighbors(self, embedding, n_results=10, category_name=None, dataset="training"):
def get_nearest_neighbors(self, embedding, n_results=10, category_name=None, dataset="training", space_key = None):
'''
Gets the nearest neighbors of a single embedding
'''
if not space_key:
space_key = self._space_key
x = requests.post(self._api_url + "/get_nearest_neighbors", data = json.dumps({
"space_key": space_key,
"embedding": embedding,
"n_results": n_results,
"category_name": category_name,

View File

@@ -7,11 +7,9 @@ from fastapi import FastAPI, Response, status
from chroma_server.db.clickhouse import Clickhouse
from chroma_server.index.hnswlib import Hnswlib
from chroma_server.algorithms.rand_subsample import rand_bisectional_subsample
from chroma_server.types import AddEmbedding, QueryEmbedding
from chroma_server.types import AddEmbedding, QueryEmbedding, ProcessEmbedding, FetchEmbedding, CountEmbedding
from chroma_server.utils import logger
# Boot script
db = Clickhouse
ann_index = Hnswlib
@@ -22,20 +20,15 @@ app = FastAPI(debug=True)
app._db = db()
app._ann_index = ann_index()
if not os.path.exists(".chroma"):
os.mkdir(".chroma")
if os.path.exists(".chroma/chroma.parquet"):
logger.info("Loading existing chroma database")
app._db.load()
if os.path.exists(".chroma/index.bin"):
logger.info("Loading existing chroma index")
app._ann_index.load(app._db.count(), len(app._db.fetch(limit=1).embedding_data))
# scoping
# an embedding space is specific to a particular trained model and layer
# instead of making the user manage this complexity, we will handle some conventions here
# that being said, we will only store a single string, "space_key" in the db, which the user can, in principle, override
# - embeddings are always written with and fetched from the same space_key
# - indexes are specific to a space_key and a timestamp
# - the client can handle the app + model_verison + layer => space_key string generation
# API Endpoints
@app.get("/api/v1")
async def root():
'''
@@ -49,78 +42,65 @@ async def add_to_db(new_embedding: AddEmbedding):
Save embedding to database
- supports single or batched embeddings
'''
print("add_to_db, new_embedding.space_key", new_embedding, new_embedding.space_key)
app._db.add_batch(
new_embedding.space_key,
new_embedding.embedding_data,
new_embedding.input_uri,
new_embedding.dataset,
new_embedding.custom_quality_score,
new_embedding.category_name
)
)
return {"response": "Added record to database"}
@app.get("/api/v1/process")
async def process():
async def process(process_embedding: ProcessEmbedding):
'''
Currently generates an index for the embedding db
'''
app._ann_index.run(app._db.fetch())
where_filter = {"space_key": process_embedding.space_key}
print("process, where_filter", where_filter)
app._ann_index.run(process_embedding.space_key, app._db.fetch(where_filter))
@app.get("/api/v1/fetch")
async def fetch(where_filter={}, sort=None, limit=None):
async def fetch(fetch_embedding: FetchEmbedding):
'''
Fetches embeddings from the database
- enables filtering by where_filter, sorting by key, and limiting the number of results
'''
return app._db.fetch(where_filter, sort, limit).to_dict(orient="records")
return app._db.fetch(fetch_embedding.where_filter, fetch_embedding.sort, fetch_embedding.limit)
@app.get("/api/v1/count")
async def count():
async def count(count_embedding: CountEmbedding):
'''
Returns the number of records in the database
'''
return ({"count": app._db.count()})
@app.get("/api/v1/persist")
async def persist():
'''
Persist the database and index to disk
'''
if not os.path.exists(".chroma"):
os.mkdir(".chroma")
app._db.persist()
app._ann_index.persist()
return True
return {"count": app._db.count(space_key=count_embedding.space_key)}
@app.get("/api/v1/reset")
async def reset():
'''
Reset the database and index
'''
shutil.rmtree(".chroma", ignore_errors=True)
app._db = db()
app._db.reset()
app._ann_index = ann_index()
return True
@app.get("/api/v1/rand")
async def rand(where_filter={}, sort=None, limit=None):
'''
Randomly bisection the database
'''
results = app._db.fetch(where_filter, sort, limit)
rand = rand_bisectional_subsample(results)
return rand.to_dict(orient="records")
@app.post("/api/v1/get_nearest_neighbors")
async def get_nearest_neighbors(embedding: QueryEmbedding):
'''
return the distance, database ids, and embedding themselves for the input embedding
'''
print("get_nearest_neighbors, embedding.space_key", embedding.space_key)
if embedding.space_key is None:
return {"error": "space_key is required"}
ids = None
filter_by_where = {}
filter_by_where["space_key"] = embedding.space_key
if embedding.category_name is not None:
filter_by_where['category_name'] = embedding.category_name
if embedding.dataset is not None:
@@ -128,9 +108,9 @@ async def get_nearest_neighbors(embedding: QueryEmbedding):
if filter_by_where is not None:
results = app._db.fetch(filter_by_where)
ids = [str(item[0]) for item in results]
ids = [str(item[1]) for item in results] # 1 is the uuid column
uuids, distances = app._ann_index.get_nearest_neighbors(embedding.embedding, embedding.n_results, ids)
uuids, distances = app._ann_index.get_nearest_neighbors(embedding.space_key, embedding.embedding, embedding.n_results, ids)
return {
"ids": uuids,
"embeddings": app._db.get_by_ids(uuids),#

View File

@@ -1,5 +1,6 @@
from abc import abstractmethod
# TODO: update this to match the clickhouse implementation
class Database():
@abstractmethod
def __init__(self):
@@ -11,16 +12,4 @@ class Database():
@abstractmethod
def fetch(self, query):
pass
@abstractmethod
def delete_batch(self, batch):
pass
@abstractmethod
def persist(self):
pass
@abstractmethod
def load(self):
pass

View File

@@ -10,6 +10,7 @@ class Clickhouse(Database):
def _create_table_embeddings(self):
self._conn.execute('''CREATE TABLE IF NOT EXISTS embeddings (
space_key String,
uuid UUID,
embedding_data Array(Float64),
input_uri String,
@@ -24,22 +25,21 @@ class Clickhouse(Database):
self._conn = client
self._create_table_embeddings()
def add_batch(self, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None):
def add_batch(self, space_key, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None):
data_to_insert = []
for i in range(len(embedding_data)):
data_to_insert.append([uuid.uuid4(), embedding_data[i], input_uri[i], dataset[i], category_name[i]])
data_to_insert.append([space_key[i], uuid.uuid4(), embedding_data[i], input_uri[i], dataset[i], category_name[i]])
self._conn.execute('''
INSERT INTO embeddings (uuid, embedding_data, input_uri, dataset, category_name) VALUES''', data_to_insert)
INSERT INTO embeddings (space_key, uuid, embedding_data, input_uri, dataset, category_name) VALUES''', data_to_insert)
def count(self):
return self._conn.execute('SELECT COUNT() FROM embeddings')
def update(self, data): # call this update_custom_quality_score! that is all it does
pass
def count(self, space_key=None):
return self._conn.execute(f"SELECT COUNT() FROM embeddings WHERE space_key = '{space_key}'")[0][0]
def fetch(self, where_filter={}, sort=None, limit=None):
if where_filter["space_key"] is None:
return {"error": "space_key is required"}
s3= time.time()
# check to see if query is a dict and if it is a flat list of key value pairs
if where_filter is not None:
@@ -62,8 +62,11 @@ class Clickhouse(Database):
if limit is not None or isinstance(limit, int):
where_filter += f" LIMIT {limit}"
print("where_filter", where_filter)
val = self._conn.execute(f'''
SELECT
space_key,
uuid,
embedding_data,
input_uri,
@@ -78,18 +81,10 @@ class Clickhouse(Database):
return val
def delete_batch(self, batch):
pass
def persist(self):
pass
def load(self, path=".chroma/chroma.parquet"):
pass
def get_by_ids(self, ids=list):
return self._conn.execute(f'''
SELECT
space_key,
uuid,
embedding_data,
input_uri,

View File

@@ -1,176 +0,0 @@
from os import EX_CANTCREAT
from chroma_server.db.abstract import Database
import duckdb
# import numpy as np
# import pandas as pd
class DuckDB(Database):
_conn = None
def __init__(self):
self._conn = duckdb.connect()
self._conn.execute('''
CREATE TABLE embeddings (
id integer PRIMARY KEY,
embedding_data REAL[],
input_uri STRING,
dataset STRING,
custom_quality_score REAL,
category_name STRING
)
''')
# ids to manage internal bookkeeping and *nothing else*, users should not have to care about these ids
self._conn.execute('''
CREATE SEQUENCE seq_id START 1;
''')
self._conn.execute('''
-- change the default null sorting order to either NULLS FIRST and NULLS LAST
PRAGMA default_null_order='NULLS LAST';
-- change the default sorting order to either DESC or ASC
PRAGMA default_order='DESC';
''')
return
def add_batch(self, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None):
'''
Add embeddings to the database
This accepts both a single input and a list of inputs
'''
# create list of the types of all inputs
types = [type(x).__name__ for x in [embedding_data, input_uri]]
# if all of the types are 'list' - do batch mode
if all(x == 'list' for x in types):
lengths = [len(x) for x in [embedding_data, input_uri]]
# accepts some inputs as str or none, and this multiples them out to the correct length
if custom_quality_score is None or isinstance(custom_quality_score, str):
custom_quality_score = [custom_quality_score] * lengths[0]
if category_name is None or isinstance(category_name, str):
category_name = [category_name] * lengths[0]
if dataset is None or isinstance(dataset, str):
dataset = [dataset] * lengths[0]
# we have to move from column to row format for duckdb
data_to_insert = []
for i in range(lengths[0]):
data_to_insert.append([embedding_data[i], input_uri[i], dataset[i], custom_quality_score[i], category_name[i]])
if all(x == lengths[0] for x in lengths):
self._conn.executemany('''
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)''',
data_to_insert
)
return
# if any of the types are 'list' - throw an error
if any(x == list for x in [input_uri, dataset, custom_quality_score, category_name]):
raise Exception("Invalid input types. One input is a list where others are not: " + str(types))
# single insert mode
# This should never fire because we do everything in batch mode, but given the mode away from duckdb likely, I am just leaving it in
self._conn.execute('''
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)''',
[embedding_data, input_uri, dataset, custom_quality_score, category_name]
)
def count(self):
return self._conn.execute('''
SELECT COUNT(*) FROM embeddings;
''').fetchone()[0]
def update(self, data): # call this update_custom_quality_score! that is all it does
'''
I was not able to figure out (yet) how to do a bulk update in duckdb
This is going to be fairly slow
'''
for element in data:
if element['custom_quality_score'] is None:
continue
self._conn.execute(f'''
UPDATE embeddings SET custom_quality_score={element['custom_quality_score']} WHERE id={element['id']}'''
)
def fetch(self, where_filter={}, sort=None, limit=None):
# check to see if query is a dict and if it is a flat list of key value pairs
if where_filter is not None:
if not isinstance(where_filter, dict):
raise Exception("Invalid where_filter: " + str(where_filter))
# ensure where_filter is a flat dict
for key in where_filter:
if isinstance(where_filter[key], dict):
raise Exception("Invalid where_filter: " + str(where_filter))
where_filter = " AND ".join([f"{key} = '{value}'" for key, value in where_filter.items()])
if where_filter:
where_filter = f"WHERE {where_filter}"
if sort is not None:
where_filter += f" ORDER BY {sort}"
if limit is not None or isinstance(limit, int):
where_filter += f" LIMIT {limit}"
return self._conn.execute(f'''
SELECT
id,
embedding_data,
input_uri,
dataset,
custom_quality_score,
category_name
FROM
embeddings
{where_filter}
''').fetchdf().replace({np.nan: None}) # replace nan with None for json serialization
def delete_batch(self, batch):
raise NotImplementedError
def persist(self):
'''
Persist the database to disk
'''
if self._conn is None:
return
self._conn.execute('''
COPY
(SELECT * FROM embeddings)
TO '.chroma/chroma.parquet'
(FORMAT PARQUET);
''')
def load(self, path=".chroma/chroma.parquet"):
'''
Load the database from disk
'''
self._conn.execute(f"INSERT INTO embeddings SELECT * FROM read_parquet('{path}');")
def get_by_ids(self, ids=list):
# select from duckdb table where ids are in the list
if not isinstance(ids, list):
raise Exception("ids must be a list")
if not ids:
# create an empty pandas dataframe
return pd.DataFrame()
return self._conn.execute(f'''
SELECT
id,
embedding_data,
input_uri,
dataset,
custom_quality_score,
category_name
FROM
embeddings
WHERE
id IN ({','.join([str(x) for x in ids])})
''').fetchdf().replace({np.nan: None}) # replace nan with None for json serialization

View File

@@ -1,4 +1,5 @@
import hnswlib
import pickle
import time
import numpy as np
from chroma_server.index.abstract import Index
@@ -6,38 +7,37 @@ from chroma_server.utils import logger
class Hnswlib(Index):
# we cache the index and mappers for the latest space_key
_space_key = None
_index = None
# these data structures enable us to map between uuids and ids
# - our uuids are strings (clickhouse doesnt do autoincrementing ids for performance)
# - but hnswlib uses integers only as ids
# - so this is a temporary bandaid.
# TODO: this should get written to disk somehow or we the index will be come useless after a restart
# - so this is a bandaid.
_id_to_uuid = {}
_uuid_to_id = {}
def __init__(self):
pass
def run(self, embedding_data):
def run(self, space_key, embedding_data):
# more comments available at the source: https://github.com/nmslib/hnswlib
self._space_key = space_key
s1 = time.time()
uuids = []
embeddings = []
ids = []
i = 0
for embedding in embedding_data:
uuids.append(str(embedding[0]))
embeddings.append((embedding[1]))
uuids.append(str(embedding[1]))
embeddings.append((embedding[2]))
ids.append(i)
self._id_to_uuid[i] = str(embedding[0])
self._uuid_to_id[str(embedding[0])] = i
self._id_to_uuid[i] = str(embedding[1])
self._uuid_to_id[str(embedding[1])] = i
i += 1
print('time to create uuids and embeddings: ', time.time() - s1)
# We split the data in two batches:
data1 = embeddings
dim = len(data1[0])
num_elements = len(data1)
@@ -67,25 +67,40 @@ class Hnswlib(Index):
self._index = p
def fetch(self, query):
raise NotImplementedError
self.save()
def delete_batch(self, batch):
raise NotImplementedError
def persist(self):
def save(self):
if self._index is None:
return
self._index.save_index(".chroma/index.bin")
logger.debug('Index saved to .chroma/index.bin')
self._index.save_index(f"/index_data/index_{self._space_key}.bin")
def load(self, elements, dimensionality):
# pickle the mappers
with open(f"/index_data/id_to_uuid_{self._space_key}.pkl", 'wb') as f:
pickle.dump(self._id_to_uuid, f, pickle.HIGHEST_PROTOCOL)
with open(f"/index_data/uuid_to_id_{self._space_key}.pkl", 'wb') as f:
pickle.dump(self._uuid_to_id, f, pickle.HIGHEST_PROTOCOL)
logger.debug('Index saved to /index_data/index.bin')
def load(self, space_key, elements, dimensionality):
p = hnswlib.Index(space='l2', dim= dimensionality)
self._index = p
self._index.load_index(".chroma/index.bin", max_elements= elements)
self._index.load_index(f"/index_data/index_{space_key}.bin", max_elements= elements)
# unpickle the mappers
with open(f"/index_data/id_to_uuid_{space_key}.pkl", 'rb') as f:
self._id_to_uuid = pickle.load(f)
with open(f"/index_data/uuid_to_id_{space_key}.pkl", 'rb') as f:
self._uuid_to_id = pickle.load(f)
self._space_key = space_key
# do knn_query on hnswlib to get nearest neighbors
def get_nearest_neighbors(self, query, k, uuids=None):
def get_nearest_neighbors(self, space_key, query, k, uuids=None):
if self._space_key != space_key:
# TODO: deal with this magic number....
self.load(space_key, 500_000, len(query))
s2= time.time()
# get ids from uuids

View File

@@ -3,6 +3,7 @@ from typing import Union, Any
# type supports single and batch mode
class AddEmbedding(BaseModel):
space_key: Union[str, list]
embedding_data: list
input_uri: Union[str, list]
dataset: Union[str, list] = None
@@ -10,7 +11,19 @@ class AddEmbedding(BaseModel):
category_name: Union[str, list] = None
class QueryEmbedding(BaseModel):
space_key: str = None
embedding: list
n_results: int = 10
category_name: str = None
dataset: str = None
class ProcessEmbedding(BaseModel):
space_key: str = None
class FetchEmbedding(BaseModel):
where_filter: dict = {}
sort: str = None
limit: int = None
class CountEmbedding(BaseModel):
space_key: str = None

View File

@@ -11,6 +11,7 @@ services:
dockerfile: Dockerfile
volumes:
- ./chroma-server/:/chroma-server/
- index_data:/index_data
command: uvicorn chroma_server:app --reload --workers 1 --host 0.0.0.0 --port 8000
# env_file:
# - ./chroma-server/.env
@@ -36,4 +37,6 @@ services:
volumes:
clickhouse_data:
driver: local
index_data:
driver: local

View File

@@ -1,5 +1,36 @@
from hashlib import new
import chroma_client
import pandas as pd
from chroma_client import Chroma
new_labels = chroma_client.fetch_new_labels()
print(new_labels)
client1 = Chroma(app="yolov3", model_version="1", layer="1")
# client1.reset()
knife_embedding = [0.2310010939836502, -0.3462161719799042, 0.29164767265319824, -0.09828940033912659, 1.814868450164795, -10.517369270324707, -13.531850814819336, -12.730537414550781, -13.011675834655762, -10.257010459899902, -13.779699325561523, -11.963963508605957, -13.948140144348145, -12.46799087524414, -14.569470405578613, -16.388280868530273, -13.76762580871582, -12.192169189453125, -12.204055786132812, -12.259000778198242, -13.696036338806152, -14.609177589416504, -16.951879501342773, -17.096384048461914, -14.355693817138672, -16.643482208251953, -14.270745277404785, -14.375198364257812, -14.381218910217285, -13.475995063781738, -12.694938659667969, -10.011992454528809, -9.770626068115234, -13.155019760131836, -16.136341094970703, -6.552414417266846, -11.243837356567383, -16.678457260131836, -14.629229545593262, -10.052337646484375, -15.451828956604004, -12.561151504516602, -11.68396282196045, -11.975972175598145, -11.09926986694336, -13.060500144958496, -12.075592994689941, -1.0808746814727783, 1.7046797275543213, -3.8080708980560303, -11.401922225952148, -12.184720039367676, -13.262567520141602, -11.299583435058594, -13.654638290405273, -10.767330169677734, -9.012763977050781, -10.202326774597168, -10.088111877441406, -13.247991561889648, -9.651527404785156, -11.903244972229004, -13.922954559326172, -17.37179946899414, -12.51513385772705, -7.8046746253967285, -14.406414985656738, -13.172696113586426, -11.194984436035156, -12.029500961303711, -10.996524810791016, -10.828441619873047, -8.673471450805664, -13.800869941711426, -9.680946350097656, -12.964024543762207, -9.694372177124023, -13.132003784179688, -9.38864803314209, -14.305071830749512, -14.4693603515625, -5.0566205978393555, -15.685358047485352, -12.493011474609375, -8.424881935119629]
get_nearest_neighbors = client1.get_nearest_neighbors(knife_embedding, 4, None, "training")
res_df = pd.DataFrame(get_nearest_neighbors['embeddings'])
print(res_df.head())
# client1.log([[1,2,3,4,5]], ["/images/1"], ["training"], ['spoon'])
# client1.log([[1,2,3,4,5]], ["/images/2"], ["training"], ['spoon'])
# client1.log([[1,2,3,4,5]], ["/images/3"], ["training"], ['spoon'])
# client1.log([[1,2,3,4,5]], ["/images/4"], ["training"], ['spoon'])
# client1.log([[1,2,3,4,5]], ["/prod/1"], ["test"], ['spoon'])
# client1.log([[1,2,3,4,5]], ["/prod/2"], ["test"], ['spoon'])
# print("context", client1.get_context())
# print("context", client1.heartbeat())
# print("layer 1", client1.count(client1.get_context()))
# # print("fetch", client1.fetch())
# print(client1.process())
# print(client1.get_nearest_neighbors([1,2,3,4,5], 2))
# client1.set_context("test", "1", "2")
# client1.log([[1,2,3,4,5]], ["/images/1"], ["training"], ['knife'])
# client1.log([[1,2,3,4,5]], ["/images/4"], ["training"], ['knife'])
# client1.log([[1,2,3,4,5]], ["/prod/2"], ["test"], ['knife'])
# print("context", client1.get_context())
# print("context", client1.heartbeat())
# print("layer 1", client1.count(client1.get_context()))
# # print("fetch", client1.fetch())
# print(client1.process())
# print(client1.get_nearest_neighbors([1,2,3,4,5], 2))

View File

@@ -19,7 +19,7 @@ if __name__ == "__main__":
data_length = len(df)
chroma = Chroma()
chroma = Chroma(app="yolov3", model_version="1", layer="1")
chroma.reset() #make sure we are using a fresh db
allstart = time.time()
start = time.time()