space_key -> model_space

This commit is contained in:
Jeffrey Huber
2022-11-14 12:40:42 -08:00
parent 6e84b205f0
commit 2bf317c43c
9 changed files with 125 additions and 125 deletions

View File

@@ -1,7 +1,7 @@
from chroma_client import Chroma
chroma = Chroma()
chroma.set_space_key('sample_space')
chroma.set_model_space('sample_space')
print("Getting heartbeat to verify the server is up")
print(chroma.heartbeat())

View File

@@ -5,7 +5,7 @@ from typing import Union
class Chroma:
_api_url = "http://localhost:8000/api/v1"
_space_key = "default_scope"
_model_space = "default_scope"
def __init__(self, url=None, app=None, model_version=None, layer=None):
"""Initialize Chroma client"""
@@ -14,38 +14,38 @@ class Chroma:
self._api_url = url
if app and model_version and layer:
self._space_key = app + "_" + model_version + "_" + layer
self._model_space = app + "_" + model_version + "_" + layer
def set_context(self, app, model_version, layer):
'''Sets the context of the client'''
self._space_key = app + "_" + model_version + "_" + layer
def set_model_space(self, app, model_version, layer):
'''Sets the model_space of the client'''
self._model_space = app + "_" + model_version + "_" + layer
def set_space_key(self, space_key):
def set_model_space(self, model_space):
'''Sets the space key for the client, enables overriding the string concat'''
self._space_key = space_key
self._model_space = model_space
def get_context(self):
'''Returns the space key'''
return self._space_key
def get_model_space(self):
'''Returns the model_space key'''
return self._model_space
def heartbeat(self):
'''Returns the current server time in nanoseconds to check if the server is alive'''
return requests.get(self._api_url).json()
def count(self, space_key=None, all=False):
def count(self, model_space=None, all=False):
'''Returns the number of embeddings in the database'''
params = {"space_key": space_key or self._space_key}
params = {"model_space": model_space or self._model_space}
if all:
params["space_key"] = None
params["model_space"] = None
x = requests.get(self._api_url + "/count", params=params)
return x.json()
def fetch(self, where_filter={}, sort=None, limit=None, offset=None, page=None, page_size=None):
'''Fetches embeddings from the database'''
if self._space_key:
where_filter["space_key"] = self._space_key
if self._model_space:
where_filter["model_space"] = self._model_space
if page and page_size:
offset = (page - 1) * page_size
@@ -60,8 +60,8 @@ class Chroma:
def delete(self, where_filter={}):
'''Deletes embeddings from the database'''
if self._space_key:
where_filter["space_key"] = self._space_key
if self._model_space:
where_filter["model_space"] = self._model_space
return requests.post(self._api_url + "/delete", data=json.dumps({
"where_filter":where_filter,
@@ -72,17 +72,17 @@ class Chroma:
input_uri: list,
dataset: list = None,
category_name: list = None,
space_keys: list = None):
model_spaces: list = None):
'''
Logs a batch of embeddings to the database
- pass in column oriented data lists
'''
if not space_keys:
space_keys = self._space_key
if not model_spaces:
model_spaces = self._model_space
x = requests.post(self._api_url + "/add", data = json.dumps({
"space_key": space_keys,
"model_space": model_spaces,
"embedding_data": embedding_data,
"input_uri": input_uri,
"dataset": dataset,
@@ -94,7 +94,7 @@ class Chroma:
else:
return False
def log_training(self, embedding_data: list, input_uri: list, category_name: list, space_keys: list = None):
def log_training(self, embedding_data: list, input_uri: list, category_name: list, model_spaces: list = None):
'''
Small wrapper around log() to log a batch of training embedding - sets dataset to "training"
'''
@@ -104,10 +104,10 @@ class Chroma:
input_uri=input_uri,
dataset=datasets,
category_name=category_name,
space_keys=space_keys
model_spaces=model_spaces
)
def log_production(self, embedding_data: list, input_uri: list, category_name: list, space_keys: list = None):
def log_production(self, embedding_data: list, input_uri: list, category_name: list, model_spaces: list = None):
'''
Small wrapper around log() to log a batch of production embedding - sets dataset to "production"
'''
@@ -117,10 +117,10 @@ class Chroma:
input_uri=input_uri,
dataset=datasets,
category_name=category_name,
space_keys=space_keys
model_spaces=model_spaces
)
def log_triage(self, embedding_data: list, input_uri: list, category_name: list, space_keys: list = None):
def log_triage(self, embedding_data: list, input_uri: list, category_name: list, model_spaces: list = None):
'''
Small wrapper around log() to log a batch of triage embedding - sets dataset to "triage"
'''
@@ -130,16 +130,16 @@ class Chroma:
input_uri=input_uri,
dataset=datasets,
category_name=category_name,
space_keys=space_keys
model_spaces=model_spaces
)
def get_nearest_neighbors(self, embedding, n_results=10, category_name=None, dataset="training", space_key = None):
def get_nearest_neighbors(self, embedding, n_results=10, category_name=None, dataset="training", model_space = None):
'''Gets the nearest neighbors of a single embedding'''
if not space_key:
space_key = self._space_key
if not model_space:
model_space = self._model_space
x = requests.post(self._api_url + "/get_nearest_neighbors", data = json.dumps({
"space_key": space_key,
"model_space": model_space,
"embedding": embedding,
"n_results": n_results,
"category_name": category_name,
@@ -151,12 +151,12 @@ class Chroma:
else:
return False
def process(self, space_key=None):
def process(self, model_space=None):
'''
Processes embeddings in the database
- currently this only runs hnswlib, doesnt return anything
'''
requests.post(self._api_url + "/process", data = json.dumps({"space_key": space_key or self._space_key}))
requests.post(self._api_url + "/process", data = json.dumps({"model_space": model_space or self._model_space}))
return True
def reset(self):
@@ -167,13 +167,13 @@ class Chroma:
'''Runs a raw SQL query against the database'''
return requests.post(self._api_url + "/raw_sql", data = json.dumps({"raw_sql": sql})).json()
def calculate_results(self, space_key=None):
def calculate_results(self, model_space=None):
'''Calculates the results for the given space key'''
return requests.post(self._api_url + "/calculate_results", data = json.dumps({"space_key": space_key or self._space_key})).json()
return requests.post(self._api_url + "/calculate_results", data = json.dumps({"model_space": model_space or self._model_space})).json()
def get_results(self, space_key=None, n_results = 100):
def get_results(self, model_space=None, n_results = 100):
'''Gets the results for the given space key'''
return requests.post(self._api_url + "/get_results", data = json.dumps({"space_key": space_key or self._space_key, "n_results": n_results})).json()
return requests.post(self._api_url + "/get_results", data = json.dumps({"model_space": model_space or self._model_space, "n_results": n_results})).json()
def get_task_status(self, task_id):
'''Gets the status of a task'''

View File

@@ -38,8 +38,8 @@ async def root():
@app.post("/api/v1/calculate_results")
async def calculate_results(space_key: SpaceKeyInput):
task = heavy_offline_analysis.delay(space_key.space_key)
async def calculate_results(model_space: SpaceKeyInput):
task = heavy_offline_analysis.delay(model_space.model_space)
chroma_telemetry.capture('heavy-offline-analysis')
return JSONResponse({"task_id": task.id})
@@ -55,7 +55,7 @@ async def get_status(task_id):
@app.post("/api/v1/get_results")
async def get_results(results: Results):
return app._db.return_results(results.space_key, results.n_results)
return app._db.return_results(results.model_space, results.n_results)
@@ -65,12 +65,12 @@ async def add_to_db(new_embedding: AddEmbedding):
number_of_embeddings = len(new_embedding.embedding_data)
if isinstance(new_embedding.space_key, str):
space_key = [new_embedding.space_key] * number_of_embeddings
elif len(new_embedding.space_key) == 1:
space_key = [new_embedding.space_key[0]] * number_of_embeddings
if isinstance(new_embedding.model_space, str):
model_space = [new_embedding.model_space] * number_of_embeddings
elif len(new_embedding.model_space) == 1:
model_space = [new_embedding.model_space[0]] * number_of_embeddings
else:
space_key = new_embedding.space_key
model_space = new_embedding.model_space
if isinstance(new_embedding.dataset, str):
dataset = [new_embedding.dataset] * number_of_embeddings
@@ -80,10 +80,10 @@ async def add_to_db(new_embedding: AddEmbedding):
dataset = new_embedding.dataset
# print the len of all inputs to add_batch
print(len(new_embedding.embedding_data), len(new_embedding.input_uri), len(space_key), len(dataset))
print(len(new_embedding.embedding_data), len(new_embedding.input_uri), len(model_space), len(dataset))
app._db.add_batch(
space_key,
model_space,
new_embedding.embedding_data,
new_embedding.input_uri,
dataset,
@@ -98,9 +98,9 @@ async def process(process_embedding: ProcessEmbedding):
'''
Currently generates an index for the embedding db
'''
fetch = app._db.fetch({"space_key": process_embedding.space_key}, columnar=True)
fetch = app._db.fetch({"model_space": process_embedding.model_space}, columnar=True)
chroma_telemetry.capture('created-index', {'n': len(fetch[2])})
app._ann_index.run(process_embedding.space_key, fetch[1], fetch[2]) # more magic number, ugh
app._ann_index.run(process_embedding.model_space, fetch[1], fetch[2]) # more magic number, ugh
return {"response": "Processed space"}
@@ -121,11 +121,11 @@ async def delete(embedding: DeleteEmbedding):
return app._db.delete(embedding.where_filter)
@app.get("/api/v1/count")
async def count(space_key: str = None):
async def count(model_space: str = None):
'''
Returns the number of records in the database
'''
return {"count": app._db.count(space_key=space_key)}
return {"count": app._db.count(model_space=model_space)}
@app.post("/api/v1/reset")
async def reset():
@@ -143,12 +143,12 @@ async def get_nearest_neighbors(embedding: QueryEmbedding):
'''
return the distance, database ids, and embedding themselves for the input embedding
'''
if embedding.space_key is None:
return {"error": "space_key is required"}
if embedding.model_space is None:
return {"error": "model_space is required"}
ids = None
filter_by_where = {}
filter_by_where["space_key"] = embedding.space_key
filter_by_where["model_space"] = embedding.model_space
if embedding.category_name is not None:
filter_by_where["category_name"] = embedding.category_name
if embedding.dataset is not None:
@@ -158,7 +158,7 @@ async def get_nearest_neighbors(embedding: QueryEmbedding):
results = app._db.fetch(filter_by_where)
ids = [str(item[get_col_pos('uuid')]) for item in results]
uuids, distances = app._ann_index.get_nearest_neighbors(embedding.space_key, embedding.embedding, embedding.n_results, ids)
uuids, distances = app._ann_index.get_nearest_neighbors(embedding.model_space, embedding.embedding, embedding.n_results, ids)
return {
"ids": uuids,
"embeddings": app._db.get_by_ids(uuids),

View File

@@ -7,11 +7,11 @@ class Database:
pass
@abstractmethod
def add_batch(self, space_key, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None):
def add_batch(self, model_space, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None):
pass
@abstractmethod
def count(self, space_key=None):
def count(self, model_space=None):
pass
@abstractmethod

View File

@@ -6,7 +6,7 @@ import os
from clickhouse_driver import connect, Client
EMBEDDING_TABLE_SCHEMA = [
{'space_key': 'String'},
{'model_space': 'String'},
{'uuid': 'UUID'},
{'embedding_data': 'Array(Float64)'},
{'input_uri': 'String'},
@@ -15,7 +15,7 @@ EMBEDDING_TABLE_SCHEMA = [
]
RESULTS_TABLE_SCHEMA = [
{'space_key': 'String'},
{'model_space': 'String'},
{'uuid': 'UUID'},
{'custom_quality_score': ' Nullable(Float64)'},
]
@@ -47,7 +47,7 @@ class Clickhouse(Database):
def _create_table_embeddings(self):
self._conn.execute(f'''CREATE TABLE IF NOT EXISTS embeddings (
{db_array_schema_to_clickhouse_schema(EMBEDDING_TABLE_SCHEMA)}
) ENGINE = MergeTree() ORDER BY space_key''')
) ENGINE = MergeTree() ORDER BY model_space''')
self._conn.execute(f'''SET allow_experimental_lightweight_delete = true''')
self._conn.execute(f'''SET mutations_sync = 1''') # https://clickhouse.com/docs/en/operations/settings/settings/#mutations_sync
@@ -55,7 +55,7 @@ class Clickhouse(Database):
def _create_table_results(self):
self._conn.execute(f'''CREATE TABLE IF NOT EXISTS results (
{db_array_schema_to_clickhouse_schema(RESULTS_TABLE_SCHEMA)}
) ENGINE = MergeTree() ORDER BY space_key''')
) ENGINE = MergeTree() ORDER BY model_space''')
def __init__(self):
client = Client(host='clickhouse', port=os.getenv('CLICKHOUSE_TCP_PORT', '9000'))
@@ -63,23 +63,23 @@ class Clickhouse(Database):
self._create_table_embeddings()
self._create_table_results()
def add_batch(self, space_key, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None):
def add_batch(self, model_space, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None):
data_to_insert = []
for i in range(len(embedding_data)):
data_to_insert.append([space_key[i], uuid.uuid4(), embedding_data[i], input_uri[i], dataset[i], category_name[i]])
data_to_insert.append([model_space[i], uuid.uuid4(), embedding_data[i], input_uri[i], dataset[i], category_name[i]])
self._conn.execute('''
INSERT INTO embeddings (space_key, uuid, embedding_data, input_uri, dataset, category_name) VALUES''', data_to_insert)
INSERT INTO embeddings (model_space, uuid, embedding_data, input_uri, dataset, category_name) VALUES''', data_to_insert)
def count(self, space_key=None):
def count(self, model_space=None):
where_string = ""
if space_key is not None:
where_string = f"WHERE space_key = '{space_key}'"
if model_space is not None:
where_string = f"WHERE model_space = '{model_space}'"
return self._conn.execute(f"SELECT COUNT() FROM embeddings {where_string}")[0][0]
def fetch(self, where_filter={}, sort=None, limit=None, offset=None, columnar=False):
if where_filter["space_key"] is None:
return {"error": "space_key is required"}
if where_filter["model_space"] is None:
return {"error": "model_space is required"}
s3= time.time()
# check to see if query is a dict and if it is a flat list of key value pairs
@@ -100,7 +100,7 @@ class Clickhouse(Database):
if sort is not None:
where_filter += f" ORDER BY {sort}"
else:
where_filter += f" ORDER BY space_key" # stable ordering
where_filter += f" ORDER BY model_space" # stable ordering
if limit is not None or isinstance(limit, int):
where_filter += f" LIMIT {limit}"
@@ -120,8 +120,8 @@ class Clickhouse(Database):
return val
def delete(self, where_filter={}):
if where_filter["space_key"] is None:
return {"error": "space_key is required. Use reset to clear the entire db"}
if where_filter["model_space"] is None:
return {"error": "model_space is required. Use reset to clear the entire db"}
s3= time.time()
# check to see if query is a dict and if it is a flat list of key value pairs
@@ -160,18 +160,18 @@ class Clickhouse(Database):
def raw_sql(self, sql):
return self._conn.execute(sql)
def add_results(self, space_keys, uuids, custom_quality_score):
def add_results(self, model_spaces, uuids, custom_quality_score):
data_to_insert = []
for i in range(len(space_keys)):
data_to_insert.append([space_keys[i], uuids[i], custom_quality_score[i]])
for i in range(len(model_spaces)):
data_to_insert.append([model_spaces[i], uuids[i], custom_quality_score[i]])
self._conn.execute('''
INSERT INTO results (space_key, uuid, custom_quality_score) VALUES''', data_to_insert)
INSERT INTO results (model_space, uuid, custom_quality_score) VALUES''', data_to_insert)
def delete_results(self, space_key):
self._conn.execute(f"DELETE FROM results WHERE space_key = '{space_key}'")
def delete_results(self, model_space):
self._conn.execute(f"DELETE FROM results WHERE model_space = '{model_space}'")
def return_results(self, space_key, n_results = 100):
def return_results(self, model_space, n_results = 100):
return self._conn.execute(f'''
SELECT
embeddings.input_uri,
@@ -184,7 +184,7 @@ class Clickhouse(Database):
ON
results.uuid = embeddings.uuid
WHERE
results.space_key = '{space_key}'
results.model_space = '{model_space}'
ORDER BY
results.custom_quality_score DESC
LIMIT {n_results}

View File

@@ -9,7 +9,7 @@ from chroma_server.logger import logger
class Hnswlib(Index):
_space_key = None
_model_space = None
_index = None
_index_metadata = {
'dimensionality': None,
@@ -23,7 +23,7 @@ class Hnswlib(Index):
def __init__(self):
pass
def run(self, space_key, uuids, embeddings):
def run(self, model_space, uuids, embeddings):
# more comments available at the source: https://github.com/nmslib/hnswlib
dimensionality = len(embeddings[0])
ids = []
@@ -42,7 +42,7 @@ class Hnswlib(Index):
index.add_items(embeddings, ids)
self._index = index
self._space_key = space_key
self._model_space = model_space
self._index_metadata = {
'dimensionality': dimensionality,
'elements': len(embeddings) ,
@@ -53,38 +53,38 @@ class Hnswlib(Index):
def save(self):
if self._index is None:
return
self._index.save_index(f"/index_data/index_{self._space_key}.bin")
self._index.save_index(f"/index_data/index_{self._model_space}.bin")
# pickle the mappers
with open(f"/index_data/id_to_uuid_{self._space_key}.pkl", 'wb') as f:
with open(f"/index_data/id_to_uuid_{self._model_space}.pkl", 'wb') as f:
pickle.dump(self._id_to_uuid, f, pickle.HIGHEST_PROTOCOL)
with open(f"/index_data/uuid_to_id_{self._space_key}.pkl", 'wb') as f:
with open(f"/index_data/uuid_to_id_{self._model_space}.pkl", 'wb') as f:
pickle.dump(self._uuid_to_id, f, pickle.HIGHEST_PROTOCOL)
with open(f"/index_data/index_metadata_{self._space_key}.pkl", 'wb') as f:
with open(f"/index_data/index_metadata_{self._model_space}.pkl", 'wb') as f:
pickle.dump(self._index_metadata, f, pickle.HIGHEST_PROTOCOL)
logger.debug('Index saved to /index_data/index.bin')
def load(self, space_key):
def load(self, model_space):
# unpickle the mappers
with open(f"/index_data/id_to_uuid_{space_key}.pkl", 'rb') as f:
with open(f"/index_data/id_to_uuid_{model_space}.pkl", 'rb') as f:
self._id_to_uuid = pickle.load(f)
with open(f"/index_data/uuid_to_id_{space_key}.pkl", 'rb') as f:
with open(f"/index_data/uuid_to_id_{model_space}.pkl", 'rb') as f:
self._uuid_to_id = pickle.load(f)
with open(f"/index_data/index_metadata_{space_key}.pkl", 'rb') as f:
with open(f"/index_data/index_metadata_{model_space}.pkl", 'rb') as f:
self._index_metadata = pickle.load(f)
p = hnswlib.Index(space='l2', dim= self._index_metadata['dimensionality'])
self._index = p
self._index.load_index(f"/index_data/index_{space_key}.bin", max_elements= self._index_metadata['elements'])
self._index.load_index(f"/index_data/index_{model_space}.bin", max_elements= self._index_metadata['elements'])
self._space_key = space_key
self._model_space = model_space
# do knn_query on hnswlib to get nearest neighbors
def get_nearest_neighbors(self, space_key, query, k, uuids=None):
def get_nearest_neighbors(self, model_space, query, k, uuids=None):
if self._space_key != space_key:
self.load(space_key)
if self._model_space != model_space:
self.load(model_space)
s2= time.time()
# get ids from uuids

View File

@@ -23,7 +23,7 @@ async def post_batch_records(ac):
"input_uri": ["https://example.com", "https://example.com"],
"dataset": ["training", "training"],
"category_name": ["person", "person"],
"space_key": ["test_space", "test_space"],
"model_space": ["test_space", "test_space"],
},
)
@@ -35,7 +35,7 @@ async def post_batch_records_minimal(ac):
"input_uri": ["https://example.com", "https://example.com"],
"dataset": "training",
"category_name": ["person", "person"],
"space_key": "test_space"
"model_space": "test_space"
},
)
@@ -47,7 +47,7 @@ async def test_add_to_db_batch():
response = await post_batch_records(ac)
assert response.status_code == 201
assert response.json() == {"response": "Added records to database"}
response = await ac.get("/api/v1/count", params={"space_key": "test_space"})
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
assert response.json() == {"count": 2}
@@ -58,7 +58,7 @@ async def test_add_to_db_batch_minimal():
response = await post_batch_records_minimal(ac)
assert response.status_code == 201
assert response.json() == {"response": "Added records to database"}
response = await ac.get("/api/v1/count", params={"space_key": "test_space"})
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
assert response.json() == {"count": 2}
@pytest.mark.anyio
@@ -66,7 +66,7 @@ async def test_fetch_from_db():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.post("/api/v1/reset")
await post_batch_records(ac)
params = {"where_filter": {"space_key": "test_space"}}
params = {"where_filter": {"model_space": "test_space"}}
response = await ac.post("/api/v1/fetch", json=params)
assert response.status_code == 200
assert len(response.json()) == 2
@@ -76,7 +76,7 @@ async def test_count_from_db():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.post("/api/v1/reset") # reset db
await post_batch_records(ac)
response = await ac.get("/api/v1/count", params={"space_key": "test_space"})
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
assert response.status_code == 200
assert response.json() == {"count": 2}
@@ -85,11 +85,11 @@ async def test_reset_db():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.post("/api/v1/reset")
await post_batch_records(ac)
response = await ac.get("/api/v1/count", params={"space_key": "test_space"})
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
assert response.json() == {"count": 2}
response = await ac.post("/api/v1/reset")
assert response.json() == True
response = await ac.get("/api/v1/count", params={"space_key": "test_space"})
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
assert response.json() == {"count": 0}
@pytest.mark.anyio
@@ -97,9 +97,9 @@ async def test_get_nearest_neighbors():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.post("/api/v1/reset")
await post_batch_records(ac)
await ac.post("/api/v1/process", json={"space_key": "test_space"})
await ac.post("/api/v1/process", json={"model_space": "test_space"})
response = await ac.post(
"/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1, "space_key": "test_space"}
"/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1, "model_space": "test_space"}
)
assert response.status_code == 200
assert len(response.json()["ids"]) == 1
@@ -109,7 +109,7 @@ async def test_get_nearest_neighbors_filter():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.post("/api/v1/reset")
await post_batch_records(ac)
await ac.post("/api/v1/process", json={"space_key": "test_space"})
await ac.post("/api/v1/process", json={"model_space": "test_space"})
response = await ac.post(
"/api/v1/get_nearest_neighbors",
json={
@@ -117,7 +117,7 @@ async def test_get_nearest_neighbors_filter():
"n_results": 1,
"dataset": "training",
"category_name": "monkey",
"space_key": "test_space",
"model_space": "test_space",
},
)
assert response.status_code == 200
@@ -128,7 +128,7 @@ async def test_process():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.post("/api/v1/reset")
await post_batch_records(ac)
response = await ac.post("/api/v1/process", json={"space_key": "test_space"})
response = await ac.post("/api/v1/process", json={"model_space": "test_space"})
assert response.status_code == 200
assert response.json() == {"response": "Processed space"}
@@ -138,11 +138,11 @@ async def test_delete():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.post("/api/v1/reset")
await post_batch_records(ac)
response = await ac.get("/api/v1/count", params={"space_key": "test_space"})
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
assert response.json() == {"count": 2}
response = await ac.post("/api/v1/delete", json={"where_filter": {"space_key": "test_space"}})
response = await ac.post("/api/v1/delete", json={"where_filter": {"model_space": "test_space"}})
assert response.json() == []
response = await ac.get("/api/v1/count", params={"space_key": "test_space"})
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
assert response.json() == {"count": 0}
# test calculate results
@@ -151,11 +151,11 @@ async def test_delete():
# async with AsyncClient(app=app, base_url="http://test") as ac:
# await ac.post("/api/v1/reset")
# await post_batch_records(ac)
# await ac.post("/api/v1/process", json={"space_key": "test_space"})
# await ac.post("/api/v1/process", json={"model_space": "test_space"})
# response = await ac.post(
# "/api/v1/calculate_results",
# json={
# "space_key": "test_space",
# "model_space": "test_space",
# },
# )
# assert response.status_code == 200

View File

@@ -3,7 +3,7 @@ from typing import Union, Any
# type supports single and batch mode
class AddEmbedding(BaseModel):
space_key: Union[str, list]
model_space: Union[str, list]
embedding_data: list
input_uri: Union[str, list]
dataset: Union[str, list] = None
@@ -11,14 +11,14 @@ class AddEmbedding(BaseModel):
class QueryEmbedding(BaseModel):
space_key: str = None
model_space: str = None
embedding: list
n_results: int = 10
category_name: str = None
dataset: str = None
class ProcessEmbedding(BaseModel):
space_key: str = None
model_space: str = None
class FetchEmbedding(BaseModel):
where_filter: dict = {}
@@ -27,17 +27,17 @@ class FetchEmbedding(BaseModel):
offset: int = None
class CountEmbedding(BaseModel):
space_key: str = None
model_space: str = None
class RawSql(BaseModel):
raw_sql: str = None
class Results(BaseModel):
space_key: str
model_space: str
n_results: int = 100
class SpaceKeyInput(BaseModel):
space_key: str
model_space: str
class DeleteEmbedding(BaseModel):
where_filter: dict = {}

View File

@@ -14,9 +14,9 @@ def create_task(task_type):
return True
@celery.task(name="heavy_offline_analysis")
def heavy_offline_analysis(space_key):
def heavy_offline_analysis(model_space):
task_db_conn = Clickhouse()
embedding_rows = task_db_conn.fetch({"space_key": space_key})
embedding_rows = task_db_conn.fetch({"model_space": model_space})
uuids = []
custom_quality_scores = []
@@ -25,9 +25,9 @@ def heavy_offline_analysis(space_key):
uuids.append(row[get_col_pos("uuid")])
custom_quality_scores.append(random.random())
spaces = [space_key] * len(uuids)
spaces = [model_space] * len(uuids)
task_db_conn.delete_results(space_key)
task_db_conn.delete_results(model_space)
task_db_conn.add_results(spaces, uuids, custom_quality_scores)
return "Wrote custom quality scores to database"