mirror of
https://github.com/placeholder-soft/chroma.git
synced 2026-05-28 15:13:41 +08:00
space_key -> model_space
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from chroma_client import Chroma
|
||||
|
||||
chroma = Chroma()
|
||||
chroma.set_space_key('sample_space')
|
||||
chroma.set_model_space('sample_space')
|
||||
print("Getting heartbeat to verify the server is up")
|
||||
print(chroma.heartbeat())
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Union
|
||||
class Chroma:
|
||||
|
||||
_api_url = "http://localhost:8000/api/v1"
|
||||
_space_key = "default_scope"
|
||||
_model_space = "default_scope"
|
||||
|
||||
def __init__(self, url=None, app=None, model_version=None, layer=None):
|
||||
"""Initialize Chroma client"""
|
||||
@@ -14,38 +14,38 @@ class Chroma:
|
||||
self._api_url = url
|
||||
|
||||
if app and model_version and layer:
|
||||
self._space_key = app + "_" + model_version + "_" + layer
|
||||
self._model_space = app + "_" + model_version + "_" + layer
|
||||
|
||||
def set_context(self, app, model_version, layer):
|
||||
'''Sets the context of the client'''
|
||||
self._space_key = app + "_" + model_version + "_" + layer
|
||||
def set_model_space(self, app, model_version, layer):
|
||||
'''Sets the model_space of the client'''
|
||||
self._model_space = app + "_" + model_version + "_" + layer
|
||||
|
||||
def set_space_key(self, space_key):
|
||||
def set_model_space(self, model_space):
|
||||
'''Sets the space key for the client, enables overriding the string concat'''
|
||||
self._space_key = space_key
|
||||
self._model_space = model_space
|
||||
|
||||
def get_context(self):
|
||||
'''Returns the space key'''
|
||||
return self._space_key
|
||||
def get_model_space(self):
|
||||
'''Returns the model_space key'''
|
||||
return self._model_space
|
||||
|
||||
def heartbeat(self):
|
||||
'''Returns the current server time in nanoseconds to check if the server is alive'''
|
||||
return requests.get(self._api_url).json()
|
||||
|
||||
def count(self, space_key=None, all=False):
|
||||
def count(self, model_space=None, all=False):
|
||||
'''Returns the number of embeddings in the database'''
|
||||
params = {"space_key": space_key or self._space_key}
|
||||
params = {"model_space": model_space or self._model_space}
|
||||
|
||||
if all:
|
||||
params["space_key"] = None
|
||||
params["model_space"] = None
|
||||
|
||||
x = requests.get(self._api_url + "/count", params=params)
|
||||
return x.json()
|
||||
|
||||
def fetch(self, where_filter={}, sort=None, limit=None, offset=None, page=None, page_size=None):
|
||||
'''Fetches embeddings from the database'''
|
||||
if self._space_key:
|
||||
where_filter["space_key"] = self._space_key
|
||||
if self._model_space:
|
||||
where_filter["model_space"] = self._model_space
|
||||
|
||||
if page and page_size:
|
||||
offset = (page - 1) * page_size
|
||||
@@ -60,8 +60,8 @@ class Chroma:
|
||||
|
||||
def delete(self, where_filter={}):
|
||||
'''Deletes embeddings from the database'''
|
||||
if self._space_key:
|
||||
where_filter["space_key"] = self._space_key
|
||||
if self._model_space:
|
||||
where_filter["model_space"] = self._model_space
|
||||
|
||||
return requests.post(self._api_url + "/delete", data=json.dumps({
|
||||
"where_filter":where_filter,
|
||||
@@ -72,17 +72,17 @@ class Chroma:
|
||||
input_uri: list,
|
||||
dataset: list = None,
|
||||
category_name: list = None,
|
||||
space_keys: list = None):
|
||||
model_spaces: list = None):
|
||||
'''
|
||||
Logs a batch of embeddings to the database
|
||||
- pass in column oriented data lists
|
||||
'''
|
||||
|
||||
if not space_keys:
|
||||
space_keys = self._space_key
|
||||
if not model_spaces:
|
||||
model_spaces = self._model_space
|
||||
|
||||
x = requests.post(self._api_url + "/add", data = json.dumps({
|
||||
"space_key": space_keys,
|
||||
"model_space": model_spaces,
|
||||
"embedding_data": embedding_data,
|
||||
"input_uri": input_uri,
|
||||
"dataset": dataset,
|
||||
@@ -94,7 +94,7 @@ class Chroma:
|
||||
else:
|
||||
return False
|
||||
|
||||
def log_training(self, embedding_data: list, input_uri: list, category_name: list, space_keys: list = None):
|
||||
def log_training(self, embedding_data: list, input_uri: list, category_name: list, model_spaces: list = None):
|
||||
'''
|
||||
Small wrapper around log() to log a batch of training embedding - sets dataset to "training"
|
||||
'''
|
||||
@@ -104,10 +104,10 @@ class Chroma:
|
||||
input_uri=input_uri,
|
||||
dataset=datasets,
|
||||
category_name=category_name,
|
||||
space_keys=space_keys
|
||||
model_spaces=model_spaces
|
||||
)
|
||||
|
||||
def log_production(self, embedding_data: list, input_uri: list, category_name: list, space_keys: list = None):
|
||||
def log_production(self, embedding_data: list, input_uri: list, category_name: list, model_spaces: list = None):
|
||||
'''
|
||||
Small wrapper around log() to log a batch of production embedding - sets dataset to "production"
|
||||
'''
|
||||
@@ -117,10 +117,10 @@ class Chroma:
|
||||
input_uri=input_uri,
|
||||
dataset=datasets,
|
||||
category_name=category_name,
|
||||
space_keys=space_keys
|
||||
model_spaces=model_spaces
|
||||
)
|
||||
|
||||
def log_triage(self, embedding_data: list, input_uri: list, category_name: list, space_keys: list = None):
|
||||
def log_triage(self, embedding_data: list, input_uri: list, category_name: list, model_spaces: list = None):
|
||||
'''
|
||||
Small wrapper around log() to log a batch of triage embedding - sets dataset to "triage"
|
||||
'''
|
||||
@@ -130,16 +130,16 @@ class Chroma:
|
||||
input_uri=input_uri,
|
||||
dataset=datasets,
|
||||
category_name=category_name,
|
||||
space_keys=space_keys
|
||||
model_spaces=model_spaces
|
||||
)
|
||||
|
||||
def get_nearest_neighbors(self, embedding, n_results=10, category_name=None, dataset="training", space_key = None):
|
||||
def get_nearest_neighbors(self, embedding, n_results=10, category_name=None, dataset="training", model_space = None):
|
||||
'''Gets the nearest neighbors of a single embedding'''
|
||||
if not space_key:
|
||||
space_key = self._space_key
|
||||
if not model_space:
|
||||
model_space = self._model_space
|
||||
|
||||
x = requests.post(self._api_url + "/get_nearest_neighbors", data = json.dumps({
|
||||
"space_key": space_key,
|
||||
"model_space": model_space,
|
||||
"embedding": embedding,
|
||||
"n_results": n_results,
|
||||
"category_name": category_name,
|
||||
@@ -151,12 +151,12 @@ class Chroma:
|
||||
else:
|
||||
return False
|
||||
|
||||
def process(self, space_key=None):
|
||||
def process(self, model_space=None):
|
||||
'''
|
||||
Processes embeddings in the database
|
||||
- currently this only runs hnswlib, doesnt return anything
|
||||
'''
|
||||
requests.post(self._api_url + "/process", data = json.dumps({"space_key": space_key or self._space_key}))
|
||||
requests.post(self._api_url + "/process", data = json.dumps({"model_space": model_space or self._model_space}))
|
||||
return True
|
||||
|
||||
def reset(self):
|
||||
@@ -167,13 +167,13 @@ class Chroma:
|
||||
'''Runs a raw SQL query against the database'''
|
||||
return requests.post(self._api_url + "/raw_sql", data = json.dumps({"raw_sql": sql})).json()
|
||||
|
||||
def calculate_results(self, space_key=None):
|
||||
def calculate_results(self, model_space=None):
|
||||
'''Calculates the results for the given space key'''
|
||||
return requests.post(self._api_url + "/calculate_results", data = json.dumps({"space_key": space_key or self._space_key})).json()
|
||||
return requests.post(self._api_url + "/calculate_results", data = json.dumps({"model_space": model_space or self._model_space})).json()
|
||||
|
||||
def get_results(self, space_key=None, n_results = 100):
|
||||
def get_results(self, model_space=None, n_results = 100):
|
||||
'''Gets the results for the given space key'''
|
||||
return requests.post(self._api_url + "/get_results", data = json.dumps({"space_key": space_key or self._space_key, "n_results": n_results})).json()
|
||||
return requests.post(self._api_url + "/get_results", data = json.dumps({"model_space": model_space or self._model_space, "n_results": n_results})).json()
|
||||
|
||||
def get_task_status(self, task_id):
|
||||
'''Gets the status of a task'''
|
||||
|
||||
@@ -38,8 +38,8 @@ async def root():
|
||||
|
||||
|
||||
@app.post("/api/v1/calculate_results")
|
||||
async def calculate_results(space_key: SpaceKeyInput):
|
||||
task = heavy_offline_analysis.delay(space_key.space_key)
|
||||
async def calculate_results(model_space: SpaceKeyInput):
|
||||
task = heavy_offline_analysis.delay(model_space.model_space)
|
||||
chroma_telemetry.capture('heavy-offline-analysis')
|
||||
return JSONResponse({"task_id": task.id})
|
||||
|
||||
@@ -55,7 +55,7 @@ async def get_status(task_id):
|
||||
|
||||
@app.post("/api/v1/get_results")
|
||||
async def get_results(results: Results):
|
||||
return app._db.return_results(results.space_key, results.n_results)
|
||||
return app._db.return_results(results.model_space, results.n_results)
|
||||
|
||||
|
||||
|
||||
@@ -65,12 +65,12 @@ async def add_to_db(new_embedding: AddEmbedding):
|
||||
|
||||
number_of_embeddings = len(new_embedding.embedding_data)
|
||||
|
||||
if isinstance(new_embedding.space_key, str):
|
||||
space_key = [new_embedding.space_key] * number_of_embeddings
|
||||
elif len(new_embedding.space_key) == 1:
|
||||
space_key = [new_embedding.space_key[0]] * number_of_embeddings
|
||||
if isinstance(new_embedding.model_space, str):
|
||||
model_space = [new_embedding.model_space] * number_of_embeddings
|
||||
elif len(new_embedding.model_space) == 1:
|
||||
model_space = [new_embedding.model_space[0]] * number_of_embeddings
|
||||
else:
|
||||
space_key = new_embedding.space_key
|
||||
model_space = new_embedding.model_space
|
||||
|
||||
if isinstance(new_embedding.dataset, str):
|
||||
dataset = [new_embedding.dataset] * number_of_embeddings
|
||||
@@ -80,10 +80,10 @@ async def add_to_db(new_embedding: AddEmbedding):
|
||||
dataset = new_embedding.dataset
|
||||
|
||||
# print the len of all inputs to add_batch
|
||||
print(len(new_embedding.embedding_data), len(new_embedding.input_uri), len(space_key), len(dataset))
|
||||
print(len(new_embedding.embedding_data), len(new_embedding.input_uri), len(model_space), len(dataset))
|
||||
|
||||
app._db.add_batch(
|
||||
space_key,
|
||||
model_space,
|
||||
new_embedding.embedding_data,
|
||||
new_embedding.input_uri,
|
||||
dataset,
|
||||
@@ -98,9 +98,9 @@ async def process(process_embedding: ProcessEmbedding):
|
||||
'''
|
||||
Currently generates an index for the embedding db
|
||||
'''
|
||||
fetch = app._db.fetch({"space_key": process_embedding.space_key}, columnar=True)
|
||||
fetch = app._db.fetch({"model_space": process_embedding.model_space}, columnar=True)
|
||||
chroma_telemetry.capture('created-index', {'n': len(fetch[2])})
|
||||
app._ann_index.run(process_embedding.space_key, fetch[1], fetch[2]) # more magic number, ugh
|
||||
app._ann_index.run(process_embedding.model_space, fetch[1], fetch[2]) # more magic number, ugh
|
||||
|
||||
return {"response": "Processed space"}
|
||||
|
||||
@@ -121,11 +121,11 @@ async def delete(embedding: DeleteEmbedding):
|
||||
return app._db.delete(embedding.where_filter)
|
||||
|
||||
@app.get("/api/v1/count")
|
||||
async def count(space_key: str = None):
|
||||
async def count(model_space: str = None):
|
||||
'''
|
||||
Returns the number of records in the database
|
||||
'''
|
||||
return {"count": app._db.count(space_key=space_key)}
|
||||
return {"count": app._db.count(model_space=model_space)}
|
||||
|
||||
@app.post("/api/v1/reset")
|
||||
async def reset():
|
||||
@@ -143,12 +143,12 @@ async def get_nearest_neighbors(embedding: QueryEmbedding):
|
||||
'''
|
||||
return the distance, database ids, and embedding themselves for the input embedding
|
||||
'''
|
||||
if embedding.space_key is None:
|
||||
return {"error": "space_key is required"}
|
||||
if embedding.model_space is None:
|
||||
return {"error": "model_space is required"}
|
||||
|
||||
ids = None
|
||||
filter_by_where = {}
|
||||
filter_by_where["space_key"] = embedding.space_key
|
||||
filter_by_where["model_space"] = embedding.model_space
|
||||
if embedding.category_name is not None:
|
||||
filter_by_where["category_name"] = embedding.category_name
|
||||
if embedding.dataset is not None:
|
||||
@@ -158,7 +158,7 @@ async def get_nearest_neighbors(embedding: QueryEmbedding):
|
||||
results = app._db.fetch(filter_by_where)
|
||||
ids = [str(item[get_col_pos('uuid')]) for item in results]
|
||||
|
||||
uuids, distances = app._ann_index.get_nearest_neighbors(embedding.space_key, embedding.embedding, embedding.n_results, ids)
|
||||
uuids, distances = app._ann_index.get_nearest_neighbors(embedding.model_space, embedding.embedding, embedding.n_results, ids)
|
||||
return {
|
||||
"ids": uuids,
|
||||
"embeddings": app._db.get_by_ids(uuids),
|
||||
|
||||
@@ -7,11 +7,11 @@ class Database:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_batch(self, space_key, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None):
|
||||
def add_batch(self, model_space, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def count(self, space_key=None):
|
||||
def count(self, model_space=None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
from clickhouse_driver import connect, Client
|
||||
|
||||
EMBEDDING_TABLE_SCHEMA = [
|
||||
{'space_key': 'String'},
|
||||
{'model_space': 'String'},
|
||||
{'uuid': 'UUID'},
|
||||
{'embedding_data': 'Array(Float64)'},
|
||||
{'input_uri': 'String'},
|
||||
@@ -15,7 +15,7 @@ EMBEDDING_TABLE_SCHEMA = [
|
||||
]
|
||||
|
||||
RESULTS_TABLE_SCHEMA = [
|
||||
{'space_key': 'String'},
|
||||
{'model_space': 'String'},
|
||||
{'uuid': 'UUID'},
|
||||
{'custom_quality_score': ' Nullable(Float64)'},
|
||||
]
|
||||
@@ -47,7 +47,7 @@ class Clickhouse(Database):
|
||||
def _create_table_embeddings(self):
|
||||
self._conn.execute(f'''CREATE TABLE IF NOT EXISTS embeddings (
|
||||
{db_array_schema_to_clickhouse_schema(EMBEDDING_TABLE_SCHEMA)}
|
||||
) ENGINE = MergeTree() ORDER BY space_key''')
|
||||
) ENGINE = MergeTree() ORDER BY model_space''')
|
||||
|
||||
self._conn.execute(f'''SET allow_experimental_lightweight_delete = true''')
|
||||
self._conn.execute(f'''SET mutations_sync = 1''') # https://clickhouse.com/docs/en/operations/settings/settings/#mutations_sync
|
||||
@@ -55,7 +55,7 @@ class Clickhouse(Database):
|
||||
def _create_table_results(self):
|
||||
self._conn.execute(f'''CREATE TABLE IF NOT EXISTS results (
|
||||
{db_array_schema_to_clickhouse_schema(RESULTS_TABLE_SCHEMA)}
|
||||
) ENGINE = MergeTree() ORDER BY space_key''')
|
||||
) ENGINE = MergeTree() ORDER BY model_space''')
|
||||
|
||||
def __init__(self):
|
||||
client = Client(host='clickhouse', port=os.getenv('CLICKHOUSE_TCP_PORT', '9000'))
|
||||
@@ -63,23 +63,23 @@ class Clickhouse(Database):
|
||||
self._create_table_embeddings()
|
||||
self._create_table_results()
|
||||
|
||||
def add_batch(self, space_key, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None):
|
||||
def add_batch(self, model_space, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None):
|
||||
data_to_insert = []
|
||||
for i in range(len(embedding_data)):
|
||||
data_to_insert.append([space_key[i], uuid.uuid4(), embedding_data[i], input_uri[i], dataset[i], category_name[i]])
|
||||
data_to_insert.append([model_space[i], uuid.uuid4(), embedding_data[i], input_uri[i], dataset[i], category_name[i]])
|
||||
|
||||
self._conn.execute('''
|
||||
INSERT INTO embeddings (space_key, uuid, embedding_data, input_uri, dataset, category_name) VALUES''', data_to_insert)
|
||||
INSERT INTO embeddings (model_space, uuid, embedding_data, input_uri, dataset, category_name) VALUES''', data_to_insert)
|
||||
|
||||
def count(self, space_key=None):
|
||||
def count(self, model_space=None):
|
||||
where_string = ""
|
||||
if space_key is not None:
|
||||
where_string = f"WHERE space_key = '{space_key}'"
|
||||
if model_space is not None:
|
||||
where_string = f"WHERE model_space = '{model_space}'"
|
||||
return self._conn.execute(f"SELECT COUNT() FROM embeddings {where_string}")[0][0]
|
||||
|
||||
def fetch(self, where_filter={}, sort=None, limit=None, offset=None, columnar=False):
|
||||
if where_filter["space_key"] is None:
|
||||
return {"error": "space_key is required"}
|
||||
if where_filter["model_space"] is None:
|
||||
return {"error": "model_space is required"}
|
||||
|
||||
s3= time.time()
|
||||
# check to see if query is a dict and if it is a flat list of key value pairs
|
||||
@@ -100,7 +100,7 @@ class Clickhouse(Database):
|
||||
if sort is not None:
|
||||
where_filter += f" ORDER BY {sort}"
|
||||
else:
|
||||
where_filter += f" ORDER BY space_key" # stable ordering
|
||||
where_filter += f" ORDER BY model_space" # stable ordering
|
||||
|
||||
if limit is not None or isinstance(limit, int):
|
||||
where_filter += f" LIMIT {limit}"
|
||||
@@ -120,8 +120,8 @@ class Clickhouse(Database):
|
||||
return val
|
||||
|
||||
def delete(self, where_filter={}):
|
||||
if where_filter["space_key"] is None:
|
||||
return {"error": "space_key is required. Use reset to clear the entire db"}
|
||||
if where_filter["model_space"] is None:
|
||||
return {"error": "model_space is required. Use reset to clear the entire db"}
|
||||
|
||||
s3= time.time()
|
||||
# check to see if query is a dict and if it is a flat list of key value pairs
|
||||
@@ -160,18 +160,18 @@ class Clickhouse(Database):
|
||||
def raw_sql(self, sql):
|
||||
return self._conn.execute(sql)
|
||||
|
||||
def add_results(self, space_keys, uuids, custom_quality_score):
|
||||
def add_results(self, model_spaces, uuids, custom_quality_score):
|
||||
data_to_insert = []
|
||||
for i in range(len(space_keys)):
|
||||
data_to_insert.append([space_keys[i], uuids[i], custom_quality_score[i]])
|
||||
for i in range(len(model_spaces)):
|
||||
data_to_insert.append([model_spaces[i], uuids[i], custom_quality_score[i]])
|
||||
|
||||
self._conn.execute('''
|
||||
INSERT INTO results (space_key, uuid, custom_quality_score) VALUES''', data_to_insert)
|
||||
INSERT INTO results (model_space, uuid, custom_quality_score) VALUES''', data_to_insert)
|
||||
|
||||
def delete_results(self, space_key):
|
||||
self._conn.execute(f"DELETE FROM results WHERE space_key = '{space_key}'")
|
||||
def delete_results(self, model_space):
|
||||
self._conn.execute(f"DELETE FROM results WHERE model_space = '{model_space}'")
|
||||
|
||||
def return_results(self, space_key, n_results = 100):
|
||||
def return_results(self, model_space, n_results = 100):
|
||||
return self._conn.execute(f'''
|
||||
SELECT
|
||||
embeddings.input_uri,
|
||||
@@ -184,7 +184,7 @@ class Clickhouse(Database):
|
||||
ON
|
||||
results.uuid = embeddings.uuid
|
||||
WHERE
|
||||
results.space_key = '{space_key}'
|
||||
results.model_space = '{model_space}'
|
||||
ORDER BY
|
||||
results.custom_quality_score DESC
|
||||
LIMIT {n_results}
|
||||
|
||||
@@ -9,7 +9,7 @@ from chroma_server.logger import logger
|
||||
|
||||
class Hnswlib(Index):
|
||||
|
||||
_space_key = None
|
||||
_model_space = None
|
||||
_index = None
|
||||
_index_metadata = {
|
||||
'dimensionality': None,
|
||||
@@ -23,7 +23,7 @@ class Hnswlib(Index):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def run(self, space_key, uuids, embeddings):
|
||||
def run(self, model_space, uuids, embeddings):
|
||||
# more comments available at the source: https://github.com/nmslib/hnswlib
|
||||
dimensionality = len(embeddings[0])
|
||||
ids = []
|
||||
@@ -42,7 +42,7 @@ class Hnswlib(Index):
|
||||
index.add_items(embeddings, ids)
|
||||
|
||||
self._index = index
|
||||
self._space_key = space_key
|
||||
self._model_space = model_space
|
||||
self._index_metadata = {
|
||||
'dimensionality': dimensionality,
|
||||
'elements': len(embeddings) ,
|
||||
@@ -53,38 +53,38 @@ class Hnswlib(Index):
|
||||
def save(self):
|
||||
if self._index is None:
|
||||
return
|
||||
self._index.save_index(f"/index_data/index_{self._space_key}.bin")
|
||||
self._index.save_index(f"/index_data/index_{self._model_space}.bin")
|
||||
|
||||
# pickle the mappers
|
||||
with open(f"/index_data/id_to_uuid_{self._space_key}.pkl", 'wb') as f:
|
||||
with open(f"/index_data/id_to_uuid_{self._model_space}.pkl", 'wb') as f:
|
||||
pickle.dump(self._id_to_uuid, f, pickle.HIGHEST_PROTOCOL)
|
||||
with open(f"/index_data/uuid_to_id_{self._space_key}.pkl", 'wb') as f:
|
||||
with open(f"/index_data/uuid_to_id_{self._model_space}.pkl", 'wb') as f:
|
||||
pickle.dump(self._uuid_to_id, f, pickle.HIGHEST_PROTOCOL)
|
||||
with open(f"/index_data/index_metadata_{self._space_key}.pkl", 'wb') as f:
|
||||
with open(f"/index_data/index_metadata_{self._model_space}.pkl", 'wb') as f:
|
||||
pickle.dump(self._index_metadata, f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
logger.debug('Index saved to /index_data/index.bin')
|
||||
|
||||
def load(self, space_key):
|
||||
def load(self, model_space):
|
||||
# unpickle the mappers
|
||||
with open(f"/index_data/id_to_uuid_{space_key}.pkl", 'rb') as f:
|
||||
with open(f"/index_data/id_to_uuid_{model_space}.pkl", 'rb') as f:
|
||||
self._id_to_uuid = pickle.load(f)
|
||||
with open(f"/index_data/uuid_to_id_{space_key}.pkl", 'rb') as f:
|
||||
with open(f"/index_data/uuid_to_id_{model_space}.pkl", 'rb') as f:
|
||||
self._uuid_to_id = pickle.load(f)
|
||||
with open(f"/index_data/index_metadata_{space_key}.pkl", 'rb') as f:
|
||||
with open(f"/index_data/index_metadata_{model_space}.pkl", 'rb') as f:
|
||||
self._index_metadata = pickle.load(f)
|
||||
|
||||
p = hnswlib.Index(space='l2', dim= self._index_metadata['dimensionality'])
|
||||
self._index = p
|
||||
self._index.load_index(f"/index_data/index_{space_key}.bin", max_elements= self._index_metadata['elements'])
|
||||
self._index.load_index(f"/index_data/index_{model_space}.bin", max_elements= self._index_metadata['elements'])
|
||||
|
||||
self._space_key = space_key
|
||||
self._model_space = model_space
|
||||
|
||||
# do knn_query on hnswlib to get nearest neighbors
|
||||
def get_nearest_neighbors(self, space_key, query, k, uuids=None):
|
||||
def get_nearest_neighbors(self, model_space, query, k, uuids=None):
|
||||
|
||||
if self._space_key != space_key:
|
||||
self.load(space_key)
|
||||
if self._model_space != model_space:
|
||||
self.load(model_space)
|
||||
|
||||
s2= time.time()
|
||||
# get ids from uuids
|
||||
|
||||
@@ -23,7 +23,7 @@ async def post_batch_records(ac):
|
||||
"input_uri": ["https://example.com", "https://example.com"],
|
||||
"dataset": ["training", "training"],
|
||||
"category_name": ["person", "person"],
|
||||
"space_key": ["test_space", "test_space"],
|
||||
"model_space": ["test_space", "test_space"],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -35,7 +35,7 @@ async def post_batch_records_minimal(ac):
|
||||
"input_uri": ["https://example.com", "https://example.com"],
|
||||
"dataset": "training",
|
||||
"category_name": ["person", "person"],
|
||||
"space_key": "test_space"
|
||||
"model_space": "test_space"
|
||||
},
|
||||
)
|
||||
|
||||
@@ -47,7 +47,7 @@ async def test_add_to_db_batch():
|
||||
response = await post_batch_records(ac)
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {"response": "Added records to database"}
|
||||
response = await ac.get("/api/v1/count", params={"space_key": "test_space"})
|
||||
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
|
||||
assert response.json() == {"count": 2}
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ async def test_add_to_db_batch_minimal():
|
||||
response = await post_batch_records_minimal(ac)
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {"response": "Added records to database"}
|
||||
response = await ac.get("/api/v1/count", params={"space_key": "test_space"})
|
||||
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
|
||||
assert response.json() == {"count": 2}
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -66,7 +66,7 @@ async def test_fetch_from_db():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
await ac.post("/api/v1/reset")
|
||||
await post_batch_records(ac)
|
||||
params = {"where_filter": {"space_key": "test_space"}}
|
||||
params = {"where_filter": {"model_space": "test_space"}}
|
||||
response = await ac.post("/api/v1/fetch", json=params)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 2
|
||||
@@ -76,7 +76,7 @@ async def test_count_from_db():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
await ac.post("/api/v1/reset") # reset db
|
||||
await post_batch_records(ac)
|
||||
response = await ac.get("/api/v1/count", params={"space_key": "test_space"})
|
||||
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"count": 2}
|
||||
|
||||
@@ -85,11 +85,11 @@ async def test_reset_db():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
await ac.post("/api/v1/reset")
|
||||
await post_batch_records(ac)
|
||||
response = await ac.get("/api/v1/count", params={"space_key": "test_space"})
|
||||
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
|
||||
assert response.json() == {"count": 2}
|
||||
response = await ac.post("/api/v1/reset")
|
||||
assert response.json() == True
|
||||
response = await ac.get("/api/v1/count", params={"space_key": "test_space"})
|
||||
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
|
||||
assert response.json() == {"count": 0}
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -97,9 +97,9 @@ async def test_get_nearest_neighbors():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
await ac.post("/api/v1/reset")
|
||||
await post_batch_records(ac)
|
||||
await ac.post("/api/v1/process", json={"space_key": "test_space"})
|
||||
await ac.post("/api/v1/process", json={"model_space": "test_space"})
|
||||
response = await ac.post(
|
||||
"/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1, "space_key": "test_space"}
|
||||
"/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1, "model_space": "test_space"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()["ids"]) == 1
|
||||
@@ -109,7 +109,7 @@ async def test_get_nearest_neighbors_filter():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
await ac.post("/api/v1/reset")
|
||||
await post_batch_records(ac)
|
||||
await ac.post("/api/v1/process", json={"space_key": "test_space"})
|
||||
await ac.post("/api/v1/process", json={"model_space": "test_space"})
|
||||
response = await ac.post(
|
||||
"/api/v1/get_nearest_neighbors",
|
||||
json={
|
||||
@@ -117,7 +117,7 @@ async def test_get_nearest_neighbors_filter():
|
||||
"n_results": 1,
|
||||
"dataset": "training",
|
||||
"category_name": "monkey",
|
||||
"space_key": "test_space",
|
||||
"model_space": "test_space",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -128,7 +128,7 @@ async def test_process():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
await ac.post("/api/v1/reset")
|
||||
await post_batch_records(ac)
|
||||
response = await ac.post("/api/v1/process", json={"space_key": "test_space"})
|
||||
response = await ac.post("/api/v1/process", json={"model_space": "test_space"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"response": "Processed space"}
|
||||
|
||||
@@ -138,11 +138,11 @@ async def test_delete():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
await ac.post("/api/v1/reset")
|
||||
await post_batch_records(ac)
|
||||
response = await ac.get("/api/v1/count", params={"space_key": "test_space"})
|
||||
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
|
||||
assert response.json() == {"count": 2}
|
||||
response = await ac.post("/api/v1/delete", json={"where_filter": {"space_key": "test_space"}})
|
||||
response = await ac.post("/api/v1/delete", json={"where_filter": {"model_space": "test_space"}})
|
||||
assert response.json() == []
|
||||
response = await ac.get("/api/v1/count", params={"space_key": "test_space"})
|
||||
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
|
||||
assert response.json() == {"count": 0}
|
||||
|
||||
# test calculate results
|
||||
@@ -151,11 +151,11 @@ async def test_delete():
|
||||
# async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
# await ac.post("/api/v1/reset")
|
||||
# await post_batch_records(ac)
|
||||
# await ac.post("/api/v1/process", json={"space_key": "test_space"})
|
||||
# await ac.post("/api/v1/process", json={"model_space": "test_space"})
|
||||
# response = await ac.post(
|
||||
# "/api/v1/calculate_results",
|
||||
# json={
|
||||
# "space_key": "test_space",
|
||||
# "model_space": "test_space",
|
||||
# },
|
||||
# )
|
||||
# assert response.status_code == 200
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Union, Any
|
||||
|
||||
# type supports single and batch mode
|
||||
class AddEmbedding(BaseModel):
|
||||
space_key: Union[str, list]
|
||||
model_space: Union[str, list]
|
||||
embedding_data: list
|
||||
input_uri: Union[str, list]
|
||||
dataset: Union[str, list] = None
|
||||
@@ -11,14 +11,14 @@ class AddEmbedding(BaseModel):
|
||||
|
||||
|
||||
class QueryEmbedding(BaseModel):
|
||||
space_key: str = None
|
||||
model_space: str = None
|
||||
embedding: list
|
||||
n_results: int = 10
|
||||
category_name: str = None
|
||||
dataset: str = None
|
||||
|
||||
class ProcessEmbedding(BaseModel):
|
||||
space_key: str = None
|
||||
model_space: str = None
|
||||
|
||||
class FetchEmbedding(BaseModel):
|
||||
where_filter: dict = {}
|
||||
@@ -27,17 +27,17 @@ class FetchEmbedding(BaseModel):
|
||||
offset: int = None
|
||||
|
||||
class CountEmbedding(BaseModel):
|
||||
space_key: str = None
|
||||
model_space: str = None
|
||||
|
||||
class RawSql(BaseModel):
|
||||
raw_sql: str = None
|
||||
|
||||
class Results(BaseModel):
|
||||
space_key: str
|
||||
model_space: str
|
||||
n_results: int = 100
|
||||
|
||||
class SpaceKeyInput(BaseModel):
|
||||
space_key: str
|
||||
model_space: str
|
||||
|
||||
class DeleteEmbedding(BaseModel):
|
||||
where_filter: dict = {}
|
||||
@@ -14,9 +14,9 @@ def create_task(task_type):
|
||||
return True
|
||||
|
||||
@celery.task(name="heavy_offline_analysis")
|
||||
def heavy_offline_analysis(space_key):
|
||||
def heavy_offline_analysis(model_space):
|
||||
task_db_conn = Clickhouse()
|
||||
embedding_rows = task_db_conn.fetch({"space_key": space_key})
|
||||
embedding_rows = task_db_conn.fetch({"model_space": model_space})
|
||||
|
||||
uuids = []
|
||||
custom_quality_scores = []
|
||||
@@ -25,9 +25,9 @@ def heavy_offline_analysis(space_key):
|
||||
uuids.append(row[get_col_pos("uuid")])
|
||||
custom_quality_scores.append(random.random())
|
||||
|
||||
spaces = [space_key] * len(uuids)
|
||||
spaces = [model_space] * len(uuids)
|
||||
|
||||
task_db_conn.delete_results(space_key)
|
||||
task_db_conn.delete_results(model_space)
|
||||
task_db_conn.add_results(spaces, uuids, custom_quality_scores)
|
||||
|
||||
return "Wrote custom quality scores to database"
|
||||
|
||||
Reference in New Issue
Block a user