Merge branch 'main' into jeff/celery

This commit is contained in:
Jeffrey Huber
2022-11-10 16:04:53 -08:00
40 changed files with 666 additions and 76 deletions

View File

@@ -44,6 +44,7 @@ jobs:
with:
context: chroma-server
push: true
target: chroma_server
tags: ${{ steps.tag.outputs.tag_name}}
- name: Get Release Version
id: version

View File

@@ -17,19 +17,10 @@ jobs:
strategy:
matrix:
python: ['3.10']
platform: [ubuntu-latest, macos-latest]
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python }}
- name: Install test dependencies
run: |
cd chroma-server && python -m pip install -r requirements.txt -r requirements_dev.txt
- name: Install chroma_client
run: cd chroma-client && pip install .
- name: Test
run: cd chroma-server && python -m pytest
run: cd chroma-server && bin/test

2
.gitignore vendored
View File

@@ -9,4 +9,4 @@ chroma-server/chroma_logs.log
**/data__nogit
**/.ipynb_checkpoints
**/.ipynb_checkpoints

5
Makefile Normal file
View File

@@ -0,0 +1,5 @@
black:
black --fast chroma-server chroma-client
check_black:
black --check --fast chroma-server chroma-client

View File

@@ -9,3 +9,38 @@ Contents:
- `/chroma-client` - Python client for Chroma
- `/chroma-server` - FastAPI server used as the backend for Chroma client
### Get up and running on Linux
No requirements
```
/bin/bash -c "$(curl -fsSL https://gist.githubusercontent.com/jeffchuber/effcbac05021e863bbd634f4b7d0283d/raw/4d38b150809d6ccbc379f88433cadd86c81d32cd/chroma_setup.sh)"
python3 chroma/bin/test.py
```
### Get up and running on Mac
Requirements
- git
- Docker & `docker-compose`
- pip
```
/bin/bash -c "$(curl -fsSL https://gist.githubusercontent.com/jeffchuber/27a3cbb28e6521c811da6398346cd35f/raw/55c2d82870436431120a9446b47f19b72d88fa31/chroma_setup_mac.sh)"
python3 chroma/bin/test.py
```
* These urls will be swapped out for the link in the repo once it is live
### You should see something like
```
Getting heartbeat to verify the server is up
{'nanosecond heartbeat': 1667865642509760965000}
Logging embeddings into the database
Generating the index
True
Running a nearest neighbor search
{'ids': ['11540ca6-ebbc-4c81-8299-108d8c47c88c'], 'embeddings': [['sample_space', '11540ca6-ebbc-4c81-8299-108d8c47c88c', [1.0, 2.0, 3.0, 4.0, 5.0], '/images/1', 'training', None, 'spoon']], 'distances': [0.0]}
Success! Everything worked!
```

42
bin/setup_linux.sh Normal file
View File

@@ -0,0 +1,42 @@
#!/usr/bin/env bash
# install pip
apt install -y python3-pip
# install docker
sudo apt-get update
sudo apt-get -y install \
ca-certificates \
curl \
gnupg \
lsb-release
sudo mkdir -p /etc/apt/keyrings
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
echo \
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \
$(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
sudo apt-get update
sudo apt-get -y install docker-ce docker-ce-cli containerd.io docker-compose-plugin
pip3 install docker-compose
# get the code
git clone https://oauth2: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@github.com/chroma-core/chroma.git
#checkout the right branch
cd chroma
git checkout jeff/packaging
# run docker
cd chroma-server
docker-compose up -d --build
# install chroma-client
cd ../chroma-client
pip3 install --upgrade pip # you have to do this or it will use UNKNOWN as the package name
pip3 install .

19
bin/setup_mac.sh Normal file
View File

@@ -0,0 +1,19 @@
# requirements
# - docker
# - pip
# get the code
git clone https://oauth2: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@github.com/chroma-core/chroma.git
#checkout the right branch
cd chroma
git checkout jeff/packaging
# run docker
cd chroma-server
docker-compose up -d --build
# install chroma-client
cd ../chroma-client
pip install --upgrade pip # you have to do this or it will use UNKNOWN as the package name
pip install .

23
bin/test.py Normal file
View File

@@ -0,0 +1,23 @@
from chroma_client import Chroma
chroma = Chroma()
chroma.set_space_key('sample_space')
print("Getting heartbeat to verify the server is up")
print(chroma.heartbeat())
print("Logging embeddings into the database")
chroma.log(
[[1,2,3,4,5], [5,4,3,2,1], [10,9,8,7,6]],
["/images/1", "/images/2", "/images/3"],
["training", "training", "training"],
['spoon', 'knife', 'fork']
)
# print("fetch", chroma.fetch())
print("Generating the index")
print(chroma.process())
print("Running a nearest neighbor search")
print(chroma.get_nearest_neighbors([1,2,3,4,5], 1))
print("Success! Everything worked!")

View File

@@ -2,10 +2,11 @@ import requests
import json
from typing import Union
class Chroma:
_api_url = "http://localhost:8000/api/v1"
_space_key = None
_space_key = "default_scope"
def __init__(self, url=None, app=None, model_version=None, layer=None):
"""Initialize Chroma client"""
@@ -159,4 +160,4 @@ class Chroma:
def get_task_status(self, task_id):
'''Gets the status of a task'''
return requests.get(self._api_url + f"/tasks/{task_id}").json()
return requests.get(self._api_url + f"/tasks/{task_id}").json()

View File

@@ -4,11 +4,14 @@ from chroma_client import Chroma
import pytest
import time
from httpx import AsyncClient
# from ..api import app # this wont work because i moved the file
@pytest.fixture
def anyio_backend():
return 'asyncio'
return "asyncio"
def test_init():
chroma = Chroma()

3
chroma-server/.env Normal file
View File

@@ -0,0 +1,3 @@
disable_anonymized_telemetry=False
environment=development
telemetry_anonymized_uuid=f80b11fc-1c5a-4a90-ba35-8c3a3c5371cc

View File

@@ -1,4 +1,4 @@
FROM --platform=linux/amd64 python:3.10
FROM --platform=linux/amd64 python:3.10 AS chroma_server
#RUN apt-get update -qq
#RUN apt-get install python3.10 python3-pip -y --no-install-recommends && rm -rf /var/lib/apt/lists_/*
@@ -15,3 +15,13 @@ EXPOSE 8000
CMD ["uvicorn", "chroma_server:app", "--host", "0.0.0.0", "--port", "8000", "--proxy-headers"]
# Use a multi-stage build to layer in test dependencies without bloating server image
# https://docs.docker.com/build/building/multi-stage/
# Note: requires passing --target to docker-build.
FROM chroma_server AS chroma_server_test
COPY ./requirements_dev.txt requirements_dev.txt
RUN pip install --no-cache-dir --upgrade -r requirements_dev.txt
CMD ["python", "-m", "pytest"]

View File

@@ -12,7 +12,8 @@ pip install -r requirements.txt
pip install -r requirements_dev.txt
```
To run tests, run `pytest`.
To run tests, run `bin/test`. This will run the test suite inside a
docker compose cluster, with the database available, and clean up when complete.
To run the server locally, in development mode, run `uvicorn chroma_server:app --reload`
@@ -43,4 +44,4 @@ To run use `docker images` to see what containers and tags you have available:
docker run -p 8000:8000 ghcr.io/chroma-core/chroma-server:<tag name -- eg 0.0.2-dirty>>
```
This will expose the internal app at `localhost:8000`
This will expose the internal app at `localhost:8000`

View File

@@ -1,3 +1,3 @@
#!/usr/bin/env bash
docker build . -t ghcr.io/chroma-core/chroma-server:`bin/version`
docker build . --target chroma_server -t ghcr.io/chroma-core/chroma-server:`bin/version`

11
chroma-server/bin/test Executable file
View File

@@ -0,0 +1,11 @@
#!/usr/bin/env bash
set -e
function cleanup {
docker-compose -f docker-compose.test.yml down
}
trap cleanup EXIT
docker-compose -f docker-compose.test.yml run --rm server_test

View File

@@ -1,8 +1,8 @@
import random
def rand_bisectional_subsample(data):
"""
Randomly bisectionally subsample a list of data to size.
"""
return data.sample(frac=0.5, replace=True, random_state=1)
return data.sample(frac=0.5, replace=True, random_state=1)

View File

@@ -2,22 +2,23 @@ import numpy as np
import json
import ast
def class_distances(data):
''''
"""'
This is all very subject to change, so essentially just copy and paste from what we had before
'''
"""
return False
# def unpack_annotations(embeddings):
# annotations = [json.loads(embedding['infer'])["annotations"]for embedding in embeddings]
# annotations = [annotation for annotation_list in annotations for annotation in annotation_list]
# annotations = [annotation for annotation_list in annotations for annotation in annotation_list]
# # Unpack embedding data
# embeddings = [embedding["embedding_data"] for embedding in embeddings]
# embedding_vectors_by_category = {}
# for embedding_annotation_pair in zip(embeddings, annotations):
# data = np.array(embedding_annotation_pair[0])
# category = embedding_annotation_pair[1]['category_id']
# category = embedding_annotation_pair[1]['category_id']
# if category in embedding_vectors_by_category.keys():
# embedding_vectors_by_category[category] = np.append(
# embedding_vectors_by_category[category], data[np.newaxis, :], axis=0
@@ -84,5 +85,5 @@ def class_distances(data):
# if (len(inferences) == 0):
# raise Exception("No inferences found for datapoint")
# return output_distances
# return output_distances

View File

@@ -2,7 +2,7 @@ import os
import shutil
import time
from fastapi import FastAPI, Response, status
from fastapi import FastAPI, status
from fastapi.responses import JSONResponse
from worker import heavy_offline_analysis
@@ -12,6 +12,13 @@ from chroma_server.index.hnswlib import Hnswlib
from chroma_server.types import AddEmbedding, QueryEmbedding, ProcessEmbedding, FetchEmbedding, CountEmbedding, RawSql, Results, SpaceKeyInput
from chroma_server.utils import logger
from chroma_server.utils.telemetry.capture import Capture
from chroma_server.utils.error_reporting import init_error_reporting
chroma_telemetry = Capture()
chroma_telemetry.capture('server-start')
init_error_reporting()
from celery.result import AsyncResult
# Boot script
@@ -66,6 +73,7 @@ async def add_to_db(new_embedding: AddEmbedding):
return {"response": "Added records to database"}
@app.get("/api/v1/process")
async def process(process_embedding: ProcessEmbedding):
'''
@@ -74,6 +82,7 @@ async def process(process_embedding: ProcessEmbedding):
fetch = app._db.fetch({"space_key": process_embedding.space_key}, columnar=True)
app._ann_index.run(process_embedding.space_key, fetch[1], fetch[2]) # more magic number, ugh
@app.get("/api/v1/fetch")
async def fetch(fetch_embedding: FetchEmbedding):
'''
@@ -82,6 +91,7 @@ async def fetch(fetch_embedding: FetchEmbedding):
'''
return app._db.fetch(fetch_embedding.where_filter, fetch_embedding.sort, fetch_embedding.limit)
@app.get("/api/v1/count")
async def count(count_embedding: CountEmbedding):
'''
@@ -89,6 +99,7 @@ async def count(count_embedding: CountEmbedding):
'''
return {"count": app._db.count(space_key=count_embedding.space_key)}
@app.get("/api/v1/reset")
async def reset():
'''
@@ -102,7 +113,7 @@ async def reset():
@app.post("/api/v1/get_nearest_neighbors")
async def get_nearest_neighbors(embedding: QueryEmbedding):
'''
"""
return the distance, database ids, and embedding themselves for the input embedding
'''
if embedding.space_key is None:
@@ -112,9 +123,9 @@ async def get_nearest_neighbors(embedding: QueryEmbedding):
filter_by_where = {}
filter_by_where["space_key"] = embedding.space_key
if embedding.category_name is not None:
filter_by_where['category_name'] = embedding.category_name
filter_by_where["category_name"] = embedding.category_name
if embedding.dataset is not None:
filter_by_where['dataset'] = embedding.dataset
filter_by_where["dataset"] = embedding.dataset
if filter_by_where is not None:
results = app._db.fetch(filter_by_where)
@@ -129,4 +140,4 @@ async def get_nearest_neighbors(embedding: QueryEmbedding):
@app.get("/api/v1/raw_sql")
async def raw_sql(raw_sql: RawSql):
return app._db.raw_sql(raw_sql.raw_sql)
return app._db.raw_sql(raw_sql.raw_sql)

View File

@@ -1,6 +1,7 @@
from abc import abstractmethod
class Database():
class Database:
@abstractmethod
def __init__(self):
pass
@@ -23,4 +24,4 @@ class Database():
@abstractmethod
def reset(self):
pass
pass

View File

@@ -0,0 +1,212 @@
from os import EX_CANTCREAT
from chroma_server.db.abstract import Database
import duckdb
class DuckDB(Database):
_conn = None
def __init__(self):
self._conn = duckdb.connect()
self._conn.execute(
"""
CREATE TABLE embeddings (
id integer PRIMARY KEY,
embedding_data REAL[],
input_uri STRING,
dataset STRING,
custom_quality_score REAL,
category_name STRING
)
"""
)
# ids to manage internal bookkeeping and *nothing else*, users should not have to care about these ids
self._conn.execute(
"""
CREATE SEQUENCE seq_id START 1;
"""
)
self._conn.execute(
"""
-- change the default null sorting order to either NULLS FIRST and NULLS LAST
PRAGMA default_null_order='NULLS LAST';
-- change the default sorting order to either DESC or ASC
PRAGMA default_order='DESC';
"""
)
return
def add_batch(
self, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None
):
"""
Add embeddings to the database
This accepts both a single input and a list of inputs
"""
# create list of the types of all inputs
types = [type(x).__name__ for x in [embedding_data, input_uri]]
# if all of the types are 'list' - do batch mode
if all(x == "list" for x in types):
lengths = [len(x) for x in [embedding_data, input_uri]]
# accepts some inputs as str or none, and this multiples them out to the correct length
if custom_quality_score is None or isinstance(custom_quality_score, str):
custom_quality_score = [custom_quality_score] * lengths[0]
if category_name is None or isinstance(category_name, str):
category_name = [category_name] * lengths[0]
if dataset is None or isinstance(dataset, str):
dataset = [dataset] * lengths[0]
# we have to move from column to row format for duckdb
data_to_insert = []
for i in range(lengths[0]):
data_to_insert.append(
[
embedding_data[i],
input_uri[i],
dataset[i],
custom_quality_score[i],
category_name[i],
]
)
if all(x == lengths[0] for x in lengths):
self._conn.executemany(
"""
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)""",
data_to_insert,
)
return
# if any of the types are 'list' - throw an error
if any(x == list for x in [input_uri, dataset, custom_quality_score, category_name]):
raise Exception(
"Invalid input types. One input is a list where others are not: " + str(types)
)
# single insert mode
# This should never fire because we do everything in batch mode, but given the mode away from duckdb likely, I am just leaving it in
self._conn.execute(
"""
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)""",
[embedding_data, input_uri, dataset, custom_quality_score, category_name],
)
def count(self):
return self._conn.execute(
"""
SELECT COUNT(*) FROM embeddings;
"""
).fetchone()[0]
def update(self, data): # call this update_custom_quality_score! that is all it does
"""
I was not able to figure out (yet) how to do a bulk update in duckdb
This is going to be fairly slow
"""
for element in data:
if element["custom_quality_score"] is None:
continue
self._conn.execute(
f"""
UPDATE embeddings SET custom_quality_score={element['custom_quality_score']} WHERE id={element['id']}"""
)
def fetch(self, where_filter={}, sort=None, limit=None):
# check to see if query is a dict and if it is a flat list of key value pairs
if where_filter is not None:
if not isinstance(where_filter, dict):
raise Exception("Invalid where_filter: " + str(where_filter))
# ensure where_filter is a flat dict
for key in where_filter:
if isinstance(where_filter[key], dict):
raise Exception("Invalid where_filter: " + str(where_filter))
where_filter = " AND ".join([f"{key} = '{value}'" for key, value in where_filter.items()])
if where_filter:
where_filter = f"WHERE {where_filter}"
if sort is not None:
where_filter += f" ORDER BY {sort}"
if limit is not None or isinstance(limit, int):
where_filter += f" LIMIT {limit}"
return (
self._conn.execute(
f"""
SELECT
id,
embedding_data,
input_uri,
dataset,
custom_quality_score,
category_name
FROM
embeddings
{where_filter}
"""
)
.fetchdf()
.replace({np.nan: None})
) # replace nan with None for json serialization
def delete_batch(self, batch):
raise NotImplementedError
def persist(self):
"""
Persist the database to disk
"""
if self._conn is None:
return
self._conn.execute(
"""
COPY
(SELECT * FROM embeddings)
TO '.chroma/chroma.parquet'
(FORMAT PARQUET);
"""
)
def load(self, path=".chroma/chroma.parquet"):
"""
Load the database from disk
"""
self._conn.execute(f"INSERT INTO embeddings SELECT * FROM read_parquet('{path}');")
def get_by_ids(self, ids=list):
# select from duckdb table where ids are in the list
if not isinstance(ids, list):
raise Exception("ids must be a list")
if not ids:
# create an empty pandas dataframe
return pd.DataFrame()
return (
self._conn.execute(
f"""
SELECT
id,
embedding_data,
input_uri,
dataset,
custom_quality_score,
category_name
FROM
embeddings
WHERE
id IN ({','.join([str(x) for x in ids])})
"""
)
.fetchdf()
.replace({np.nan: None})
) # replace nan with None for json serialization

View File

@@ -1,6 +1,7 @@
from abc import abstractmethod
class Index():
class Index:
@abstractmethod
def __init__(self):
pass
@@ -23,4 +24,4 @@ class Index():
@abstractmethod
def load(self):
pass
pass

View File

@@ -4,7 +4,8 @@ import time
import os
import numpy as np
from chroma_server.index.abstract import Index
from chroma_server.utils import logger
from chroma_server.logger import logger
class Hnswlib(Index):
@@ -16,9 +17,6 @@ class Hnswlib(Index):
'time_created': None,
}
# these data structures enable us to map between uuids and ids
# - our uuids are strings (clickhouse doesnt do autoincrementing ids for performance)
# - but hnswlib uses integers only as ids
_id_to_uuid = {}
_uuid_to_id = {}

View File

@@ -1,5 +1,6 @@
import logging
def setup_logging():
logging.basicConfig(filename="chroma_logs.log")
logger = logging.getLogger("Chroma")
@@ -7,4 +8,5 @@ def setup_logging():
logger.debug("Logger created")
return logger
logger = setup_logging()
logger = setup_logging()

View File

@@ -4,32 +4,45 @@ from httpx import AsyncClient
from ..api import app
@pytest.fixture
def anyio_backend():
return 'asyncio'
return "asyncio"
@pytest.mark.anyio
async def test_root():
async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.get("/api/v1")
assert response.status_code == 200
assert abs(response.json()["nanosecond heartbeat"] - int(1000 * time.time_ns())) < 3_000_000_000 # a billion nanoseconds = 3s
assert (
abs(response.json()["nanosecond heartbeat"] - int(1000 * time.time_ns())) < 3_000_000_000
) # a billion nanoseconds = 3s
async def post_one_record(ac):
return await ac.post("/api/v1/add", json={
"embedding_data": [1.02, 2.03, 3.03],
"input_uri": "https://example.com",
"dataset": "coco",
"category_name": "person"
})
return await ac.post(
"/api/v1/add",
json={
"embedding_data": [1.02, 2.03, 3.03],
"input_uri": "https://example.com",
"dataset": "coco",
"category_name": "person",
},
)
async def post_batch_records(ac):
return await ac.post("/api/v1/add", json={
"embedding_data": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"input_uri": ["https://example.com", "https://example.com"],
"dataset": "training",
"category_name": "person"
})
return await ac.post(
"/api/v1/add",
json={
"embedding_data": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"input_uri": ["https://example.com", "https://example.com"],
"dataset": "training",
"category_name": "person",
},
)
@pytest.mark.anyio
async def test_add_to_db():
@@ -38,6 +51,7 @@ async def test_add_to_db():
assert response.status_code == 201
assert response.json() == {"response": "Added record to database"}
@pytest.mark.anyio
async def test_add_to_db_batch():
async with AsyncClient(app=app, base_url="http://test") as ac:
@@ -45,6 +59,7 @@ async def test_add_to_db_batch():
assert response.status_code == 201
assert response.json() == {"response": "Added record to database"}
@pytest.mark.anyio
async def test_fetch_from_db():
async with AsyncClient(app=app, base_url="http://test") as ac:
@@ -53,15 +68,17 @@ async def test_fetch_from_db():
assert response.status_code == 200
assert len(response.json()) == 1
@pytest.mark.anyio
async def test_count_from_db():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.get("/api/v1/reset") # reset db
await ac.get("/api/v1/reset") # reset db
await post_batch_records(ac)
response = await ac.get("/api/v1/count")
assert response.status_code == 200
assert response.json() == {"count": 2}
@pytest.mark.anyio
async def test_reset_db():
async with AsyncClient(app=app, base_url="http://test") as ac:
@@ -74,25 +91,19 @@ async def test_reset_db():
response = await ac.get("/api/v1/count")
assert response.json() == {"count": 0}
@pytest.mark.anyio
async def test_get_nearest_neighbors():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.get("/api/v1/reset")
await post_batch_records(ac)
await ac.get("/api/v1/process")
response = await ac.post("/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1})
response = await ac.post(
"/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1}
)
assert response.status_code == 200
assert len(response.json()["ids"]) == 1
@pytest.mark.anyio
async def test_get_nearest_neighbors_filter():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.get("/api/v1/reset")
await post_batch_records(ac)
await ac.get("/api/v1/process")
response = await ac.post("/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1, "dataset": "training", "category_name": "monkey"})
assert response.status_code == 200
assert len(response.json()["ids"]) == 0
@pytest.mark.anyio
async def test_get_nearest_neighbors_filter():
@@ -100,6 +111,33 @@ async def test_get_nearest_neighbors_filter():
await ac.get("/api/v1/reset")
await post_batch_records(ac)
await ac.get("/api/v1/process")
response = await ac.post("/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 2, "dataset": "training", "category_name": "person"})
response = await ac.post(
"/api/v1/get_nearest_neighbors",
json={
"embedding": [1.1, 2.3, 3.2],
"n_results": 1,
"dataset": "training",
"category_name": "monkey",
},
)
assert response.status_code == 200
assert len(response.json()["ids"]) == 0
@pytest.mark.anyio
async def test_get_nearest_neighbors_filter():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.get("/api/v1/reset")
await post_batch_records(ac)
await ac.get("/api/v1/process")
response = await ac.post(
"/api/v1/get_nearest_neighbors",
json={
"embedding": [1.1, 2.3, 3.2],
"n_results": 2,
"dataset": "training",
"category_name": "person",
},
)
assert response.status_code == 200
assert len(response.json()["ids"]) == 2

View File

@@ -7,9 +7,10 @@ class AddEmbedding(BaseModel):
embedding_data: list
input_uri: Union[str, list]
dataset: Union[str, list] = None
custom_quality_score: Union[float, list] = None
custom_quality_score: Union[float, list] = None
category_name: Union[str, list] = None
class QueryEmbedding(BaseModel):
space_key: str = None
embedding: list

View File

@@ -0,0 +1,15 @@
from functools import lru_cache
from typing import Union
from pydantic import BaseSettings
class Settings(BaseSettings):
disable_anonymized_telemetry: bool = False
telemetry_anonymized_uuid: str = ''
environment: str = 'development'
class Config:
env_file = ".env"
@lru_cache()
def get_settings():
return Settings()

View File

@@ -0,0 +1,27 @@
from chroma_server.utils.config.settings import get_settings
import sentry_sdk
from sentry_sdk.client import Client
from sentry_sdk import configure_scope
from posthog.sentry.posthog_integration import PostHogIntegration
PostHogIntegration.organization = "chroma"
sample_rate = 1.0
if get_settings().environment == "production":
sample_rate = 0.1
def strip_sensitive_data(event, hint):
if 'server_name' in event:
del event['server_name']
return event
def init_error_reporting():
sentry_sdk.init(
dsn="https://ef5fae1e461f49b3a7a2adf3404378ab@o4504080408051712.ingest.sentry.io/4504080409296896",
traces_sample_rate=sample_rate,
integrations=[PostHogIntegration()],
environment=get_settings().environment,
before_send=strip_sensitive_data,
)
with configure_scope() as scope:
scope.set_tag('posthog_distinct_id', get_settings().telemetry_anonymized_uuid)

View File

@@ -0,0 +1,10 @@
from abc import abstractmethod
class Telemetry():
@abstractmethod
def __init__(self):
pass
@abstractmethod
def capture(self, event, properties=None):
pass

View File

@@ -0,0 +1,34 @@
import posthog
import uuid
import sys
from chroma_server.utils.telemetry.abstract import Telemetry
from chroma_server.utils.config.settings import get_settings
class Capture(Telemetry):
_conn = None
_telemetry_anonymized_uuid = None
def __init__(self):
if get_settings().disable_anonymized_telemetry:
posthog.disabled = True
# disable telemetry if we're running tests
if "pytest" in sys.modules:
posthog.disabled = True
posthog.project_api_key = 'phc_YeUxaojbKk5KPi8hNlx1bBKHzuZ4FDtl67kH1blv8Bh'
posthog.host = 'https://app.posthog.com'
self._conn = posthog
if not get_settings().telemetry_anonymized_uuid:
self._telemetry_anonymized_uuid = uuid.uuid4()
with open(".env", "a") as f:
f.write(f"\ntelemetry_anonymized_uuid={self._telemetry_anonymized_uuid}\n")
else:
self._telemetry_anonymized_uuid = get_settings().telemetry_anonymized_uuid
def capture(self, event, properties=None):
self._conn.capture(self._telemetry_anonymized_uuid, event, properties)

View File

@@ -0,0 +1,26 @@
version: '3.9'
networks:
my-network:
driver: bridge
services:
server_test:
build:
context: .
dockerfile: Dockerfile
target: chroma_server_test
depends_on:
- clickhouse
networks:
- my-network
clickhouse:
image: docker.io/bitnami/clickhouse:22.9
environment:
- ALLOW_EMPTY_PASSWORD=yes
ports:
- '8123:8123'
- '9000:9000'
networks:
- my-network

View File

@@ -6,7 +6,10 @@ pandas==1.5.0
duckdb==0.5.1
hnswlib @ git+https://oauth2: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@github.com/chroma-core/hnswlib.git
clickhouse_driver==0.2.4
redis==3.5.3
celery==4.4.7
celery==4.4.7
clickhouse_driver==0.2.4
posthog==2.1.2
uuid==1.30
sentry_sdk==1.10.1
pydantic==1.9.0

View File

@@ -1,8 +1,3 @@
httpx
pytest
setuptools_scm
duckdb
hnswlib @ git+https://oauth2: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@github.com/chroma-core/hnswlib.git
pandas
numpy
pyarrow

View File

@@ -0,0 +1,61 @@
# Clickhouse Architecture
## Context
The current prototype of Chroma Server uses DuckDB and Parquet files for
persistence. Although the simplicity and batch data retrieval
characteristics of this are attractive, we determine that this is
suboptimal for three primary reasons:
- Chroma's primary mode of ingesting data is a stream of small batches
of embeddings. DuckDB and Parquet are not well optimized for
streaming input. In fact, it's impossible to append to a Parquet
file; the entire file must be re-written or additional files
created.
- DuckDB explicitly does not support multiple writer processes, which
we will likely want in the medium term.
- DuckDB + Parquet requires an explicit flush or write operation to
persist data. This adds an element of "state management" and is
complexity that we would rather not expose to the client.
Therefore, we are looking for an architecture with the following quantities:
- Efficient streaming ingest
- Efficient bulk read to pull data into memory for processing (OLAP)
- Low volume transactional CRUD operations (e.g datasets and metadata)
- Low administrative overhead, to present as small a client API as
possible. We want to avoid exposing any methods aside from those
that define Chroma as a product, for a focused user experience.
## Decision
We will use Clickhouse as the persistence layer. For now it will be
the only persistence mechanism used by Chroma.
Instances of Chroma Server will be stateless, aside from caching for
performance.
The MVP will run in a simple `docker-compose` configuration with a
single Clickhouse and a single Chroma Server.
The Chroma Server will, when required to service a read operation,
pull entire datasets from Clickhouse into memory and keep them cached
in order to perform algorithmic work on demand.
![Clickhouse Architecture](./2022-11-01-clickhouse-architecture/diagram.png "Clickhouse Architecture")
## Consequences
- The MVP is actually less complex than the previous DuckDB based
solution.
- We can scale horizontally by adding more Chroma Server instances in
a cluster.
- We can scale vertically by using a larger instance of Clickhouse or
moving to clustered Clickhouse as workloads grow.
- At some point in the future, we will likely need to add an OLTP
database, when the system contains enough transactional data that
Clickhouse starts to perform poorly for row-based updates.
- We maintain separation of concerns, and can make future changes to
the data persistence mechanisms without disrupting the backend
protocol between Chroma Client and Chroma Server, or the user-facing
API.

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

View File

@@ -9,6 +9,7 @@ services:
build:
context: ./chroma-server
dockerfile: Dockerfile
target: chroma_server
volumes:
- ./chroma-server/:/chroma-server/
- index_data:/index_data

8
pyproject.toml Normal file
View File

@@ -0,0 +1,8 @@
[tool.black]
line-length = 100
# Black will refuse to run if it's not this version.
required-version = "22.6.0"
# Ensure black's output will be compatible with all listed versions.
target-version = ['py36', 'py37', 'py38', 'py39', 'py310']