mirror of
https://github.com/placeholder-soft/chroma.git
synced 2026-01-12 17:02:54 +08:00
Merge branch 'main' into jeff/celery
This commit is contained in:
1
.github/workflows/chroma-release.yml
vendored
1
.github/workflows/chroma-release.yml
vendored
@@ -44,6 +44,7 @@ jobs:
|
||||
with:
|
||||
context: chroma-server
|
||||
push: true
|
||||
target: chroma_server
|
||||
tags: ${{ steps.tag.outputs.tag_name}}
|
||||
- name: Get Release Version
|
||||
id: version
|
||||
|
||||
13
.github/workflows/chroma-server-test.yml
vendored
13
.github/workflows/chroma-server-test.yml
vendored
@@ -17,19 +17,10 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
python: ['3.10']
|
||||
platform: [ubuntu-latest, macos-latest]
|
||||
platform: [ubuntu-latest]
|
||||
runs-on: ${{ matrix.platform }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
cd chroma-server && python -m pip install -r requirements.txt -r requirements_dev.txt
|
||||
- name: Install chroma_client
|
||||
run: cd chroma-client && pip install .
|
||||
- name: Test
|
||||
run: cd chroma-server && python -m pytest
|
||||
run: cd chroma-server && bin/test
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -9,4 +9,4 @@ chroma-server/chroma_logs.log
|
||||
|
||||
**/data__nogit
|
||||
|
||||
**/.ipynb_checkpoints
|
||||
**/.ipynb_checkpoints
|
||||
5
Makefile
Normal file
5
Makefile
Normal file
@@ -0,0 +1,5 @@
|
||||
black:
|
||||
black --fast chroma-server chroma-client
|
||||
|
||||
check_black:
|
||||
black --check --fast chroma-server chroma-client
|
||||
35
README.md
35
README.md
@@ -9,3 +9,38 @@ Contents:
|
||||
- `/chroma-client` - Python client for Chroma
|
||||
- `/chroma-server` - FastAPI server used as the backend for Chroma client
|
||||
|
||||
|
||||
|
||||
### Get up and running on Linux
|
||||
No requirements
|
||||
```
|
||||
/bin/bash -c "$(curl -fsSL https://gist.githubusercontent.com/jeffchuber/effcbac05021e863bbd634f4b7d0283d/raw/4d38b150809d6ccbc379f88433cadd86c81d32cd/chroma_setup.sh)"
|
||||
python3 chroma/bin/test.py
|
||||
```
|
||||
|
||||
### Get up and running on Mac
|
||||
Requirements
|
||||
- git
|
||||
- Docker & `docker-compose`
|
||||
- pip
|
||||
|
||||
```
|
||||
/bin/bash -c "$(curl -fsSL https://gist.githubusercontent.com/jeffchuber/27a3cbb28e6521c811da6398346cd35f/raw/55c2d82870436431120a9446b47f19b72d88fa31/chroma_setup_mac.sh)"
|
||||
python3 chroma/bin/test.py
|
||||
```
|
||||
|
||||
* These urls will be swapped out for the link in the repo once it is live
|
||||
|
||||
|
||||
### You should see something like
|
||||
|
||||
```
|
||||
Getting heartbeat to verify the server is up
|
||||
{'nanosecond heartbeat': 1667865642509760965000}
|
||||
Logging embeddings into the database
|
||||
Generating the index
|
||||
True
|
||||
Running a nearest neighbor search
|
||||
{'ids': ['11540ca6-ebbc-4c81-8299-108d8c47c88c'], 'embeddings': [['sample_space', '11540ca6-ebbc-4c81-8299-108d8c47c88c', [1.0, 2.0, 3.0, 4.0, 5.0], '/images/1', 'training', None, 'spoon']], 'distances': [0.0]}
|
||||
Success! Everything worked!
|
||||
```
|
||||
42
bin/setup_linux.sh
Normal file
42
bin/setup_linux.sh
Normal file
@@ -0,0 +1,42 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# install pip
|
||||
apt install -y python3-pip
|
||||
|
||||
# install docker
|
||||
sudo apt-get update
|
||||
sudo apt-get -y install \
|
||||
ca-certificates \
|
||||
curl \
|
||||
gnupg \
|
||||
lsb-release
|
||||
|
||||
sudo mkdir -p /etc/apt/keyrings
|
||||
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
|
||||
|
||||
echo \
|
||||
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \
|
||||
$(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
|
||||
|
||||
sudo apt-get update
|
||||
|
||||
sudo apt-get -y install docker-ce docker-ce-cli containerd.io docker-compose-plugin
|
||||
|
||||
pip3 install docker-compose
|
||||
|
||||
|
||||
# get the code
|
||||
git clone https://oauth2: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@github.com/chroma-core/chroma.git
|
||||
|
||||
#checkout the right branch
|
||||
cd chroma
|
||||
git checkout jeff/packaging
|
||||
|
||||
# run docker
|
||||
cd chroma-server
|
||||
docker-compose up -d --build
|
||||
|
||||
# install chroma-client
|
||||
cd ../chroma-client
|
||||
pip3 install --upgrade pip # you have to do this or it will use UNKNOWN as the package name
|
||||
pip3 install .
|
||||
19
bin/setup_mac.sh
Normal file
19
bin/setup_mac.sh
Normal file
@@ -0,0 +1,19 @@
|
||||
# requirements
|
||||
# - docker
|
||||
# - pip
|
||||
|
||||
# get the code
|
||||
git clone https://oauth2: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@github.com/chroma-core/chroma.git
|
||||
|
||||
#checkout the right branch
|
||||
cd chroma
|
||||
git checkout jeff/packaging
|
||||
|
||||
# run docker
|
||||
cd chroma-server
|
||||
docker-compose up -d --build
|
||||
|
||||
# install chroma-client
|
||||
cd ../chroma-client
|
||||
pip install --upgrade pip # you have to do this or it will use UNKNOWN as the package name
|
||||
pip install .
|
||||
23
bin/test.py
Normal file
23
bin/test.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from chroma_client import Chroma
|
||||
|
||||
chroma = Chroma()
|
||||
chroma.set_space_key('sample_space')
|
||||
print("Getting heartbeat to verify the server is up")
|
||||
print(chroma.heartbeat())
|
||||
|
||||
print("Logging embeddings into the database")
|
||||
chroma.log(
|
||||
[[1,2,3,4,5], [5,4,3,2,1], [10,9,8,7,6]],
|
||||
["/images/1", "/images/2", "/images/3"],
|
||||
["training", "training", "training"],
|
||||
['spoon', 'knife', 'fork']
|
||||
)
|
||||
|
||||
# print("fetch", chroma.fetch())
|
||||
print("Generating the index")
|
||||
print(chroma.process())
|
||||
|
||||
print("Running a nearest neighbor search")
|
||||
print(chroma.get_nearest_neighbors([1,2,3,4,5], 1))
|
||||
|
||||
print("Success! Everything worked!")
|
||||
@@ -2,10 +2,11 @@ import requests
|
||||
import json
|
||||
from typing import Union
|
||||
|
||||
|
||||
class Chroma:
|
||||
|
||||
_api_url = "http://localhost:8000/api/v1"
|
||||
_space_key = None
|
||||
_space_key = "default_scope"
|
||||
|
||||
def __init__(self, url=None, app=None, model_version=None, layer=None):
|
||||
"""Initialize Chroma client"""
|
||||
@@ -159,4 +160,4 @@ class Chroma:
|
||||
|
||||
def get_task_status(self, task_id):
|
||||
'''Gets the status of a task'''
|
||||
return requests.get(self._api_url + f"/tasks/{task_id}").json()
|
||||
return requests.get(self._api_url + f"/tasks/{task_id}").json()
|
||||
|
||||
@@ -4,11 +4,14 @@ from chroma_client import Chroma
|
||||
import pytest
|
||||
import time
|
||||
from httpx import AsyncClient
|
||||
|
||||
# from ..api import app # this wont work because i moved the file
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def anyio_backend():
|
||||
return 'asyncio'
|
||||
return "asyncio"
|
||||
|
||||
|
||||
def test_init():
|
||||
chroma = Chroma()
|
||||
|
||||
3
chroma-server/.env
Normal file
3
chroma-server/.env
Normal file
@@ -0,0 +1,3 @@
|
||||
disable_anonymized_telemetry=False
|
||||
environment=development
|
||||
telemetry_anonymized_uuid=f80b11fc-1c5a-4a90-ba35-8c3a3c5371cc
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM --platform=linux/amd64 python:3.10
|
||||
FROM --platform=linux/amd64 python:3.10 AS chroma_server
|
||||
|
||||
#RUN apt-get update -qq
|
||||
#RUN apt-get install python3.10 python3-pip -y --no-install-recommends && rm -rf /var/lib/apt/lists_/*
|
||||
@@ -15,3 +15,13 @@ EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "chroma_server:app", "--host", "0.0.0.0", "--port", "8000", "--proxy-headers"]
|
||||
|
||||
# Use a multi-stage build to layer in test dependencies without bloating server image
|
||||
# https://docs.docker.com/build/building/multi-stage/
|
||||
# Note: requires passing --target to docker-build.
|
||||
FROM chroma_server AS chroma_server_test
|
||||
|
||||
COPY ./requirements_dev.txt requirements_dev.txt
|
||||
RUN pip install --no-cache-dir --upgrade -r requirements_dev.txt
|
||||
|
||||
CMD ["python", "-m", "pytest"]
|
||||
|
||||
|
||||
@@ -12,7 +12,8 @@ pip install -r requirements.txt
|
||||
pip install -r requirements_dev.txt
|
||||
```
|
||||
|
||||
To run tests, run `pytest`.
|
||||
To run tests, run `bin/test`. This will run the test suite inside a
|
||||
docker compose cluster, with the database available, and clean up when complete.
|
||||
|
||||
To run the server locally, in development mode, run `uvicorn chroma_server:app --reload`
|
||||
|
||||
@@ -43,4 +44,4 @@ To run use `docker images` to see what containers and tags you have available:
|
||||
docker run -p 8000:8000 ghcr.io/chroma-core/chroma-server:<tag name -- eg 0.0.2-dirty>>
|
||||
```
|
||||
|
||||
This will expose the internal app at `localhost:8000`
|
||||
This will expose the internal app at `localhost:8000`
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
docker build . -t ghcr.io/chroma-core/chroma-server:`bin/version`
|
||||
docker build . --target chroma_server -t ghcr.io/chroma-core/chroma-server:`bin/version`
|
||||
|
||||
11
chroma-server/bin/test
Executable file
11
chroma-server/bin/test
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
function cleanup {
|
||||
docker-compose -f docker-compose.test.yml down
|
||||
}
|
||||
|
||||
trap cleanup EXIT
|
||||
|
||||
docker-compose -f docker-compose.test.yml run --rm server_test
|
||||
@@ -1,8 +1,8 @@
|
||||
|
||||
import random
|
||||
|
||||
|
||||
def rand_bisectional_subsample(data):
|
||||
"""
|
||||
Randomly bisectionally subsample a list of data to size.
|
||||
"""
|
||||
return data.sample(frac=0.5, replace=True, random_state=1)
|
||||
return data.sample(frac=0.5, replace=True, random_state=1)
|
||||
|
||||
@@ -2,22 +2,23 @@ import numpy as np
|
||||
import json
|
||||
import ast
|
||||
|
||||
|
||||
def class_distances(data):
|
||||
''''
|
||||
"""'
|
||||
This is all very subject to change, so essentially just copy and paste from what we had before
|
||||
'''
|
||||
"""
|
||||
|
||||
return False
|
||||
|
||||
# def unpack_annotations(embeddings):
|
||||
# annotations = [json.loads(embedding['infer'])["annotations"]for embedding in embeddings]
|
||||
# annotations = [annotation for annotation_list in annotations for annotation in annotation_list]
|
||||
# annotations = [annotation for annotation_list in annotations for annotation in annotation_list]
|
||||
# # Unpack embedding data
|
||||
# embeddings = [embedding["embedding_data"] for embedding in embeddings]
|
||||
# embedding_vectors_by_category = {}
|
||||
# for embedding_annotation_pair in zip(embeddings, annotations):
|
||||
# data = np.array(embedding_annotation_pair[0])
|
||||
# category = embedding_annotation_pair[1]['category_id']
|
||||
# category = embedding_annotation_pair[1]['category_id']
|
||||
# if category in embedding_vectors_by_category.keys():
|
||||
# embedding_vectors_by_category[category] = np.append(
|
||||
# embedding_vectors_by_category[category], data[np.newaxis, :], axis=0
|
||||
@@ -84,5 +85,5 @@ def class_distances(data):
|
||||
|
||||
# if (len(inferences) == 0):
|
||||
# raise Exception("No inferences found for datapoint")
|
||||
|
||||
# return output_distances
|
||||
|
||||
# return output_distances
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import shutil
|
||||
import time
|
||||
|
||||
from fastapi import FastAPI, Response, status
|
||||
from fastapi import FastAPI, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from worker import heavy_offline_analysis
|
||||
@@ -12,6 +12,13 @@ from chroma_server.index.hnswlib import Hnswlib
|
||||
from chroma_server.types import AddEmbedding, QueryEmbedding, ProcessEmbedding, FetchEmbedding, CountEmbedding, RawSql, Results, SpaceKeyInput
|
||||
from chroma_server.utils import logger
|
||||
|
||||
from chroma_server.utils.telemetry.capture import Capture
|
||||
from chroma_server.utils.error_reporting import init_error_reporting
|
||||
|
||||
chroma_telemetry = Capture()
|
||||
chroma_telemetry.capture('server-start')
|
||||
init_error_reporting()
|
||||
|
||||
from celery.result import AsyncResult
|
||||
|
||||
# Boot script
|
||||
@@ -66,6 +73,7 @@ async def add_to_db(new_embedding: AddEmbedding):
|
||||
|
||||
return {"response": "Added records to database"}
|
||||
|
||||
|
||||
@app.get("/api/v1/process")
|
||||
async def process(process_embedding: ProcessEmbedding):
|
||||
'''
|
||||
@@ -74,6 +82,7 @@ async def process(process_embedding: ProcessEmbedding):
|
||||
fetch = app._db.fetch({"space_key": process_embedding.space_key}, columnar=True)
|
||||
app._ann_index.run(process_embedding.space_key, fetch[1], fetch[2]) # more magic number, ugh
|
||||
|
||||
|
||||
@app.get("/api/v1/fetch")
|
||||
async def fetch(fetch_embedding: FetchEmbedding):
|
||||
'''
|
||||
@@ -82,6 +91,7 @@ async def fetch(fetch_embedding: FetchEmbedding):
|
||||
'''
|
||||
return app._db.fetch(fetch_embedding.where_filter, fetch_embedding.sort, fetch_embedding.limit)
|
||||
|
||||
|
||||
@app.get("/api/v1/count")
|
||||
async def count(count_embedding: CountEmbedding):
|
||||
'''
|
||||
@@ -89,6 +99,7 @@ async def count(count_embedding: CountEmbedding):
|
||||
'''
|
||||
return {"count": app._db.count(space_key=count_embedding.space_key)}
|
||||
|
||||
|
||||
@app.get("/api/v1/reset")
|
||||
async def reset():
|
||||
'''
|
||||
@@ -102,7 +113,7 @@ async def reset():
|
||||
|
||||
@app.post("/api/v1/get_nearest_neighbors")
|
||||
async def get_nearest_neighbors(embedding: QueryEmbedding):
|
||||
'''
|
||||
"""
|
||||
return the distance, database ids, and embedding themselves for the input embedding
|
||||
'''
|
||||
if embedding.space_key is None:
|
||||
@@ -112,9 +123,9 @@ async def get_nearest_neighbors(embedding: QueryEmbedding):
|
||||
filter_by_where = {}
|
||||
filter_by_where["space_key"] = embedding.space_key
|
||||
if embedding.category_name is not None:
|
||||
filter_by_where['category_name'] = embedding.category_name
|
||||
filter_by_where["category_name"] = embedding.category_name
|
||||
if embedding.dataset is not None:
|
||||
filter_by_where['dataset'] = embedding.dataset
|
||||
filter_by_where["dataset"] = embedding.dataset
|
||||
|
||||
if filter_by_where is not None:
|
||||
results = app._db.fetch(filter_by_where)
|
||||
@@ -129,4 +140,4 @@ async def get_nearest_neighbors(embedding: QueryEmbedding):
|
||||
|
||||
@app.get("/api/v1/raw_sql")
|
||||
async def raw_sql(raw_sql: RawSql):
|
||||
return app._db.raw_sql(raw_sql.raw_sql)
|
||||
return app._db.raw_sql(raw_sql.raw_sql)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
class Database():
|
||||
|
||||
class Database:
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
pass
|
||||
@@ -23,4 +24,4 @@ class Database():
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
pass
|
||||
pass
|
||||
|
||||
212
chroma-server/chroma_server/db/duckdb.py
Normal file
212
chroma-server/chroma_server/db/duckdb.py
Normal file
@@ -0,0 +1,212 @@
|
||||
from os import EX_CANTCREAT
|
||||
from chroma_server.db.abstract import Database
|
||||
import duckdb
|
||||
|
||||
|
||||
class DuckDB(Database):
|
||||
_conn = None
|
||||
|
||||
def __init__(self):
|
||||
self._conn = duckdb.connect()
|
||||
self._conn.execute(
|
||||
"""
|
||||
CREATE TABLE embeddings (
|
||||
id integer PRIMARY KEY,
|
||||
embedding_data REAL[],
|
||||
input_uri STRING,
|
||||
dataset STRING,
|
||||
custom_quality_score REAL,
|
||||
category_name STRING
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# ids to manage internal bookkeeping and *nothing else*, users should not have to care about these ids
|
||||
self._conn.execute(
|
||||
"""
|
||||
CREATE SEQUENCE seq_id START 1;
|
||||
"""
|
||||
)
|
||||
|
||||
self._conn.execute(
|
||||
"""
|
||||
-- change the default null sorting order to either NULLS FIRST and NULLS LAST
|
||||
PRAGMA default_null_order='NULLS LAST';
|
||||
-- change the default sorting order to either DESC or ASC
|
||||
PRAGMA default_order='DESC';
|
||||
"""
|
||||
)
|
||||
return
|
||||
|
||||
def add_batch(
|
||||
self, embedding_data, input_uri, dataset=None, custom_quality_score=None, category_name=None
|
||||
):
|
||||
"""
|
||||
Add embeddings to the database
|
||||
This accepts both a single input and a list of inputs
|
||||
"""
|
||||
|
||||
# create list of the types of all inputs
|
||||
types = [type(x).__name__ for x in [embedding_data, input_uri]]
|
||||
|
||||
# if all of the types are 'list' - do batch mode
|
||||
if all(x == "list" for x in types):
|
||||
lengths = [len(x) for x in [embedding_data, input_uri]]
|
||||
|
||||
# accepts some inputs as str or none, and this multiples them out to the correct length
|
||||
if custom_quality_score is None or isinstance(custom_quality_score, str):
|
||||
custom_quality_score = [custom_quality_score] * lengths[0]
|
||||
if category_name is None or isinstance(category_name, str):
|
||||
category_name = [category_name] * lengths[0]
|
||||
if dataset is None or isinstance(dataset, str):
|
||||
dataset = [dataset] * lengths[0]
|
||||
|
||||
# we have to move from column to row format for duckdb
|
||||
data_to_insert = []
|
||||
for i in range(lengths[0]):
|
||||
data_to_insert.append(
|
||||
[
|
||||
embedding_data[i],
|
||||
input_uri[i],
|
||||
dataset[i],
|
||||
custom_quality_score[i],
|
||||
category_name[i],
|
||||
]
|
||||
)
|
||||
|
||||
if all(x == lengths[0] for x in lengths):
|
||||
self._conn.executemany(
|
||||
"""
|
||||
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)""",
|
||||
data_to_insert,
|
||||
)
|
||||
return
|
||||
|
||||
# if any of the types are 'list' - throw an error
|
||||
if any(x == list for x in [input_uri, dataset, custom_quality_score, category_name]):
|
||||
raise Exception(
|
||||
"Invalid input types. One input is a list where others are not: " + str(types)
|
||||
)
|
||||
|
||||
# single insert mode
|
||||
# This should never fire because we do everything in batch mode, but given the mode away from duckdb likely, I am just leaving it in
|
||||
self._conn.execute(
|
||||
"""
|
||||
INSERT INTO embeddings VALUES (nextval('seq_id'), ?, ?, ?, ?, ?)""",
|
||||
[embedding_data, input_uri, dataset, custom_quality_score, category_name],
|
||||
)
|
||||
|
||||
def count(self):
|
||||
return self._conn.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM embeddings;
|
||||
"""
|
||||
).fetchone()[0]
|
||||
|
||||
def update(self, data): # call this update_custom_quality_score! that is all it does
|
||||
"""
|
||||
I was not able to figure out (yet) how to do a bulk update in duckdb
|
||||
This is going to be fairly slow
|
||||
"""
|
||||
for element in data:
|
||||
if element["custom_quality_score"] is None:
|
||||
continue
|
||||
self._conn.execute(
|
||||
f"""
|
||||
UPDATE embeddings SET custom_quality_score={element['custom_quality_score']} WHERE id={element['id']}"""
|
||||
)
|
||||
|
||||
def fetch(self, where_filter={}, sort=None, limit=None):
|
||||
# check to see if query is a dict and if it is a flat list of key value pairs
|
||||
if where_filter is not None:
|
||||
if not isinstance(where_filter, dict):
|
||||
raise Exception("Invalid where_filter: " + str(where_filter))
|
||||
|
||||
# ensure where_filter is a flat dict
|
||||
for key in where_filter:
|
||||
if isinstance(where_filter[key], dict):
|
||||
raise Exception("Invalid where_filter: " + str(where_filter))
|
||||
|
||||
where_filter = " AND ".join([f"{key} = '{value}'" for key, value in where_filter.items()])
|
||||
|
||||
if where_filter:
|
||||
where_filter = f"WHERE {where_filter}"
|
||||
|
||||
if sort is not None:
|
||||
where_filter += f" ORDER BY {sort}"
|
||||
|
||||
if limit is not None or isinstance(limit, int):
|
||||
where_filter += f" LIMIT {limit}"
|
||||
|
||||
return (
|
||||
self._conn.execute(
|
||||
f"""
|
||||
SELECT
|
||||
id,
|
||||
embedding_data,
|
||||
input_uri,
|
||||
dataset,
|
||||
custom_quality_score,
|
||||
category_name
|
||||
FROM
|
||||
embeddings
|
||||
{where_filter}
|
||||
"""
|
||||
)
|
||||
.fetchdf()
|
||||
.replace({np.nan: None})
|
||||
) # replace nan with None for json serialization
|
||||
|
||||
def delete_batch(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def persist(self):
|
||||
"""
|
||||
Persist the database to disk
|
||||
"""
|
||||
if self._conn is None:
|
||||
return
|
||||
|
||||
self._conn.execute(
|
||||
"""
|
||||
COPY
|
||||
(SELECT * FROM embeddings)
|
||||
TO '.chroma/chroma.parquet'
|
||||
(FORMAT PARQUET);
|
||||
"""
|
||||
)
|
||||
|
||||
def load(self, path=".chroma/chroma.parquet"):
|
||||
"""
|
||||
Load the database from disk
|
||||
"""
|
||||
self._conn.execute(f"INSERT INTO embeddings SELECT * FROM read_parquet('{path}');")
|
||||
|
||||
def get_by_ids(self, ids=list):
|
||||
# select from duckdb table where ids are in the list
|
||||
if not isinstance(ids, list):
|
||||
raise Exception("ids must be a list")
|
||||
|
||||
if not ids:
|
||||
# create an empty pandas dataframe
|
||||
return pd.DataFrame()
|
||||
|
||||
return (
|
||||
self._conn.execute(
|
||||
f"""
|
||||
SELECT
|
||||
id,
|
||||
embedding_data,
|
||||
input_uri,
|
||||
dataset,
|
||||
custom_quality_score,
|
||||
category_name
|
||||
FROM
|
||||
embeddings
|
||||
WHERE
|
||||
id IN ({','.join([str(x) for x in ids])})
|
||||
"""
|
||||
)
|
||||
.fetchdf()
|
||||
.replace({np.nan: None})
|
||||
) # replace nan with None for json serialization
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
class Index():
|
||||
|
||||
class Index:
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
pass
|
||||
@@ -23,4 +24,4 @@ class Index():
|
||||
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -4,7 +4,8 @@ import time
|
||||
import os
|
||||
import numpy as np
|
||||
from chroma_server.index.abstract import Index
|
||||
from chroma_server.utils import logger
|
||||
from chroma_server.logger import logger
|
||||
|
||||
|
||||
class Hnswlib(Index):
|
||||
|
||||
@@ -16,9 +17,6 @@ class Hnswlib(Index):
|
||||
'time_created': None,
|
||||
}
|
||||
|
||||
# these data structures enable us to map between uuids and ids
|
||||
# - our uuids are strings (clickhouse doesnt do autoincrementing ids for performance)
|
||||
# - but hnswlib uses integers only as ids
|
||||
_id_to_uuid = {}
|
||||
_uuid_to_id = {}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
|
||||
|
||||
def setup_logging():
|
||||
logging.basicConfig(filename="chroma_logs.log")
|
||||
logger = logging.getLogger("Chroma")
|
||||
@@ -7,4 +8,5 @@ def setup_logging():
|
||||
logger.debug("Logger created")
|
||||
return logger
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
logger = setup_logging()
|
||||
@@ -4,32 +4,45 @@ from httpx import AsyncClient
|
||||
|
||||
from ..api import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def anyio_backend():
|
||||
return 'asyncio'
|
||||
return "asyncio"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_root():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
response = await ac.get("/api/v1")
|
||||
assert response.status_code == 200
|
||||
assert abs(response.json()["nanosecond heartbeat"] - int(1000 * time.time_ns())) < 3_000_000_000 # a billion nanoseconds = 3s
|
||||
assert (
|
||||
abs(response.json()["nanosecond heartbeat"] - int(1000 * time.time_ns())) < 3_000_000_000
|
||||
) # a billion nanoseconds = 3s
|
||||
|
||||
|
||||
async def post_one_record(ac):
|
||||
return await ac.post("/api/v1/add", json={
|
||||
"embedding_data": [1.02, 2.03, 3.03],
|
||||
"input_uri": "https://example.com",
|
||||
"dataset": "coco",
|
||||
"category_name": "person"
|
||||
})
|
||||
return await ac.post(
|
||||
"/api/v1/add",
|
||||
json={
|
||||
"embedding_data": [1.02, 2.03, 3.03],
|
||||
"input_uri": "https://example.com",
|
||||
"dataset": "coco",
|
||||
"category_name": "person",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def post_batch_records(ac):
|
||||
return await ac.post("/api/v1/add", json={
|
||||
"embedding_data": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
|
||||
"input_uri": ["https://example.com", "https://example.com"],
|
||||
"dataset": "training",
|
||||
"category_name": "person"
|
||||
})
|
||||
return await ac.post(
|
||||
"/api/v1/add",
|
||||
json={
|
||||
"embedding_data": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
|
||||
"input_uri": ["https://example.com", "https://example.com"],
|
||||
"dataset": "training",
|
||||
"category_name": "person",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_add_to_db():
|
||||
@@ -38,6 +51,7 @@ async def test_add_to_db():
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {"response": "Added record to database"}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_add_to_db_batch():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
@@ -45,6 +59,7 @@ async def test_add_to_db_batch():
|
||||
assert response.status_code == 201
|
||||
assert response.json() == {"response": "Added record to database"}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_fetch_from_db():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
@@ -53,15 +68,17 @@ async def test_fetch_from_db():
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_count_from_db():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
await ac.get("/api/v1/reset") # reset db
|
||||
await ac.get("/api/v1/reset") # reset db
|
||||
await post_batch_records(ac)
|
||||
response = await ac.get("/api/v1/count")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"count": 2}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_reset_db():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
@@ -74,25 +91,19 @@ async def test_reset_db():
|
||||
response = await ac.get("/api/v1/count")
|
||||
assert response.json() == {"count": 0}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_nearest_neighbors():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
await ac.get("/api/v1/reset")
|
||||
await post_batch_records(ac)
|
||||
await ac.get("/api/v1/process")
|
||||
response = await ac.post("/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1})
|
||||
response = await ac.post(
|
||||
"/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()["ids"]) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_nearest_neighbors_filter():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
await ac.get("/api/v1/reset")
|
||||
await post_batch_records(ac)
|
||||
await ac.get("/api/v1/process")
|
||||
response = await ac.post("/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1, "dataset": "training", "category_name": "monkey"})
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()["ids"]) == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_nearest_neighbors_filter():
|
||||
@@ -100,6 +111,33 @@ async def test_get_nearest_neighbors_filter():
|
||||
await ac.get("/api/v1/reset")
|
||||
await post_batch_records(ac)
|
||||
await ac.get("/api/v1/process")
|
||||
response = await ac.post("/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 2, "dataset": "training", "category_name": "person"})
|
||||
response = await ac.post(
|
||||
"/api/v1/get_nearest_neighbors",
|
||||
json={
|
||||
"embedding": [1.1, 2.3, 3.2],
|
||||
"n_results": 1,
|
||||
"dataset": "training",
|
||||
"category_name": "monkey",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()["ids"]) == 0
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_nearest_neighbors_filter():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
await ac.get("/api/v1/reset")
|
||||
await post_batch_records(ac)
|
||||
await ac.get("/api/v1/process")
|
||||
response = await ac.post(
|
||||
"/api/v1/get_nearest_neighbors",
|
||||
json={
|
||||
"embedding": [1.1, 2.3, 3.2],
|
||||
"n_results": 2,
|
||||
"dataset": "training",
|
||||
"category_name": "person",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()["ids"]) == 2
|
||||
|
||||
@@ -7,9 +7,10 @@ class AddEmbedding(BaseModel):
|
||||
embedding_data: list
|
||||
input_uri: Union[str, list]
|
||||
dataset: Union[str, list] = None
|
||||
custom_quality_score: Union[float, list] = None
|
||||
custom_quality_score: Union[float, list] = None
|
||||
category_name: Union[str, list] = None
|
||||
|
||||
|
||||
class QueryEmbedding(BaseModel):
|
||||
space_key: str = None
|
||||
embedding: list
|
||||
|
||||
0
chroma-server/chroma_server/utils/__init__.py
Normal file
0
chroma-server/chroma_server/utils/__init__.py
Normal file
15
chroma-server/chroma_server/utils/config/settings.py
Normal file
15
chroma-server/chroma_server/utils/config/settings.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from functools import lru_cache
|
||||
from typing import Union
|
||||
from pydantic import BaseSettings
|
||||
|
||||
class Settings(BaseSettings):
|
||||
disable_anonymized_telemetry: bool = False
|
||||
telemetry_anonymized_uuid: str = ''
|
||||
environment: str = 'development'
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
@lru_cache()
|
||||
def get_settings():
|
||||
return Settings()
|
||||
27
chroma-server/chroma_server/utils/error_reporting.py
Normal file
27
chroma-server/chroma_server/utils/error_reporting.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from chroma_server.utils.config.settings import get_settings
|
||||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.client import Client
|
||||
from sentry_sdk import configure_scope
|
||||
from posthog.sentry.posthog_integration import PostHogIntegration
|
||||
PostHogIntegration.organization = "chroma"
|
||||
sample_rate = 1.0
|
||||
if get_settings().environment == "production":
|
||||
sample_rate = 0.1
|
||||
|
||||
def strip_sensitive_data(event, hint):
|
||||
if 'server_name' in event:
|
||||
del event['server_name']
|
||||
return event
|
||||
|
||||
def init_error_reporting():
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn="https://ef5fae1e461f49b3a7a2adf3404378ab@o4504080408051712.ingest.sentry.io/4504080409296896",
|
||||
traces_sample_rate=sample_rate,
|
||||
integrations=[PostHogIntegration()],
|
||||
environment=get_settings().environment,
|
||||
before_send=strip_sensitive_data,
|
||||
)
|
||||
with configure_scope() as scope:
|
||||
scope.set_tag('posthog_distinct_id', get_settings().telemetry_anonymized_uuid)
|
||||
10
chroma-server/chroma_server/utils/telemetry/abstract.py
Normal file
10
chroma-server/chroma_server/utils/telemetry/abstract.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
class Telemetry():
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def capture(self, event, properties=None):
|
||||
pass
|
||||
34
chroma-server/chroma_server/utils/telemetry/capture.py
Normal file
34
chroma-server/chroma_server/utils/telemetry/capture.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import posthog
|
||||
import uuid
|
||||
import sys
|
||||
from chroma_server.utils.telemetry.abstract import Telemetry
|
||||
from chroma_server.utils.config.settings import get_settings
|
||||
|
||||
class Capture(Telemetry):
|
||||
_conn = None
|
||||
_telemetry_anonymized_uuid = None
|
||||
|
||||
def __init__(self):
|
||||
if get_settings().disable_anonymized_telemetry:
|
||||
posthog.disabled = True
|
||||
|
||||
# disable telemetry if we're running tests
|
||||
if "pytest" in sys.modules:
|
||||
posthog.disabled = True
|
||||
|
||||
posthog.project_api_key = 'phc_YeUxaojbKk5KPi8hNlx1bBKHzuZ4FDtl67kH1blv8Bh'
|
||||
posthog.host = 'https://app.posthog.com'
|
||||
self._conn = posthog
|
||||
|
||||
if not get_settings().telemetry_anonymized_uuid:
|
||||
self._telemetry_anonymized_uuid = uuid.uuid4()
|
||||
|
||||
with open(".env", "a") as f:
|
||||
f.write(f"\ntelemetry_anonymized_uuid={self._telemetry_anonymized_uuid}\n")
|
||||
|
||||
else:
|
||||
self._telemetry_anonymized_uuid = get_settings().telemetry_anonymized_uuid
|
||||
|
||||
def capture(self, event, properties=None):
|
||||
self._conn.capture(self._telemetry_anonymized_uuid, event, properties)
|
||||
|
||||
26
chroma-server/docker-compose.test.yml
Normal file
26
chroma-server/docker-compose.test.yml
Normal file
@@ -0,0 +1,26 @@
|
||||
version: '3.9'
|
||||
|
||||
networks:
|
||||
my-network:
|
||||
driver: bridge
|
||||
|
||||
services:
|
||||
server_test:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
target: chroma_server_test
|
||||
depends_on:
|
||||
- clickhouse
|
||||
networks:
|
||||
- my-network
|
||||
|
||||
clickhouse:
|
||||
image: docker.io/bitnami/clickhouse:22.9
|
||||
environment:
|
||||
- ALLOW_EMPTY_PASSWORD=yes
|
||||
ports:
|
||||
- '8123:8123'
|
||||
- '9000:9000'
|
||||
networks:
|
||||
- my-network
|
||||
@@ -6,7 +6,10 @@ pandas==1.5.0
|
||||
duckdb==0.5.1
|
||||
hnswlib @ git+https://oauth2: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@github.com/chroma-core/hnswlib.git
|
||||
|
||||
clickhouse_driver==0.2.4
|
||||
|
||||
redis==3.5.3
|
||||
celery==4.4.7
|
||||
celery==4.4.7
|
||||
clickhouse_driver==0.2.4
|
||||
posthog==2.1.2
|
||||
uuid==1.30
|
||||
sentry_sdk==1.10.1
|
||||
pydantic==1.9.0
|
||||
|
||||
@@ -1,8 +1,3 @@
|
||||
httpx
|
||||
pytest
|
||||
setuptools_scm
|
||||
duckdb
|
||||
hnswlib @ git+https://oauth2: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@github.com/chroma-core/hnswlib.git
|
||||
pandas
|
||||
numpy
|
||||
pyarrow
|
||||
61
doc/adr/2022-11-03-clickhouse-architecture.md
Normal file
61
doc/adr/2022-11-03-clickhouse-architecture.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# Clickhouse Architecture
|
||||
|
||||
## Context
|
||||
|
||||
The current prototype of Chroma Server uses DuckDB and Parquet files for
|
||||
persistence. Although the simplicity and batch data retrieval
|
||||
characteristics of this are attractive, we determine that this is
|
||||
suboptimal for three primary reasons:
|
||||
|
||||
- Chroma's primary mode of ingesting data is a stream of small batches
|
||||
of embeddings. DuckDB and Parquet are not well optimized for
|
||||
streaming input. In fact, it's impossible to append to a Parquet
|
||||
file; the entire file must be re-written or additional files
|
||||
created.
|
||||
- DuckDB explicitly does not support multiple writer processes, which
|
||||
we will likely want in the medium term.
|
||||
- DuckDB + Parquet requires an explicit flush or write operation to
|
||||
persist data. This adds an element of "state management" and is
|
||||
complexity that we would rather not expose to the client.
|
||||
|
||||
Therefore, we are looking for an architecture with the following quantities:
|
||||
|
||||
- Efficient streaming ingest
|
||||
- Efficient bulk read to pull data into memory for processing (OLAP)
|
||||
- Low volume transactional CRUD operations (e.g datasets and metadata)
|
||||
- Low administrative overhead, to present as small a client API as
|
||||
possible. We want to avoid exposing any methods aside from those
|
||||
that define Chroma as a product, for a focused user experience.
|
||||
|
||||
## Decision
|
||||
|
||||
We will use Clickhouse as the persistence layer. For now it will be
|
||||
the only persistence mechanism used by Chroma.
|
||||
|
||||
Instances of Chroma Server will be stateless, aside from caching for
|
||||
performance.
|
||||
|
||||
The MVP will run in a simple `docker-compose` configuration with a
|
||||
single Clickhouse and a single Chroma Server.
|
||||
|
||||
The Chroma Server will, when required to service a read operation,
|
||||
pull entire datasets from Clickhouse into memory and keep them cached
|
||||
in order to perform algorithmic work on demand.
|
||||
|
||||

|
||||
|
||||
## Consequences
|
||||
|
||||
- The MVP is actually less complex than the previous DuckDB based
|
||||
solution.
|
||||
- We can scale horizontally by adding more Chroma Server instances in
|
||||
a cluster.
|
||||
- We can scale vertically by using a larger instance of Clickhouse or
|
||||
moving to clustered Clickhouse as workloads grow.
|
||||
- At some point in the future, we will likely need to add an OLTP
|
||||
database, when the system contains enough transactional data that
|
||||
Clickhouse starts to perform poorly for row-based updates.
|
||||
- We maintain separation of concerns, and can make future changes to
|
||||
the data persistence mechanisms without disrupting the backend
|
||||
protocol between Chroma Client and Chroma Server, or the user-facing
|
||||
API.
|
||||
BIN
doc/adr/2022-11-03-clickhouse-architecture/diagram.graffle
Normal file
BIN
doc/adr/2022-11-03-clickhouse-architecture/diagram.graffle
Normal file
Binary file not shown.
BIN
doc/adr/2022-11-03-clickhouse-architecture/diagram.png
Normal file
BIN
doc/adr/2022-11-03-clickhouse-architecture/diagram.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 38 KiB |
@@ -9,6 +9,7 @@ services:
|
||||
build:
|
||||
context: ./chroma-server
|
||||
dockerfile: Dockerfile
|
||||
target: chroma_server
|
||||
volumes:
|
||||
- ./chroma-server/:/chroma-server/
|
||||
- index_data:/index_data
|
||||
|
||||
8
pyproject.toml
Normal file
8
pyproject.toml
Normal file
@@ -0,0 +1,8 @@
|
||||
[tool.black]
|
||||
line-length = 100
|
||||
|
||||
# Black will refuse to run if it's not this version.
|
||||
required-version = "22.6.0"
|
||||
|
||||
# Ensure black's output will be compatible with all listed versions.
|
||||
target-version = ['py36', 'py37', 'py38', 'py39', 'py310']
|
||||
Reference in New Issue
Block a user