diff --git a/.github/workflows/chroma-client-integration-test.yml b/.github/workflows/chroma-client-integration-test.yml index 734b38e..9a6fdd4 100644 --- a/.github/workflows/chroma-client-integration-test.yml +++ b/.github/workflows/chroma-client-integration-test.yml @@ -7,7 +7,7 @@ on: pull_request: branches: - main - - '*' + - '**' jobs: test: diff --git a/.github/workflows/chroma-integration-test.yml b/.github/workflows/chroma-integration-test.yml index 187a10f..9b332d5 100644 --- a/.github/workflows/chroma-integration-test.yml +++ b/.github/workflows/chroma-integration-test.yml @@ -8,7 +8,7 @@ on: pull_request: branches: - main - - '*' + - '**' jobs: test: diff --git a/.github/workflows/chroma-test.yml b/.github/workflows/chroma-test.yml index 1cbefc9..cccd6a6 100644 --- a/.github/workflows/chroma-test.yml +++ b/.github/workflows/chroma-test.yml @@ -8,7 +8,7 @@ on: pull_request: branches: - main - - '*' + - '**' jobs: test: diff --git a/Dockerfile b/Dockerfile index ee06e28..cb56b05 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.10-slim-bullseye as builder +FROM python:3.10-slim-bookworm as builder #RUN apt-get update -qq #RUN apt-get install python3.10 python3-pip -y --no-install-recommends && rm -rf /var/lib/apt/lists_/* @@ -11,7 +11,7 @@ COPY ./requirements.txt requirements.txt RUN pip install --no-cache-dir --upgrade --prefix="/install" -r requirements.txt -FROM python:3.10-slim-bullseye as final +FROM python:3.10-slim-bookworm as final RUN mkdir /chroma WORKDIR /chroma diff --git a/bin/backup.sh b/bin/backup.sh deleted file mode 100644 index 75582b6..0000000 --- a/bin/backup.sh +++ /dev/null @@ -1,46 +0,0 @@ -# #!/bin/bash - -# # check to see if the docker container called chroma-private-clickhouse-1 is running -# if [ "$(docker inspect -f '{{.State.Running}}' chroma-private-clickhouse-1)" = "true" ]; then -# echo "chroma-private-clickhouse-1 is up, proceeding with backup" -# else -# echo "chroma-private-clickhouse-1 is not up" -# exit 1 -# fi - -# backup_name=${backup_name:-backup} -# backup_name="$backup_name-$(date +%Y_%m_%d-%H_%M_%S)" - -# # date with format YYYY_MM_DD-HH_MM_SS -# # backup_date=$(date +%Y_%m_%d-%H_%M_%S) - -# while [ $# -gt 0 ]; do - -# if [[ $1 == *"--"* ]]; then -# param="${1/--/}" -# declare $param="$2" -# # echo $1 $2 // Optional to see the parameter:value result -# fi - -# shift -# done - -# echo $backup_name - -# # create a folder at ../backup to store the backup -# mkdir -p ../backups - -# # create a folder inside of ../backups with the name of the backup -# mkdir -p ../backups/$backup_name - -# # create folder in ../backups with that name string, if folder already exists, exit -# docker exec -u 0 -it chroma-private-clickhouse-1 clickhouse-client --query="BACKUP DATABASE default TO Disk('backups', '$backup_name.zip')" - -# # use that name to dump the clickhouse db and copy into the folder -# docker cp chroma-private-clickhouse-1:/etc/clickhouse-server/$backup_name.zip ../backups/$backup_name/$backup_name.zip - -# # remove the backup from teh clickhouse container -# docker exec -u 0 -it chroma-private-clickhouse-1 rm /etc/clickhouse-server/$backup_name.zip - -# # copy the entire contents of -# docker cp chroma-private-server-1:/index_data ../backups/$backup_name diff --git a/bin/docker_entrypoint.sh b/bin/docker_entrypoint.sh index 8695815..3b0d146 100755 --- a/bin/docker_entrypoint.sh +++ b/bin/docker_entrypoint.sh @@ -1,5 +1,6 @@ #!/bin/bash echo "Rebuilding hnsw to ensure architecture compatibility" -pip install --force-reinstall --no-cache-dir hnswlib +pip install --force-reinstall --no-cache-dir chroma-hnswlib +export IS_PERSISTENT=1 uvicorn chromadb.app:app --workers 1 --host 0.0.0.0 --port 8000 --proxy-headers --log-config log_config.yml diff --git a/bin/integration-test b/bin/integration-test index b18624a..1432a96 100755 --- a/bin/integration-test +++ b/bin/integration-test @@ -13,7 +13,7 @@ trap cleanup EXIT docker compose -f docker-compose.test.yml up --build -d export CHROMA_INTEGRATION_TEST_ONLY=1 -export CHROMA_API_IMPL=rest +export CHROMA_API_IMPL=chromadb.api.fastapi.FastAPI export CHROMA_SERVER_HOST=localhost export CHROMA_SERVER_HTTP_PORT=8000 diff --git a/bin/restore.sh b/bin/restore.sh deleted file mode 100644 index 7f23752..0000000 --- a/bin/restore.sh +++ /dev/null @@ -1,42 +0,0 @@ -# #!/bin/bash - -# # check to see if the docker container called chroma_clickhouse_1 is running -# if [ "$(docker inspect -f '{{.State.Running}}' chroma_clickhouse_1)" = "true" ]; then -# echo "chroma_clickhouse_1 is up, proceeding with backup" -# else -# echo "chroma_clickhouse_1 is not up" -# exit 1 -# fi - -# while [ $# -gt 0 ]; do - -# if [[ $1 == *"--"* ]]; then -# param="${1/--/}" -# declare $param="$2" -# # echo $1 $2 // Optional to see the parameter:value result -# fi - -# shift -# done - -# # if backup name is not provided, exit -# if [ -z "$backup_name" ]; then -# echo "backup_name is not provided" -# exit 1 -# fi - -# echo $backup_name - -# # change file permissions to -rw-r----- -# chmod 440 ../backups/$backup_name/$backup_name.zip -# chmod 777 ../backups/$backup_name/$backup_name.zip -# docker cp ../backups/$backup_name/index_data chroma_server_1:/ -# docker cp ../backups/$backup_name/$backup_name.zip chroma_clickhouse_1:/etc/clickhouse-server/$backup_name.zip - -# docker exec -u 0 -it chroma_clickhouse_1 chmod 777 /etc/clickhouse-server/$backup_name.zip -# docker exec -u 0 -it chroma_clickhouse_1 chown 1001 /etc/clickhouse-server/$backup_name.zip -# docker exec -u 0 -it chroma_clickhouse_1 chgrp root /etc/clickhouse-server/$backup_name.zip -# docker exec -u 0 -it chroma_clickhouse_1 clickhouse-client --query="DROP TABLE embeddings" -# docker exec -u 0 -it chroma_clickhouse_1 clickhouse-client --query="DROP TABLE results" -# docker exec -u 0 -it chroma_clickhouse_1 rm -rf /bitnami/clickhouse/data/tmp -# docker exec -u 0 -it chroma_clickhouse_1 clickhouse-client --query="RESTORE DATABASE default FROM Disk('backups', '$backup_name.zip')" diff --git a/bin/setup_linux.sh b/bin/setup_linux.sh deleted file mode 100644 index 0c7ddf2..0000000 --- a/bin/setup_linux.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/usr/bin/env bash - -# install pip -apt install -y python3-pip - -# install docker -sudo apt-get update -sudo apt-get -y install \ - ca-certificates \ - curl \ - gnupg \ - lsb-release - -sudo mkdir -p /etc/apt/keyrings -curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg - -echo \ - "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \ - $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null - -sudo apt-get update - -sudo apt-get -y install docker-ce docker-ce-cli containerd.io docker-compose-plugin - -pip3 install docker-compose - - -# get the code -git clone https://oauth2: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@github.com/chroma-core/chroma.git - -#checkout the right branch -cd chroma - -# run docker -cd chroma-server -docker-compose up -d --build - -# install chroma-client -cd ../chroma-client -pip3 install --upgrade pip # you have to do this or it will use UNKNOWN as the package name -pip3 install . diff --git a/bin/setup_mac.sh b/bin/setup_mac.sh deleted file mode 100644 index 316fab6..0000000 --- a/bin/setup_mac.sh +++ /dev/null @@ -1,18 +0,0 @@ -# requirements -# - docker -# - pip - -# get the code -git clone https://oauth2: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@github.com/chroma-core/chroma.git - -#checkout the right branch -cd chroma - -# run docker -cd chroma-server -docker-compose up -d --build - -# install chroma-client -cd ../chroma-client -pip install --upgrade pip # you have to do this or it will use UNKNOWN as the package name -pip install . diff --git a/bin/templates/docker-compose.yml b/bin/templates/docker-compose.yml index c618b96..d3199d6 100644 --- a/bin/templates/docker-compose.yml +++ b/bin/templates/docker-compose.yml @@ -9,37 +9,12 @@ services: image: ghcr.io/chroma-core/chroma:${ChromaVersion} volumes: - index_data:/index_data - environment: - - CHROMA_DB_IMPL=clickhouse - - CLICKHOUSE_HOST=clickhouse - - CLICKHOUSE_PORT=8123 ports: - 8000:8000 - depends_on: - - clickhouse - networks: - - net - - clickhouse: - image: clickhouse/clickhouse-server:22.9-alpine - environment: - - ALLOW_EMPTY_PASSWORD=yes - - CLICKHOUSE_TCP_PORT=9000 - - CLICKHOUSE_HTTP_PORT=8123 - ports: - - '8123:8123' - - '9000:9000' - volumes: - - clickhouse_data:/bitnami/clickhouse - - backups:/backups - - ./config/backup_disk.xml:/etc/clickhouse-server/config.d/backup_disk.xml - - ./config/chroma_users.xml:/etc/clickhouse-server/users.d/chroma.xml networks: - net volumes: - clickhouse_data: - driver: local index_data: driver: local backups: diff --git a/bin/test-remote b/bin/test-remote index 04c2990..9997baf 100755 --- a/bin/test-remote +++ b/bin/test-remote @@ -10,7 +10,7 @@ fi export CHROMA_INTEGRATION_TEST_ONLY=1 export CHROMA_SERVER_HOST=$1 -export CHROMA_API_IMPL=rest +export CHROMA_API_IMPL=chromadb.api.fastapi.FastAPI export CHROMA_SERVER_HTTP_PORT=8000 python -m pytest diff --git a/chromadb/__init__.py b/chromadb/__init__.py index 4c980f8..9b7afba 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -1,3 +1,4 @@ +from typing import Dict import chromadb.config import logging from chromadb.telemetry.events import ClientStartEvent @@ -22,6 +23,58 @@ def get_settings() -> Settings: return __settings +def EphemeralClient(settings: Settings = Settings()) -> API: + """ + Creates an in-memory instance of Chroma. This is useful for testing and + development, but not recommended for production use. + """ + settings.is_persistent = False + + return Client(settings) + + +def PersistentClient(path: str = "./chroma", settings: Settings = Settings()) -> API: + """ + Creates a persistent instance of Chroma that saves to disk. This is useful for + testing and development, but not recommended for production use. + + Args: + path: The directory to save Chroma's data to. Defaults to "./chroma". + """ + settings.persist_directory = path + settings.is_persistent = True + + return Client(settings) + + +def HttpClient( + host: str = "localhost", + port: str = "8000", + ssl: bool = False, + headers: Dict[str, str] = {}, + settings: Settings = Settings(), +) -> API: + """ + Creates a client that connects to a remote Chroma server. This supports + many clients connecting to the same server, and is the recommended way to + use Chroma in production. + + Args: + host: The hostname of the Chroma server. Defaults to "localhost". + port: The port of the Chroma server. Defaults to "8000". + ssl: Whether to use SSL to connect to the Chroma server. Defaults to False. + headers: A dictionary of headers to send to the Chroma server. Defaults to {}. + """ + + settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI" + settings.chroma_server_host = host + settings.chroma_server_http_port = port + settings.chroma_server_ssl_enabled = ssl + settings.chroma_server_headers = headers + + return Client(settings) + + def Client(settings: Settings = __settings) -> API: """Return a running chroma.API instance""" diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index f4d57d8..195d7ee 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -16,7 +16,7 @@ from chromadb.api.types import ( GetResult, WhereDocument, ) -from chromadb.config import Component +from chromadb.config import Component, Settings import chromadb.utils.embedding_functions as ef @@ -389,17 +389,6 @@ class API(Component, ABC): """ pass - @abstractmethod - def persist(self) -> bool: - """Persist the database to disk - - Returns: - bool: True if the database was persisted successfully - - """ - - pass - @abstractmethod def get_version(self) -> str: """Get the version of Chroma. @@ -409,3 +398,13 @@ class API(Component, ABC): """ pass + + @abstractmethod + def get_settings(self) -> Settings: + """Get the settings used to initialize the client. + + Returns: + Settings: The settings used to initialize the client. + + """ + pass diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 19adc10..d0fc804 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -1,6 +1,6 @@ from typing import Optional, cast from chromadb.api import API -from chromadb.config import System +from chromadb.config import Settings, System from chromadb.api.types import ( Documents, Embeddings, @@ -27,6 +27,8 @@ from overrides import override class FastAPI(API): + _settings: Settings + def __init__(self, system: System): super().__init__(system) url_prefix = "https" if system.settings.chroma_server_ssl_enabled else "http" @@ -34,6 +36,7 @@ class FastAPI(API): system.settings.require("chroma_server_http_port") self._telemetry_client = self.require(Telemetry) + self._settings = system.settings port_suffix = ( f":{system.settings.chroma_server_http_port}" @@ -352,13 +355,6 @@ class FastAPI(API): raise_chroma_error(resp) return cast(bool, resp.json()) - @override - def persist(self) -> bool: - """Persists the database""" - resp = self._session.post(self._api_url + "/persist") - raise_chroma_error(resp) - return cast(bool, resp.json()) - @override def raw_sql(self, sql: str) -> pd.DataFrame: """Runs a raw SQL query against the database""" @@ -384,8 +380,14 @@ class FastAPI(API): raise_chroma_error(resp) return cast(str, resp.json()) + @override + def get_settings(self) -> Settings: + """Returns the settings of the client""" + return self._settings + def raise_chroma_error(resp: requests.Response) -> None: + """Raises an error if the response is not ok, using a ChromaError if possible""" if resp.ok: return diff --git a/chromadb/api/local.py b/chromadb/api/local.py deleted file mode 100644 index aa7ed88..0000000 --- a/chromadb/api/local.py +++ /dev/null @@ -1,485 +0,0 @@ -import json -import time -from uuid import UUID -from typing import List, Optional, Sequence, cast -from chromadb import __version__ -from chromadb.api import API -from chromadb.db import DB -from chromadb.api.types import ( - Documents, - EmbeddingFunction, - Embeddings, - GetResult, - IDs, - Include, - Metadata, - Metadatas, - QueryResult, - Where, - WhereDocument, - CollectionMetadata, - validate_metadata, -) -from chromadb.api.models.Collection import Collection -from chromadb.config import System -import chromadb.utils.embedding_functions as ef -import re - -from chromadb.telemetry import Telemetry -from chromadb.telemetry.events import CollectionAddEvent, CollectionDeleteEvent -from overrides import override -import pandas as pd -import logging - -logger = logging.getLogger(__name__) - - -# mimics s3 bucket requirements for naming -def check_index_name(index_name: str) -> None: - msg = ( - "Expected collection name that " - "(1) contains 3-63 characters, " - "(2) starts and ends with an alphanumeric character, " - "(3) otherwise contains only alphanumeric characters, underscores or hyphens (-), " - "(4) contains no two consecutive periods (..) and " - "(5) is not a valid IPv4 address, " - f"got {index_name}" - ) - if len(index_name) < 3 or len(index_name) > 63: - raise ValueError(msg) - if not re.match("^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$", index_name): - raise ValueError(msg) - if ".." in index_name: - raise ValueError(msg) - if re.match("^[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}$", index_name): - raise ValueError(msg) - - -class LocalAPI(API): - _db: DB - _telemetry_client: Telemetry - - def __init__(self, system: System): - super().__init__(system) - self._db = self.require(DB) - self._telemetry_client = self.require(Telemetry) - - @override - def heartbeat(self) -> int: - return int(time.time_ns()) - - # - # COLLECTION METHODS - # - @override - def list_collections(self) -> Sequence[Collection]: - collections = [] - db_collections = self._db.list_collections() - for db_collection in db_collections: - collections.append( - Collection( - client=self, - id=db_collection[0], - name=db_collection[1], - metadata=db_collection[2], - ) - ) - return collections - - @override - def create_collection( - self, - name: str, - metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), - get_or_create: bool = False, - ) -> Collection: - check_index_name(name) - - if metadata is not None: - validate_metadata(metadata) - - res = self._db.create_collection(name, metadata, get_or_create) - return Collection( - client=self, - name=name, - embedding_function=embedding_function, - id=res[0][0], - metadata=res[0][2], - ) - - @override - def get_or_create_collection( - self, - name: str, - metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), - ) -> Collection: - if metadata is not None: - validate_metadata(metadata) - - return self.create_collection( - name, metadata, embedding_function, get_or_create=True - ) - - @override - def get_collection( - self, - name: str, - embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), - ) -> Collection: - res = self._db.get_collection(name) - if len(res) == 0: - raise ValueError(f"Collection {name} does not exist") - return Collection( - client=self, - name=name, - id=res[0][0], - embedding_function=embedding_function, - metadata=res[0][2], - ) - - @override - def _modify( - self, - id: UUID, - new_name: Optional[str] = None, - new_metadata: Optional[CollectionMetadata] = None, - ) -> None: - if new_name is not None: - check_index_name(new_name) - - self._db.update_collection(id, new_name, new_metadata) - - @override - def delete_collection(self, name: str) -> None: - self._db.delete_collection(name) - - # - # ITEM METHODS - # - @override - def _add( - self, - ids: IDs, - collection_id: UUID, - embeddings: Embeddings, - metadatas: Optional[Metadatas] = None, - documents: Optional[Documents] = None, - increment_index: bool = True, - ) -> bool: - existing_ids = set(self._get(collection_id, ids=ids, include=[])["ids"]) - if len(existing_ids) > 0: - logger.info(f"Adding {len(existing_ids)} items with ids that already exist") - # Partially add the items that don't already exist - valid_indices = [i for i, id in enumerate(ids) if id not in existing_ids] - if len(valid_indices) == 0: - return False - filtered_ids: IDs = [] - filtered_embeddings: Embeddings = [] - if metadatas is not None: - filtered_metadatas: Metadatas = [] - if documents is not None: - filtered_documents: Documents = [] - for index in valid_indices: - filtered_ids.append(ids[index]) - filtered_embeddings.append(embeddings[index]) - if metadatas is not None: - filtered_metadatas.append(metadatas[index]) - if documents is not None: - filtered_documents.append(documents[index]) - ids = filtered_ids - embeddings = filtered_embeddings - if metadatas is not None: - metadatas = filtered_metadatas - if documents is not None: - documents = filtered_documents - - added_uuids = self._db.add( - collection_id, - embeddings=embeddings, - metadatas=metadatas, - documents=documents, - ids=ids, - ) - - if increment_index: - self._db.add_incremental(collection_id, added_uuids, embeddings) - - self._telemetry_client.capture(CollectionAddEvent(str(collection_id), len(ids))) - return True # NIT: should this return the ids of the succesfully added items? - - @override - def _update( - self, - collection_id: UUID, - ids: IDs, - embeddings: Optional[Embeddings] = None, - metadatas: Optional[Metadatas] = None, - documents: Optional[Documents] = None, - ) -> bool: - self._db.update(collection_id, ids, embeddings, metadatas, documents) - return True - - @override - def _upsert( - self, - collection_id: UUID, - ids: IDs, - embeddings: Embeddings, - metadatas: Optional[Metadatas] = None, - documents: Optional[Documents] = None, - increment_index: bool = True, - ) -> bool: - # Determine which ids need to be added and which need to be updated based on the ids already in the collection - existing_ids = set(self._get(collection_id, ids=ids, include=[])["ids"]) - - ids_to_add = [] - ids_to_update = [] - embeddings_to_add: Embeddings = [] - embeddings_to_update: Embeddings = [] - metadatas_to_add: Optional[Metadatas] = [] if metadatas else None - metadatas_to_update: Optional[Metadatas] = [] if metadatas else None - documents_to_add: Optional[Documents] = [] if documents else None - documents_to_update: Optional[Documents] = [] if documents else None - - for i, id in enumerate(ids): - if id in existing_ids: - ids_to_update.append(id) - if embeddings is not None: - embeddings_to_update.append(embeddings[i]) - if metadatas is not None: - metadatas_to_update.append(metadatas[i]) # type: ignore - if documents is not None: - documents_to_update.append(documents[i]) # type: ignore - else: - ids_to_add.append(id) - if embeddings is not None: - embeddings_to_add.append(embeddings[i]) - if metadatas is not None: - metadatas_to_add.append(metadatas[i]) # type: ignore - if documents is not None: - documents_to_add.append(documents[i]) # type: ignore - - if len(ids_to_add) > 0: - self._add( - ids_to_add, - collection_id, - embeddings_to_add, - metadatas_to_add, - documents_to_add, - increment_index=increment_index, - ) - - if len(ids_to_update) > 0: - self._update( - collection_id, - ids_to_update, - embeddings_to_update, - metadatas_to_update, - documents_to_update, - ) - self._db.update(collection_id, ids, embeddings, metadatas, documents) - - return True - - @override - def _get( - self, - collection_id: UUID, - ids: Optional[IDs] = None, - where: Optional[Where] = {}, - sort: Optional[str] = None, - limit: Optional[int] = None, - offset: Optional[int] = None, - page: Optional[int] = None, - page_size: Optional[int] = None, - where_document: Optional[WhereDocument] = {}, - include: Include = ["embeddings", "metadatas", "documents"], - ) -> GetResult: - if where is None: - where = {} - - if where_document is None: - where_document = {} - - if page and page_size: - offset = (page - 1) * page_size - limit = page_size - - include_embeddings = "embeddings" in include - include_documents = "documents" in include - include_metadatas = "metadatas" in include - - # Remove plural from include since db columns are singular - db_columns = [column[:-1] for column in include] + ["id"] - column_index = { - column_name: index for index, column_name in enumerate(db_columns) - } - - db_result = self._db.get( - collection_uuid=collection_id, - ids=ids, - where=where, - sort=sort, - limit=limit, - offset=offset, - where_document=where_document, - columns=db_columns, - ) - - get_result = GetResult( - ids=[], - embeddings=[] if include_embeddings else None, - documents=[] if include_documents else None, - metadatas=[] if include_metadatas else None, - ) - - for entry in db_result: - if include_embeddings: - cast(List, get_result["embeddings"]).append( # type: ignore - entry[column_index["embedding"]] - ) - if include_documents: - cast(List, get_result["documents"]).append( # type: ignore - entry[column_index["document"]] - ) - if include_metadatas: - cast(List, get_result["metadatas"]).append( # type: ignore - entry[column_index["metadata"]] - ) - get_result["ids"].append(entry[column_index["id"]]) - return get_result - - @override - def _delete( - self, - collection_id: UUID, - ids: Optional[IDs] = None, - where: Optional[Where] = None, - where_document: Optional[WhereDocument] = None, - ) -> IDs: - if where is None: - where = {} - - if where_document is None: - where_document = {} - - deleted_uuids = self._db.delete( - collection_uuid=collection_id, - where=where, - ids=ids, - where_document=where_document, - ) - self._telemetry_client.capture( - CollectionDeleteEvent(str(collection_id), len(deleted_uuids)) - ) - - return deleted_uuids - - @override - def _count(self, collection_id: UUID) -> int: - return self._db.count(collection_id) - - @override - def reset(self) -> bool: - self._db.reset_state() - return True - - @override - def _query( - self, - collection_id: UUID, - query_embeddings: Embeddings, - n_results: int = 10, - where: Where = {}, - where_document: WhereDocument = {}, - include: Include = ["documents", "metadatas", "distances"], - ) -> QueryResult: - uuids, distances = self._db.get_nearest_neighbors( - collection_uuid=collection_id, - where=where, - where_document=where_document, - embeddings=query_embeddings, - n_results=n_results, - ) - - include_embeddings = "embeddings" in include - include_documents = "documents" in include - include_metadatas = "metadatas" in include - include_distances = "distances" in include - - query_result = QueryResult( - ids=[], - embeddings=[] if include_embeddings else None, - documents=[] if include_documents else None, - metadatas=[] if include_metadatas else None, - distances=[] if include_distances else None, - ) - for i in range(len(uuids)): - embeddings: Embeddings = [] - documents: Documents = [] - ids: IDs = [] - metadatas: List[Optional[Metadata]] = [] - # Remove plural from include since db columns are singular - db_columns = [ - column[:-1] for column in include if column != "distances" - ] + ["id"] - column_index = { - column_name: index for index, column_name in enumerate(db_columns) - } - db_result = self._db.get_by_ids(uuids[i], columns=db_columns) - - for entry in db_result: - if include_embeddings: - embeddings.append(entry[column_index["embedding"]]) - if include_documents: - documents.append(entry[column_index["document"]]) - if include_metadatas: - metadatas.append( - json.loads(entry[column_index["metadata"]]) - if entry[column_index["metadata"]] - else None - ) - ids.append(entry[column_index["id"]]) - - if include_embeddings: - cast(List[Embeddings], query_result["embeddings"]).append(embeddings) - if include_documents: - cast(List[Documents], query_result["documents"]).append(documents) - if include_metadatas: - cast(List[List[Optional[Metadata]]], query_result["metadatas"]).append( - metadatas - ) - if include_distances: - cast(List[List[float]], query_result["distances"]).append(distances[i]) - query_result["ids"].append(ids) - - return query_result - - @override - def raw_sql(self, sql: str) -> pd.DataFrame: - return self._db.raw_sql(sql) # type: ignore - - @override - def create_index(self, collection_name: str) -> bool: - collection_uuid = self._db.get_collection_uuid_from_name(collection_name) - self._db.create_index(collection_uuid=collection_uuid) - return True - - @override - def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: - return self._get( - collection_id=collection_id, - limit=n, - include=["embeddings", "documents", "metadatas"], - ) - - @override - def persist(self) -> bool: - self._db.persist() - return True - - @override - def get_version(self) -> str: - return __version__ diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 0014c6f..7904b42 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -1,13 +1,13 @@ from chromadb.api import API -from chromadb.config import System +from chromadb.config import Settings, System from chromadb.db.system import SysDB from chromadb.segment import SegmentManager, MetadataReader, VectorReader from chromadb.telemetry import Telemetry from chromadb.ingest import Producer from chromadb.api.models.Collection import Collection -import chromadb.api.local as old_api from chromadb import __version__ from chromadb.errors import InvalidDimensionException, InvalidCollectionException +import chromadb.utils.embedding_functions as ef from chromadb.api.types import ( CollectionMetadata, @@ -24,7 +24,10 @@ from chromadb.api.types import ( QueryResult, validate_metadata, validate_update_metadata, + validate_where, + validate_where_document, ) +from chromadb.telemetry.events import CollectionAddEvent, CollectionDeleteEvent import chromadb.types as t @@ -34,16 +37,40 @@ from uuid import UUID, uuid4 import pandas as pd import time import logging +import re logger = logging.getLogger(__name__) +# mimics s3 bucket requirements for naming +def check_index_name(index_name: str) -> None: + msg = ( + "Expected collection name that " + "(1) contains 3-63 characters, " + "(2) starts and ends with an alphanumeric character, " + "(3) otherwise contains only alphanumeric characters, underscores or hyphens (-), " + "(4) contains no two consecutive periods (..) and " + "(5) is not a valid IPv4 address, " + f"got {index_name}" + ) + if len(index_name) < 3 or len(index_name) > 63: + raise ValueError(msg) + if not re.match("^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$", index_name): + raise ValueError(msg) + if ".." in index_name: + raise ValueError(msg) + if re.match("^[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}$", index_name): + raise ValueError(msg) + + class SegmentAPI(API): """API implementation utilizing the new segment-based internal architecture""" + _settings: Settings _sysdb: SysDB _manager: SegmentManager _producer: Producer + # TODO: fire telemetry events _telemetry_client: Telemetry _tenant_id: str _topic_ns: str @@ -51,6 +78,7 @@ class SegmentAPI(API): def __init__(self, system: System): super().__init__(system) + self._settings = system.settings self._sysdb = self.require(SysDB) self._manager = self.require(SegmentManager) self._telemetry_client = self.require(Telemetry) @@ -71,7 +99,7 @@ class SegmentAPI(API): self, name: str, metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[EmbeddingFunction] = None, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), get_or_create: bool = False, ) -> Collection: existing = self._sysdb.get_collections(name=name) @@ -94,8 +122,8 @@ class SegmentAPI(API): else: raise ValueError(f"Collection {name} already exists.") - # backwards compatibility in naming requirements (for now) - old_api.check_index_name(name) + # TODO: remove backwards compatibility in naming requirements + check_index_name(name) id = uuid4() coll = t.Collection( @@ -120,7 +148,7 @@ class SegmentAPI(API): self, name: str, metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[EmbeddingFunction] = None, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), ) -> Collection: return self.create_collection( name=name, @@ -136,7 +164,7 @@ class SegmentAPI(API): def get_collection( self, name: str, - embedding_function: Optional[EmbeddingFunction] = None, + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), ) -> Collection: existing = self._sysdb.get_collections(name=name) @@ -175,7 +203,7 @@ class SegmentAPI(API): ) -> None: if new_name: # backwards compatibility in naming requirements (for now) - old_api.check_index_name(new_name) + check_index_name(new_name) if new_metadata: validate_update_metadata(new_metadata) @@ -214,11 +242,13 @@ class SegmentAPI(API): increment_index: bool = True, ) -> bool: coll = self._get_collection(collection_id) + self._manager.hint_use_collection(collection_id, t.Operation.ADD) for r in _records(t.Operation.ADD, ids, embeddings, metadatas, documents): self._validate_embedding_record(coll, r) self._producer.submit_embedding(coll["topic"], r) + self._telemetry_client.capture(CollectionAddEvent(str(collection_id), len(ids))) return True @override @@ -231,6 +261,7 @@ class SegmentAPI(API): documents: Optional[Documents] = None, ) -> bool: coll = self._get_collection(collection_id) + self._manager.hint_use_collection(collection_id, t.Operation.UPDATE) for r in _records(t.Operation.UPDATE, ids, embeddings, metadatas, documents): self._validate_embedding_record(coll, r) @@ -249,6 +280,8 @@ class SegmentAPI(API): increment_index: bool = True, ) -> bool: coll = self._get_collection(collection_id) + self._manager.hint_use_collection(collection_id, t.Operation.UPSERT) + for r in _records(t.Operation.UPSERT, ids, embeddings, metadatas, documents): self._validate_embedding_record(coll, r) self._producer.submit_embedding(coll["topic"], r) @@ -269,6 +302,13 @@ class SegmentAPI(API): where_document: Optional[WhereDocument] = {}, include: Include = ["embeddings", "metadatas", "documents"], ) -> GetResult: + where = validate_where(where) if where is not None and len(where) > 0 else None + where_document = ( + validate_where_document(where_document) + if where_document is not None and len(where_document) > 0 + else None + ) + metadata_segment = self._manager.get_segment(collection_id, MetadataReader) if sort is not None: @@ -316,7 +356,15 @@ class SegmentAPI(API): where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, ) -> IDs: + where = validate_where(where) if where is not None and len(where) > 0 else None + where_document = ( + validate_where_document(where_document) + if where_document is not None and len(where_document) > 0 + else None + ) + coll = self._get_collection(collection_id) + self._manager.hint_use_collection(collection_id, t.Operation.DELETE) # TODO: Do we want to warn the user that unrestricted _delete() is 99% of the # time a bad idea? @@ -333,6 +381,9 @@ class SegmentAPI(API): self._validate_embedding_record(coll, r) self._producer.submit_embedding(coll["topic"], r) + self._telemetry_client.capture( + CollectionDeleteEvent(str(collection_id), len(ids_to_delete)) + ) return ids_to_delete @override @@ -350,6 +401,13 @@ class SegmentAPI(API): where_document: WhereDocument = {}, include: Include = ["documents", "metadatas", "distances"], ) -> QueryResult: + where = validate_where(where) if where is not None and len(where) > 0 else where + where_document = ( + validate_where_document(where_document) + if where_document is not None and len(where_document) > 0 + else where_document + ) + allowed_ids = None coll = self._get_collection(collection_id) @@ -395,7 +453,16 @@ class SegmentAPI(API): records = metadata_reader.get_metadata(ids=list(all_ids)) metadata_by_id = {r["id"]: r["metadata"] for r in records} for id_list in ids: - metadata_list = [metadata_by_id[id] for id in id_list] + # In the segment based architecture, it is possible for one segment + # to have a record that another segment does not have. This results in + # data inconsistency. For the case of the local segments and the + # local segment manager, there is a case where a thread writes + # a record to the vector segment but not the metadata segment. + # Then a query'ing thread reads from the vector segment and + # queries the metadata segment. The metadata segment does not have + # the record. In this case we choose to return potentially + # incorrect data in the form of None. + metadata_list = [metadata_by_id.get(id, None) for id in id_list] if "metadatas" in include: metadatas.append(_clean_metadatas(metadata_list)) # type: ignore if "documents" in include: @@ -439,11 +506,8 @@ class SegmentAPI(API): return True @override - def persist(self) -> bool: - logger.warning( - "Calling persist is unnecessary, data is now automatically indexed." - ) - return True + def get_settings(self) -> Settings: + return self._settings def _topic(self, collection_id: UUID) -> str: return f"persistent://{self._tenant_id}/{self._topic_ns}/{collection_id}" diff --git a/chromadb/api/types.py b/chromadb/api/types.py index bd4df15..9f5ab12 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -124,8 +124,10 @@ def validate_ids(ids: IDs) -> IDs: def validate_metadata(metadata: Metadata) -> Metadata: """Validates metadata to ensure it is a dictionary of strings to strings, ints, or floats""" - if not isinstance(metadata, dict): - raise ValueError(f"Expected metadata to be a dict, got {metadata}") + if not isinstance(metadata, dict) and metadata is not None: + raise ValueError(f"Expected metadata to be a dict or None, got {metadata}") + if metadata is None: + return metadata if len(metadata) == 0: raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}") for key, value in metadata.items(): @@ -143,8 +145,10 @@ def validate_metadata(metadata: Metadata) -> Metadata: def validate_update_metadata(metadata: UpdateMetadata) -> UpdateMetadata: """Validates metadata to ensure it is a dictionary of strings to strings, ints, or floats""" - if not isinstance(metadata, dict): - raise ValueError(f"Expected metadata to be a dict, got {metadata}") + if not isinstance(metadata, dict) and metadata is not None: + raise ValueError(f"Expected metadata to be a dict or None, got {metadata}") + if metadata is None: + return metadata if len(metadata) == 0: raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}") for key, value in metadata.items(): diff --git a/chromadb/config.py b/chromadb/config.py index 75065e7..dbd0c93 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -18,17 +18,27 @@ except ImportError: logger = logging.getLogger(__name__) + +LEGACY_ERROR = "You are using a deprecated configuration of Chroma. Please pip install chroma-migrate and run `chroma-migrate` to upgrade your configuration. See https://docs.trychroma.com/migration for more information or join our discord at https://discord.gg/8g5FESbj for help!" + +_legacy_config_keys = { + "chroma_db_impl", +} + _legacy_config_values = { - "duckdb": "chromadb.db.duckdb.DuckDB", - "duckdb+parquet": "chromadb.db.duckdb.PersistentDuckDB", - "clickhouse": "chromadb.db.clickhouse.Clickhouse", - "rest": "chromadb.api.fastapi.FastAPI", - "local": "chromadb.api.local.LocalAPI", + "duckdb", + "duckdb+parquet", + "clickhouse", + "local", + "rest", + "chromadb.db.duckdb.DuckDB", + "chromadb.db.duckdb.PersistentDuckDB", + "chromadb.db.clickhouse.Clickhouse", + "chromadb.api.local.LocalAPI", } # TODO: Don't use concrete types here to avoid circular deps. Strings are fine for right here! _abstract_type_keys: Dict[str, str] = { - "chromadb.db.DB": "chroma_db_impl", "chromadb.api.API": "chroma_api_impl", "chromadb.telemetry.Telemetry": "chroma_telemetry_impl", "chromadb.ingest.Producer": "chroma_producer_impl", @@ -41,8 +51,10 @@ _abstract_type_keys: Dict[str, str] = { class Settings(BaseSettings): environment: str = "" - chroma_db_impl: str = "chromadb.db.duckdb.DuckDB" - chroma_api_impl: str = "chromadb.api.local.LocalAPI" + # Legacy config has to be kept around because pydantic will error on nonexisting keys + chroma_db_impl: Optional[str] = None + + chroma_api_impl: str = "chromadb.api.segment.SegmentAPI" # Can be "chromadb.api.segment.SegmentAPI" or "chromadb.api.fastapi.FastAPI" chroma_telemetry_impl: str = "chromadb.telemetry.posthog.Posthog" # New architecture components @@ -53,13 +65,11 @@ class Settings(BaseSettings): "chromadb.segment.impl.manager.local.LocalSegmentManager" ) - clickhouse_host: Optional[str] = None - clickhouse_port: Optional[str] = None - tenant_id: str = "default" topic_namespace: str = "default" - persist_directory: str = ".chroma" + is_persistent: bool = False + persist_directory: str = "./chroma" chroma_server_host: Optional[str] = None chroma_server_headers: Optional[Dict[str, str]] = None @@ -72,7 +82,6 @@ class Settings(BaseSettings): allow_reset: bool = False - sqlite_database: Optional[str] = ":memory:" migrations: Literal["none", "validate", "apply"] = "apply" def require(self, key: str) -> Any: @@ -85,10 +94,9 @@ class Settings(BaseSettings): def __getitem__(self, key: str) -> Any: val = getattr(self, key) - # Backwards compatibility with short names instead of full class names + # Error on legacy config values if val in _legacy_config_values: - newval = _legacy_config_values[val] - val = newval + raise ValueError(LEGACY_ERROR) return val class Config: @@ -143,6 +151,19 @@ class System(Component): _instances: Dict[Type[Component], Component] def __init__(self, settings: Settings): + if is_thin_client: + # The thin client is a system with only the API component + if settings["chroma_api_impl"] != "chromadb.api.fastapi.FastAPI": + raise RuntimeError( + "Chroma is running in http-only client mode, and can only be run with 'chromadb.api.fastapi.FastAPI' as the chroma_api_impl. \ + see https://docs.trychroma.com/usage-guide?lang=py#using-the-python-http-only-client for more information." + ) + + # Validate settings don't contain any legacy config values + for key in _legacy_config_keys: + if settings[key] is not None: + raise ValueError(LEGACY_ERROR) + self.settings = settings self._instances = {} super().__init__(self) diff --git a/chromadb/db/__init__.py b/chromadb/db/__init__.py index 6ebd014..1ceb3fa 100644 --- a/chromadb/db/__init__.py +++ b/chromadb/db/__init__.py @@ -129,7 +129,3 @@ class DB(Component): @abstractmethod def create_index(self, collection_uuid: UUID): # type: ignore pass - - @abstractmethod - def persist(self) -> None: - pass diff --git a/chromadb/db/clickhouse.py b/chromadb/db/clickhouse.py deleted file mode 100644 index d53be3b..0000000 --- a/chromadb/db/clickhouse.py +++ /dev/null @@ -1,656 +0,0 @@ -# type: ignore -from chromadb.api.types import ( - Documents, - Embeddings, - IDs, - Metadatas, - Where, - WhereDocument, -) -from chromadb.db import DB -from chromadb.db.index.hnswlib import Hnswlib, delete_all_indexes -import uuid -import json -from typing import Optional, Sequence, List, Tuple, cast -import clickhouse_connect -from clickhouse_connect.driver.client import Client -from clickhouse_connect import common -import logging -from uuid import UUID -from chromadb.config import System -from overrides import override -from chromadb.api.types import Metadata - -logger = logging.getLogger(__name__) - -COLLECTION_TABLE_SCHEMA = [{"uuid": "UUID"}, {"name": "String"}, {"metadata": "String"}] - -EMBEDDING_TABLE_SCHEMA = [ - {"collection_uuid": "UUID"}, - {"uuid": "UUID"}, - {"embedding": "Array(Float64)"}, - {"document": "Nullable(String)"}, - {"id": "Nullable(String)"}, - {"metadata": "Nullable(String)"}, -] - - -def db_array_schema_to_clickhouse_schema(table_schema): - return_str = "" - for element in table_schema: - for k, v in element.items(): - return_str += f"{k} {v}, " - return return_str - - -def db_schema_to_keys() -> List[str]: - keys = [] - for element in EMBEDDING_TABLE_SCHEMA: - keys.append(list(element.keys())[0]) - return keys - - -class Clickhouse(DB): - # - # INIT METHODS - # - def __init__(self, system: System): - super().__init__(system) - self._conn = None - self._settings = system.settings - - self._settings.require("clickhouse_host") - self._settings.require("clickhouse_port") - - def _init_conn(self): - common.set_setting("autogenerate_session_id", False) - self._conn = clickhouse_connect.get_client( - host=self._settings.clickhouse_host, - port=int(self._settings.clickhouse_port), - ) - self._create_table_collections(self._conn) - self._create_table_embeddings(self._conn) - - def _get_conn(self) -> Client: - if self._conn is None: - self._init_conn() - return self._conn - - def _create_table_collections(self, conn): - conn.command( - f"""CREATE TABLE IF NOT EXISTS collections ( - {db_array_schema_to_clickhouse_schema(COLLECTION_TABLE_SCHEMA)} - ) ENGINE = MergeTree() ORDER BY uuid""" - ) - - def _create_table_embeddings(self, conn): - conn.command( - f"""CREATE TABLE IF NOT EXISTS embeddings ( - {db_array_schema_to_clickhouse_schema(EMBEDDING_TABLE_SCHEMA)} - ) ENGINE = MergeTree() ORDER BY collection_uuid""" - ) - - index_cache = {} - - def _index(self, collection_id): - """Retrieve an HNSW index instance for the given collection""" - - if collection_id not in self.index_cache: - coll = self.get_collection_by_id(collection_id) - collection_metadata = coll[2] - index = Hnswlib( - collection_id, - self._settings, - collection_metadata, - self.count(collection_id), - ) - self.index_cache[collection_id] = index - - return self.index_cache[collection_id] - - def _delete_index(self, collection_id): - """Delete an index from the cache""" - index = self._index(collection_id) - index.delete() - del self.index_cache[collection_id] - - # - # UTILITY METHODS - # - @override - def persist(self): - raise NotImplementedError( - "Clickhouse is a persistent database, this method is not needed" - ) - - @override - def get_collection_uuid_from_name(self, collection_name: str) -> UUID: - res = self._get_conn().query( - f""" - SELECT uuid FROM collections WHERE name = '{collection_name}' - """ - ) - return res.result_rows[0][0] - - def _create_where_clause( - self, - collection_uuid: str, - ids: Optional[List[str]] = None, - where: Where = {}, - where_document: WhereDocument = {}, - ): - where_clauses: List[str] = [] - self._format_where(where, where_clauses) - if len(where_document) > 0: - where_document_clauses = [] - self._format_where_document(where_document, where_document_clauses) - where_clauses.extend(where_document_clauses) - - if ids is not None: - where_clauses.append(f" id IN {tuple(ids)}") - - where_clauses.append(f"collection_uuid = '{collection_uuid}'") - where_str = " AND ".join(where_clauses) - where_str = f"WHERE {where_str}" - return where_str - - # - # COLLECTION METHODS - # - @override - def create_collection( - self, - name: str, - metadata: Optional[Metadata] = None, - get_or_create: bool = False, - ) -> Sequence: - # poor man's unique constraint - dupe_check = self.get_collection(name) - - if len(dupe_check) > 0: - if get_or_create: - if dupe_check[0][2] != metadata: - self.update_collection( - dupe_check[0][0], new_name=name, new_metadata=metadata - ) - dupe_check = self.get_collection(name) - logger.info( - f"collection with name {name} already exists, returning existing collection" - ) - return dupe_check - else: - raise ValueError(f"Collection with name {name} already exists") - - collection_uuid = uuid.uuid4() - data_to_insert = [[collection_uuid, name, json.dumps(metadata)]] - - self._get_conn().insert( - "collections", data_to_insert, column_names=["uuid", "name", "metadata"] - ) - return [[collection_uuid, name, metadata]] - - @override - def get_collection(self, name: str) -> Sequence: - res = ( - self._get_conn() - .query( - f""" - SELECT * FROM collections WHERE name = '{name}' - """ - ) - .result_rows - ) - # json.loads the metadata - return [[x[0], x[1], json.loads(x[2])] for x in res] - - def get_collection_by_id(self, collection_uuid: str): - res = ( - self._get_conn() - .query( - f""" - SELECT * FROM collections WHERE uuid = '{collection_uuid}' - """ - ) - .result_rows - ) - # json.loads the metadata - return [[x[0], x[1], json.loads(x[2])] for x in res][0] - - @override - def list_collections(self) -> Sequence: - res = self._get_conn().query("SELECT * FROM collections").result_rows - return [[x[0], x[1], json.loads(x[2])] for x in res] - - @override - def update_collection( - self, - id: UUID, - new_name: Optional[str] = None, - new_metadata: Optional[Metadata] = None, - ): - if new_name is not None: - dupe_check = self.get_collection(new_name) - if len(dupe_check) > 0 and dupe_check[0][0] != id: - raise ValueError(f"Collection with name {new_name} already exists") - - self._get_conn().command( - "ALTER TABLE collections UPDATE name = %(new_name)s WHERE uuid = %(uuid)s", - parameters={"new_name": new_name, "uuid": id}, - ) - - if new_metadata is not None: - self._get_conn().command( - "ALTER TABLE collections UPDATE metadata = %(new_metadata)s WHERE uuid = %(uuid)s", - parameters={"new_metadata": json.dumps(new_metadata), "uuid": id}, - ) - - @override - def delete_collection(self, name: str): - collection_uuid = self.get_collection_uuid_from_name(name) - self._get_conn().command( - f""" - DELETE FROM embeddings WHERE collection_uuid = '{collection_uuid}' - """ - ) - - self._delete_index(collection_uuid) - - self._get_conn().command( - f""" - DELETE FROM collections WHERE name = '{name}' - """ - ) - - # - # ITEM METHODS - # - @override - def add(self, collection_uuid, embeddings, metadatas, documents, ids) -> List[UUID]: - data_to_insert = [ - [ - collection_uuid, - uuid.uuid4(), - embedding, - json.dumps(metadatas[i]) if metadatas else None, - documents[i] if documents else None, - ids[i], - ] - for i, embedding in enumerate(embeddings) - ] - column_names = [ - "collection_uuid", - "uuid", - "embedding", - "metadata", - "document", - "id", - ] - self._get_conn().insert("embeddings", data_to_insert, column_names=column_names) - - return [x[1] for x in data_to_insert] # return uuids - - def _update( - self, - collection_uuid, - ids: IDs, - embeddings: Optional[Embeddings], - metadatas: Optional[Metadatas], - documents: Optional[Documents], - ): - updates = [] - parameters = {} - for i in range(len(ids)): - update_fields = [] - parameters[f"i{i}"] = ids[i] - if embeddings is not None: - update_fields.append(f"embedding = %(e{i})s") - parameters[f"e{i}"] = embeddings[i] - if metadatas is not None: - update_fields.append(f"metadata = %(m{i})s") - parameters[f"m{i}"] = json.dumps(metadatas[i]) - if documents is not None: - update_fields.append(f"document = %(d{i})s") - parameters[f"d{i}"] = documents[i] - - update_statement = f""" - UPDATE - {",".join(update_fields)} - WHERE - id = %(i{i})s AND - collection_uuid = '{collection_uuid}'{"" if i == len(ids) - 1 else ","} - """ - updates.append(update_statement) - - update_clauses = ("").join(updates) - self._get_conn().command( - f"ALTER TABLE embeddings {update_clauses}", parameters=parameters - ) - - @override - def update( - self, - collection_uuid, - ids: IDs, - embeddings: Optional[Embeddings] = None, - metadatas: Optional[Metadatas] = None, - documents: Optional[Documents] = None, - ) -> bool: - # Verify all IDs exist - existing_items = self.get(collection_uuid=collection_uuid, ids=ids) - if len(existing_items) != len(ids): - raise ValueError( - f"Could not find {len(ids) - len(existing_items)} items for update" - ) - - # Update the db - self._update(collection_uuid, ids, embeddings, metadatas, documents) - - # Update the index - if embeddings is not None: - # `get` current returns items in arbitrary order. - # TODO if we fix `get`, we can remove this explicit mapping. - uuid_mapping = {r[4]: r[1] for r in existing_items} - update_uuids = [uuid_mapping[id] for id in ids] - index = self._index(collection_uuid) - index.add(update_uuids, embeddings, update=True) - - def _get(self, where={}, columns: Optional[List] = None): - select_columns = db_schema_to_keys() if columns is None else columns - val = ( - self._get_conn() - .query(f"""SELECT {",".join(select_columns)} FROM embeddings {where}""") - .result_rows - ) - for i in range(len(val)): - # We know val has index abilities, so cast it for typechecker - val = cast(list, val) - val[i] = list(val[i]) - # json.load the metadata - if "metadata" in select_columns: - metadata_column_index = select_columns.index("metadata") - db_metadata = val[i][metadata_column_index] - val[i][metadata_column_index] = ( - json.loads(db_metadata) if db_metadata else None - ) - return val - - def _format_where(self, where, result): - for key, value in where.items(): - - def has_key_and(clause): - return f"(JSONHas(metadata,'{key}') = 1 AND {clause})" - - # Shortcut for $eq - if type(value) == str: - result.append( - has_key_and(f" JSONExtractString(metadata,'{key}') = '{value}'") - ) - elif type(value) == int: - result.append( - has_key_and(f" JSONExtractInt(metadata,'{key}') = {value}") - ) - elif type(value) == float: - result.append( - has_key_and(f" JSONExtractFloat(metadata,'{key}') = {value}") - ) - # Operator expression - elif type(value) == dict: - operator, operand = list(value.items())[0] - if operator == "$gt": - return result.append( - has_key_and(f" JSONExtractFloat(metadata,'{key}') > {operand}") - ) - elif operator == "$lt": - return result.append( - has_key_and(f" JSONExtractFloat(metadata,'{key}') < {operand}") - ) - elif operator == "$gte": - return result.append( - has_key_and(f" JSONExtractFloat(metadata,'{key}') >= {operand}") - ) - elif operator == "$lte": - return result.append( - has_key_and(f" JSONExtractFloat(metadata,'{key}') <= {operand}") - ) - elif operator == "$ne": - if type(operand) == str: - return result.append( - has_key_and( - f" JSONExtractString(metadata,'{key}') != '{operand}'" - ) - ) - return result.append( - has_key_and(f" JSONExtractFloat(metadata,'{key}') != {operand}") - ) - elif operator == "$eq": - if type(operand) == str: - return result.append( - has_key_and( - f" JSONExtractString(metadata,'{key}') = '{operand}'" - ) - ) - return result.append( - has_key_and(f" JSONExtractFloat(metadata,'{key}') = {operand}") - ) - else: - raise ValueError( - f"Expected one of $gt, $lt, $gte, $lte, $ne, $eq, got {operator}" - ) - elif type(value) == list: - all_subresults = [] - for subwhere in value: - subresults = [] - self._format_where(subwhere, subresults) - all_subresults.append(subresults[0]) - if key == "$or": - result.append(f"({' OR '.join(all_subresults)})") - elif key == "$and": - result.append(f"({' AND '.join(all_subresults)})") - else: - raise ValueError(f"Expected one of $or, $and, got {key}") - - def _format_where_document(self, where_document, results): - operator = list(where_document.keys())[0] - if operator == "$contains": - results.append(f"position(document, '{where_document[operator]}') > 0") - elif operator == "$and" or operator == "$or": - all_subresults = [] - for subwhere in where_document[operator]: - subresults = [] - self._format_where_document(subwhere, subresults) - all_subresults.append(subresults[0]) - if operator == "$or": - results.append(f"({' OR '.join(all_subresults)})") - if operator == "$and": - results.append(f"({' AND '.join(all_subresults)})") - else: - raise ValueError(f"Expected one of $contains, $and, $or, got {operator}") - - @override - def get( - self, - where: Where = {}, - collection_name: Optional[str] = None, - collection_uuid: Optional[UUID] = None, - ids: Optional[IDs] = None, - sort: Optional[str] = None, - limit: Optional[int] = None, - offset: Optional[int] = None, - where_document: WhereDocument = {}, - columns: Optional[List[str]] = None, - ) -> Sequence: - if collection_name is None and collection_uuid is None: - raise TypeError( - "Arguments collection_name and collection_uuid cannot both be None" - ) - - if collection_name is not None: - collection_uuid = self.get_collection_uuid_from_name(collection_name) - - where_str = self._create_where_clause( - # collection_uuid must be defined at this point, cast it for typechecker - cast(str, collection_uuid), - ids=ids, - where=where, - where_document=where_document, - ) - - if sort is not None: - where_str += f" ORDER BY {sort}" - else: - where_str += " ORDER BY collection_uuid" # stable ordering - - if limit is not None or isinstance(limit, int): - where_str += f" LIMIT {limit}" - - if offset is not None or isinstance(offset, int): - where_str += f" OFFSET {offset}" - - val = self._get(where=where_str, columns=columns) - - return val - - @override - def count(self, collection_id: UUID) -> int: - where_string = f"WHERE collection_uuid = '{collection_id}'" - return ( - self._get_conn() - .query(f"SELECT COUNT() FROM embeddings {where_string}") - .result_rows[0][0] - ) - - def _delete(self, where_str: Optional[str] = None) -> List: - deleted_uuids = ( - self._get_conn() - .query(f"""SELECT uuid FROM embeddings {where_str}""") - .result_rows - ) - self._get_conn().command( - f""" - DELETE FROM - embeddings - {where_str} - """ - ) - return [res[0] for res in deleted_uuids] if len(deleted_uuids) > 0 else [] - - @override - def delete( - self, - where: Where = {}, - collection_uuid: Optional[UUID] = None, - ids: Optional[IDs] = None, - where_document: WhereDocument = {}, - ) -> List[str]: - where_str = self._create_where_clause( - # collection_uuid must be defined at this point, cast it for typechecker - cast(str, collection_uuid), - ids=ids, - where=where, - where_document=where_document, - ) - - deleted_uuids = self._delete(where_str) - - index = self._index(collection_uuid) - index.delete_from_index(deleted_uuids) - - return deleted_uuids - - @override - def get_by_ids( - self, uuids: List[UUID], columns: Optional[List[str]] = None - ) -> Sequence: - columns = columns + ["uuid"] if columns else ["uuid"] - select_columns = db_schema_to_keys() if columns is None else columns - response = ( - self._get_conn() - .query( - f""" - SELECT {",".join(select_columns)} FROM embeddings WHERE uuid IN ({[id.hex for id in uuids]}) - """ - ) - .result_rows - ) - - # sort db results by the order of the uuids - response = sorted(response, key=lambda obj: uuids.index(obj[len(columns) - 1])) - - return response - - @override - def get_nearest_neighbors( - self, - collection_uuid: UUID, - where: Where = {}, - embeddings: Optional[Embeddings] = None, - n_results: int = 10, - where_document: WhereDocument = {}, - ) -> Tuple[List[List[UUID]], List[List[float]]]: - # Either the collection name or the collection uuid must be provided - if collection_uuid is None: - raise TypeError("Argument collection_uuid cannot be None") - - if len(where) != 0 or len(where_document) != 0: - results = self.get( - collection_uuid=collection_uuid, - where=where, - where_document=where_document, - ) - - if len(results) > 0: - ids = [x[1] for x in results] - else: - # No results found, return empty lists - return [[] for _ in range(len(embeddings))], [ - [] for _ in range(len(embeddings)) - ] - else: - ids = None - - index = self._index(collection_uuid) - uuids, distances = index.get_nearest_neighbors(embeddings, n_results, ids) - - return uuids, distances - - @override - def create_index(self, collection_uuid: UUID): - """Create an index for a collection_uuid and optionally scoped to a dataset. - Args: - collection_uuid (str): The collection_uuid to create an index for - dataset (str, optional): The dataset to scope the index to. Defaults to None. - Returns: - None - """ - get = self.get(collection_uuid=collection_uuid) - - uuids = [x[1] for x in get] - embeddings = [x[2] for x in get] - - index = self._index(collection_uuid) - index.add(uuids, embeddings) - - @override - def add_incremental( - self, collection_uuid: UUID, ids: List[UUID], embeddings: Embeddings - ) -> None: - index = self._index(collection_uuid) - index.add(ids, embeddings) - - def reset_indexes(self): - delete_all_indexes(self._settings) - self.index_cache = {} - - @override - def reset_state(self): - conn = self._get_conn() - conn.command("DROP TABLE collections") - conn.command("DROP TABLE embeddings") - self._create_table_collections(conn) - self._create_table_embeddings(conn) - - self.reset_indexes() - - @override - def raw_sql(self, raw_sql): - return self._get_conn().query(raw_sql).result_rows diff --git a/chromadb/db/duckdb.py b/chromadb/db/duckdb.py deleted file mode 100644 index 1929413..0000000 --- a/chromadb/db/duckdb.py +++ /dev/null @@ -1,534 +0,0 @@ -# type: ignore -from chromadb.config import System -from chromadb.api.types import Documents, Embeddings, IDs, Metadatas -from chromadb.db.clickhouse import ( - Clickhouse, - db_array_schema_to_clickhouse_schema, - EMBEDDING_TABLE_SCHEMA, - db_schema_to_keys, - COLLECTION_TABLE_SCHEMA, -) -from typing import List, Optional, Sequence -import pandas as pd -import json -import duckdb -import uuid -import os -import logging -import atexit -from uuid import UUID -from overrides import override -from chromadb.api.types import Metadata - -logger = logging.getLogger(__name__) - - -def clickhouse_to_duckdb_schema(table_schema): - for item in table_schema: - if "embedding" in item: - item["embedding"] = "DOUBLE[]" - # capitalize the key - item[list(item.keys())[0]] = item[list(item.keys())[0]].upper() - if "NULLABLE" in item[list(item.keys())[0]]: - item[list(item.keys())[0]] = ( - item[list(item.keys())[0]].replace("NULLABLE(", "").replace(")", "") - ) - if "UUID" in item[list(item.keys())[0]]: - item[list(item.keys())[0]] = "STRING" - if "FLOAT64" in item[list(item.keys())[0]]: - item[list(item.keys())[0]] = "DOUBLE" - return table_schema - - -# TODO: inherits ClickHouse for convenience of copying behavior, not -# because it's logically a subtype. Factoring out the common behavior -# to a third superclass they both extend would be preferable. -class DuckDB(Clickhouse): - # duckdb has a different way of connecting to the database - def __init__(self, system: System): - self._conn = duckdb.connect() - self._create_table_collections(self._conn) - self._create_table_embeddings(self._conn) - self._settings = system.settings - - # Normally this would be handled by super(), but we actually can't invoke - # super().__init__ here because we're (incorrectly) inheriting from Clickhouse - self._dependencies = set() - - # https://duckdb.org/docs/extensions/overview - self._conn.execute("LOAD 'json';") - - @override - def _create_table_collections(self, conn): - conn.execute( - f"""CREATE TABLE collections ( - {db_array_schema_to_clickhouse_schema(clickhouse_to_duckdb_schema(COLLECTION_TABLE_SCHEMA))} - ) """ - ) - - # duckdb has different types, so we want to convert the clickhouse schema to duckdb schema - @override - def _create_table_embeddings(self, conn): - conn.execute( - f"""CREATE TABLE embeddings ( - {db_array_schema_to_clickhouse_schema(clickhouse_to_duckdb_schema(EMBEDDING_TABLE_SCHEMA))} - ) """ - ) - - # - # UTILITY METHODS - # - @override - def get_collection_uuid_from_name(self, collection_name: str) -> UUID: - return self._conn.execute( - "SELECT uuid FROM collections WHERE name = ?", [collection_name] - ).fetchall()[0][0] - - # - # COLLECTION METHODS - # - @override - def create_collection( - self, - name: str, - metadata: Optional[Metadata] = None, - get_or_create: bool = False, - ) -> Sequence: - # poor man's unique constraint - dupe_check = self.get_collection(name) - if len(dupe_check) > 0: - if get_or_create is True: - if dupe_check[0][2] != metadata: - self.update_collection( - dupe_check[0][0], new_name=name, new_metadata=metadata - ) - dupe_check = self.get_collection(name) - - logger.info( - f"collection with name {name} already exists, returning existing collection" - ) - return dupe_check - else: - raise ValueError(f"Collection with name {name} already exists") - - collection_uuid = uuid.uuid4() - self._conn.execute( - """INSERT INTO collections (uuid, name, metadata) VALUES (?, ?, ?)""", - [str(collection_uuid), name, json.dumps(metadata)], - ) - return [[str(collection_uuid), name, metadata]] - - @override - def get_collection(self, name: str) -> Sequence: - res = self._conn.execute( - """SELECT * FROM collections WHERE name = ?""", [name] - ).fetchall() - # json.loads the metadata - return [[x[0], x[1], json.loads(x[2])] for x in res] - - @override - def get_collection_by_id(self, collection_uuid: str): - res = self._conn.execute( - """SELECT * FROM collections WHERE uuid = ?""", [collection_uuid] - ).fetchone() - return [res[0], res[1], json.loads(res[2])] - - @override - def list_collections(self) -> Sequence: - res = self._conn.execute("""SELECT * FROM collections""").fetchall() - return [[x[0], x[1], json.loads(x[2])] for x in res] - - @override - def delete_collection(self, name: str): - collection_uuid = self.get_collection_uuid_from_name(name) - self._conn.execute( - """DELETE FROM embeddings WHERE collection_uuid = ?""", [collection_uuid] - ) - - self._delete_index(collection_uuid) - self._conn.execute("""DELETE FROM collections WHERE name = ?""", [name]) - - @override - def update_collection( - self, - id: UUID, - new_name: Optional[str] = None, - new_metadata: Optional[Metadata] = None, - ): - if new_name is not None: - dupe_check = self.get_collection(new_name) - if len(dupe_check) > 0 and dupe_check[0][0] != str(id): - raise ValueError(f"Collection with name {new_name} already exists") - - self._conn.execute( - """UPDATE collections SET name = ? WHERE uuid = ?""", - [new_name, id], - ) - - if new_metadata is not None: - self._conn.execute( - """UPDATE collections SET metadata = ? WHERE uuid = ?""", - [json.dumps(new_metadata), id], - ) - - # - # ITEM METHODS - # - # the execute many syntax is different than clickhouse, the (?,?) syntax is different than clickhouse - @override - def add(self, collection_uuid, embeddings, metadatas, documents, ids) -> List[UUID]: - data_to_insert = [ - [ - collection_uuid, - str(uuid.uuid4()), - embedding, - json.dumps(metadatas[i]) if metadatas else None, - documents[i] if documents else None, - ids[i], - ] - for i, embedding in enumerate(embeddings) - ] - - insert_string = "collection_uuid, uuid, embedding, metadata, document, id" - - self._conn.executemany( - f""" - INSERT INTO embeddings ({insert_string}) VALUES (?,?,?,?,?,?)""", - data_to_insert, - ) - - return [uuid.UUID(x[1]) for x in data_to_insert] # return uuids - - @override - def count(self, collection_id: UUID) -> int: - where_string = f"WHERE collection_uuid = '{collection_id}'" - return self._conn.query( - f"SELECT COUNT() FROM embeddings {where_string}" - ).fetchall()[0][0] - - @override - def _format_where(self, where, result): - for key, value in where.items(): - # Shortcut for $eq - if type(value) == str: - result.append(f" json_extract_string(metadata,'$.{key}') = '{value}'") - if type(value) == int: - result.append( - f" CAST(json_extract(metadata,'$.{key}') AS INT) = {value}" - ) - if type(value) == float: - result.append( - f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) = {value}" - ) - # Operator expression - elif type(value) == dict: - operator, operand = list(value.items())[0] - if operator == "$gt": - result.append( - f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) > {operand}" - ) - elif operator == "$lt": - result.append( - f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) < {operand}" - ) - elif operator == "$gte": - result.append( - f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) >= {operand}" - ) - elif operator == "$lte": - result.append( - f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) <= {operand}" - ) - elif operator == "$ne": - if type(operand) == str: - return result.append( - f" json_extract_string(metadata,'$.{key}') != '{operand}'" - ) - return result.append( - f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) != {operand}" - ) - elif operator == "$eq": - if type(operand) == str: - return result.append( - f" json_extract_string(metadata,'$.{key}') = '{operand}'" - ) - return result.append( - f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) = {operand}" - ) - else: - raise ValueError(f"Operator {operator} not supported") - elif type(value) == list: - all_subresults = [] - for subwhere in value: - subresults = [] - self._format_where(subwhere, subresults) - all_subresults.append(subresults[0]) - if key == "$or": - result.append(f"({' OR '.join(all_subresults)})") - elif key == "$and": - result.append(f"({' AND '.join(all_subresults)})") - else: - raise ValueError( - f"Operator {key} not supported with a list of where clauses" - ) - - @override - def _format_where_document(self, where_document, results): - operator = list(where_document.keys())[0] - if operator == "$contains": - results.append(f"position('{where_document[operator]}' in document) > 0") - elif operator == "$and" or operator == "$or": - all_subresults = [] - for subwhere in where_document[operator]: - subresults = [] - self._format_where_document(subwhere, subresults) - all_subresults.append(subresults[0]) - if operator == "$or": - results.append(f"({' OR '.join(all_subresults)})") - if operator == "$and": - results.append(f"({' AND '.join(all_subresults)})") - else: - raise ValueError(f"Operator {operator} not supported") - - @override - def _get(self, where, columns: Optional[List] = None): - select_columns = db_schema_to_keys() if columns is None else columns - val = self._conn.execute( - f"""SELECT {",".join(select_columns)} FROM embeddings {where}""" - ).fetchall() - for i in range(len(val)): - val[i] = list(val[i]) - if "collection_uuid" in select_columns: - collection_uuid_column_index = select_columns.index("collection_uuid") - val[i][collection_uuid_column_index] = uuid.UUID( - val[i][collection_uuid_column_index] - ) - if "uuid" in select_columns: - uuid_column_index = select_columns.index("uuid") - val[i][uuid_column_index] = uuid.UUID(val[i][uuid_column_index]) - if "metadata" in select_columns: - metadata_column_index = select_columns.index("metadata") - val[i][metadata_column_index] = ( - json.loads(val[i][metadata_column_index]) - if val[i][metadata_column_index] - else None - ) - - return val - - @override - def _update( - self, - collection_uuid, - ids: IDs, - embeddings: Optional[Embeddings], - metadatas: Optional[Metadatas], - documents: Optional[Documents], - ): - update_data = [] - for i in range(len(ids)): - data = [] - update_data.append(data) - if embeddings is not None: - data.append(embeddings[i]) - if metadatas is not None: - data.append(json.dumps(metadatas[i])) - if documents is not None: - data.append(documents[i]) - data.append(ids[i]) - - update_fields = [] - if embeddings is not None: - update_fields.append("embedding = ?") - if metadatas is not None: - update_fields.append("metadata = ?") - if documents is not None: - update_fields.append("document = ?") - - update_statement = f""" - UPDATE - embeddings - SET - {", ".join(update_fields)} - WHERE - id = ? AND - collection_uuid = '{collection_uuid}'; - """ - self._conn.executemany(update_statement, update_data) - - @override - def _delete(self, where_str: Optional[str] = None) -> List: - uuids_deleted = self._conn.execute( - f"""SELECT uuid FROM embeddings {where_str}""" - ).fetchall() - self._conn.execute( - f""" - DELETE FROM - embeddings - {where_str} - """ - ).fetchall()[0] - return [uuid.UUID(x[0]) for x in uuids_deleted] - - @override - def get_by_ids( - self, uuids: List[UUID], columns: Optional[List[str]] = None - ) -> Sequence: - # select from duckdb table where ids are in the list - if not isinstance(uuids, list): - raise TypeError(f"Expected ids to be a list, got {uuids}") - - if not uuids: - # create an empty pandas dataframe - return pd.DataFrame() - - columns = columns + ["uuid"] if columns else ["uuid"] - - select_columns = db_schema_to_keys() if columns is None else columns - response = self._conn.execute( - f""" - SELECT - {",".join(select_columns)} - FROM - embeddings - WHERE - uuid IN ({','.join([("'" + str(x) + "'") for x in uuids])}) - """ - ).fetchall() - - # sort db results by the order of the uuids - response = sorted( - response, key=lambda obj: uuids.index(uuid.UUID(obj[len(columns) - 1])) - ) - - return response - - @override - def raw_sql(self, raw_sql): - return self._conn.execute(raw_sql).df() - - # TODO: This method should share logic with clickhouse impl - @override - def reset_state(self): - self._conn.execute("DROP TABLE collections") - self._conn.execute("DROP TABLE embeddings") - self._create_table_collections(self._conn) - self._create_table_embeddings(self._conn) - - self.reset_indexes() - - def __del__(self): - logger.info("Exiting: Cleaning up .chroma directory") - self.reset_indexes() - - @override - def persist(self) -> None: - raise NotImplementedError( - "Set chroma_db_impl='duckdb+parquet' to get persistence functionality" - ) - - -class PersistentDuckDB(DuckDB): - _save_folder = None - - def __init__(self, system: System): - super().__init__(system=system) - - system.settings.require("persist_directory") - - if system.settings.persist_directory == ".chroma": - raise ValueError( - "You cannot use chroma's cache directory .chroma/, please set a different directory" - ) - - self._save_folder = system.settings.persist_directory - self.load() - # https://docs.python.org/3/library/atexit.html - atexit.register(self.persist) - - def set_save_folder(self, path): - self._save_folder = path - - def get_save_folder(self): - return self._save_folder - - @override - def persist(self): - """ - Persist the database to disk - """ - logger.info( - f"Persisting DB to disk, putting it in the save folder: {self._save_folder}" - ) - if self._conn is None: - return - - if not os.path.exists(self._save_folder): - os.makedirs(self._save_folder) - - # if the db is empty, dont save - if self._conn.query("SELECT COUNT() FROM embeddings") == 0: - return - - self._conn.execute( - f""" - COPY - (SELECT * FROM embeddings) - TO '{self._save_folder}/chroma-embeddings.parquet' - (FORMAT PARQUET); - """ - ) - - self._conn.execute( - f""" - COPY - (SELECT * FROM collections) - TO '{self._save_folder}/chroma-collections.parquet' - (FORMAT PARQUET); - """ - ) - - def load(self): - """ - Load the database from disk - """ - if not os.path.exists(self._save_folder): - os.makedirs(self._save_folder) - - # load in the embeddings - if not os.path.exists(f"{self._save_folder}/chroma-embeddings.parquet"): - logger.info(f"No existing DB found in {self._save_folder}, skipping load") - else: - path = self._save_folder + "/chroma-embeddings.parquet" - self._conn.execute( - f"INSERT INTO embeddings SELECT * FROM read_parquet('{path}');" - ) - logger.info( - f"""loaded in {self._conn.query(f"SELECT COUNT() FROM embeddings").fetchall()[0][0]} embeddings""" - ) - - # load in the collections - if not os.path.exists(f"{self._save_folder}/chroma-collections.parquet"): - logger.info(f"No existing DB found in {self._save_folder}, skipping load") - else: - path = self._save_folder + "/chroma-collections.parquet" - self._conn.execute( - f"INSERT INTO collections SELECT * FROM read_parquet('{path}');" - ) - logger.info( - f"""loaded in {self._conn.query(f"SELECT COUNT() FROM collections").fetchall()[0][0]} collections""" - ) - - def __del__(self): - # No-op for duckdb with persistence since the base class will delete the indexes - pass - - @override - def reset_state(self): - super().reset_state() - # empty the save folder - import shutil - import os - - shutil.rmtree(self._save_folder) - os.mkdir(self._save_folder) diff --git a/chromadb/db/impl/sqlite.py b/chromadb/db/impl/sqlite.py index 5358001..4db21c3 100644 --- a/chromadb/db/impl/sqlite.py +++ b/chromadb/db/impl/sqlite.py @@ -1,3 +1,4 @@ +from chromadb.db.impl.sqlite_pool import Connection, LockPool, PerThreadPool, Pool from chromadb.db.migrations import MigratableDB, Migration from chromadb.config import System, Settings import chromadb.db.base as base @@ -12,12 +13,18 @@ from types import TracebackType import os from uuid import UUID from threading import local +from importlib_resources import files +from importlib_resources.abc import Traversable class TxWrapper(base.TxWrapper): - def __init__(self, conn: sqlite3.Connection, stack: local) -> None: + _conn: Connection + _pool: Pool + + def __init__(self, conn_pool: Pool, stack: local): self._tx_stack = stack - self._conn = conn + self._conn = conn_pool.connect() + self._pool = conn_pool @override def __enter__(self) -> base.Cursor: @@ -39,32 +46,46 @@ class TxWrapper(base.TxWrapper): self._conn.commit() else: self._conn.rollback() + self._pool.return_to_pool(self._conn) return False class SqliteDB(MigratableDB, SqlEmbeddingsQueue, SqlSysDB): - _conn: sqlite3.Connection + _conn_pool: Pool _settings: Settings - _migration_dirs: Sequence[str] + _migration_imports: Sequence[Traversable] _db_file: str _tx_stack: local + _is_persistent: bool def __init__(self, system: System): self._settings = system.settings - self._migration_dirs = [ - "migrations/embeddings_queue", - "migrations/sysdb", - "migrations/metadb", + self._migration_imports = [ + files("chromadb.migrations.embeddings_queue"), + files("chromadb.migrations.sysdb"), + files("chromadb.migrations.metadb"), ] - self._db_file = self._settings.require("sqlite_database") + self._is_persistent = self._settings.require("is_persistent") + if not self._is_persistent: + # In order to allow sqlite to be shared between multiple threads, we need to use a + # URI connection string with shared cache. + # See https://www.sqlite.org/sharedcache.html + # https://stackoverflow.com/questions/3315046/sharing-a-memory-database-between-different-threads-in-python-using-sqlite3-pa + self._db_file = "file::memory:?cache=shared" + self._conn_pool = LockPool(self._db_file, is_uri=True) + else: + self._db_file = ( + self._settings.require("persist_directory") + "/chroma.sqlite3" + ) + if not os.path.exists(self._db_file): + os.makedirs(os.path.dirname(self._db_file), exist_ok=True) + self._conn_pool = PerThreadPool(self._db_file) self._tx_stack = local() super().__init__(system) @override def start(self) -> None: super().start() - self._conn = sqlite3.connect(self._db_file) - self._conn.isolation_level = None # Handle commits explicitly with self.tx() as cur: cur.execute("PRAGMA foreign_keys = ON") cur.execute("PRAGMA case_sensitive_like = ON") @@ -73,7 +94,7 @@ class SqliteDB(MigratableDB, SqlEmbeddingsQueue, SqlSysDB): @override def stop(self) -> None: super().stop() - self._conn.close() + self._conn_pool.close() @staticmethod @override @@ -91,14 +112,14 @@ class SqliteDB(MigratableDB, SqlEmbeddingsQueue, SqlSysDB): return "sqlite" @override - def migration_dirs(self) -> Sequence[str]: - return self._migration_dirs + def migration_dirs(self) -> Sequence[Traversable]: + return self._migration_imports @override def tx(self) -> TxWrapper: if not hasattr(self._tx_stack, "stack"): self._tx_stack.stack = [] - return TxWrapper(self._conn, stack=self._tx_stack) + return TxWrapper(self._conn_pool, stack=self._tx_stack) @override def reset_state(self) -> None: @@ -106,10 +127,19 @@ class SqliteDB(MigratableDB, SqlEmbeddingsQueue, SqlSysDB): raise ValueError( "Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted." ) - self._conn.close() - db_file = self._settings.require("sqlite_database") - if db_file != ":memory:": - os.remove(db_file) + with self.tx() as cur: + # Drop all tables + cur.execute( + """ + SELECT name FROM sqlite_master + WHERE type='table' + """ + ) + for row in cur.fetchall(): + cur.execute(f"DROP TABLE IF EXISTS {row[0]}") + self._conn_pool.close() + if self._is_persistent: + os.remove(self._db_file) self.start() super().reset_state() @@ -143,7 +173,7 @@ class SqliteDB(MigratableDB, SqlEmbeddingsQueue, SqlSysDB): return True @override - def db_migrations(self, dir: str) -> Sequence[Migration]: + def db_migrations(self, dir: Traversable) -> Sequence[Migration]: with self.tx() as cur: cur.execute( """ @@ -152,23 +182,23 @@ class SqliteDB(MigratableDB, SqlEmbeddingsQueue, SqlSysDB): WHERE dir = ? ORDER BY version ASC """, - (dir,), + (dir.name,), ) migrations = [] for row in cur.fetchall(): - dir = cast(str, row[0]) - version = cast(int, row[1]) - filename = cast(str, row[2]) - sql = cast(str, row[3]) - hash = cast(str, row[4]) + found_dir = cast(str, row[0]) + found_version = cast(int, row[1]) + found_filename = cast(str, row[2]) + found_sql = cast(str, row[3]) + found_hash = cast(str, row[4]) migrations.append( Migration( - dir=dir, - version=version, - filename=filename, - sql=sql, - hash=hash, + dir=found_dir, + version=found_version, + filename=found_filename, + sql=found_sql, + hash=found_hash, scope=self.migration_scope(), ) ) diff --git a/chromadb/db/impl/sqlite_pool.py b/chromadb/db/impl/sqlite_pool.py new file mode 100644 index 0000000..83a3edf --- /dev/null +++ b/chromadb/db/impl/sqlite_pool.py @@ -0,0 +1,159 @@ +import sqlite3 +from abc import ABC, abstractmethod +from typing import Any, Set +import threading +from overrides import override + + +class Connection: + """A threadpool connection that returns itself to the pool on close()""" + + _pool: "Pool" + _db_file: str + _conn: sqlite3.Connection + + def __init__( + self, pool: "Pool", db_file: str, is_uri: bool, *args: Any, **kwargs: Any + ): + self._pool = pool + self._db_file = db_file + self._conn = sqlite3.connect( + db_file, timeout=1000, check_same_thread=False, uri=is_uri, *args, **kwargs + ) # type: ignore + self._conn.isolation_level = None # Handle commits explicitly + + def execute(self, sql: str, parameters=...) -> sqlite3.Cursor: # type: ignore + if parameters is ...: + return self._conn.execute(sql) + return self._conn.execute(sql, parameters) + + def commit(self) -> None: + self._conn.commit() + + def rollback(self) -> None: + self._conn.rollback() + + def cursor(self) -> sqlite3.Cursor: + return self._conn.cursor() + + def close_actual(self) -> None: + """Actually closes the connection to the db""" + self._conn.close() + + +class Pool(ABC): + """Abstract base class for a pool of connections to a sqlite database.""" + + @abstractmethod + def __init__(self, db_file: str, is_uri: bool) -> None: + pass + + @abstractmethod + def connect(self, *args: Any, **kwargs: Any) -> Connection: + """Return a connection from the pool.""" + pass + + @abstractmethod + def close(self) -> None: + """Close all connections in the pool.""" + pass + + @abstractmethod + def return_to_pool(self, conn: Connection) -> None: + """Return a connection to the pool.""" + pass + + +class LockPool(Pool): + """A pool that has a single connection per thread but uses a lock to ensure that only one thread can use it at a time. + This is used because sqlite does not support multithreaded access with connection timeouts when using the + shared cache mode. We use the shared cache mode to allow multiple threads to share a database. + """ + + _connections: Set[Connection] + _lock: threading.RLock + _connection: threading.local + _db_file: str + _is_uri: bool + + def __init__(self, db_file: str, is_uri: bool = False): + self._connections = set() + self._connection = threading.local() + self._lock = threading.RLock() + self._db_file = db_file + self._is_uri = is_uri + + @override + def connect(self, *args: Any, **kwargs: Any) -> Connection: + self._lock.acquire() + if hasattr(self._connection, "conn") and self._connection.conn is not None: + return self._connection.conn # type: ignore # cast doesn't work here for some reason + else: + new_connection = Connection( + self, self._db_file, self._is_uri, *args, **kwargs + ) + self._connection.conn = new_connection + self._connections.add(new_connection) + return new_connection + + @override + def return_to_pool(self, conn: Connection) -> None: + try: + self._lock.release() + except RuntimeError: + pass + + @override + def close(self) -> None: + for conn in self._connections: + conn.close_actual() + self._connections.clear() + self._connection = threading.local() + try: + self._lock.release() + except RuntimeError: + pass + + +class PerThreadPool(Pool): + """Maintains a connection per thread. For now this does not maintain a cap on the number of connections, but it could be + extended to do so and block on connect() if the cap is reached. + """ + + _connections: Set[Connection] + _lock: threading.Lock + _connection: threading.local + _db_file: str + _is_uri_: bool + + def __init__(self, db_file: str, is_uri: bool = False): + self._connections = set() + self._connection = threading.local() + self._lock = threading.Lock() + self._db_file = db_file + self._is_uri = is_uri + + @override + def connect(self, *args: Any, **kwargs: Any) -> Connection: + if hasattr(self._connection, "conn") and self._connection.conn is not None: + return self._connection.conn # type: ignore # cast doesn't work here for some reason + else: + new_connection = Connection( + self, self._db_file, self._is_uri, *args, **kwargs + ) + self._connection.conn = new_connection + with self._lock: + self._connections.add(new_connection) + return new_connection + + @override + def close(self) -> None: + with self._lock: + for conn in self._connections: + conn.close_actual() + self._connections.clear() + self._connection = threading.local() + + @override + def return_to_pool(self, conn: Connection) -> None: + pass # Each thread gets its own connection, so we don't need to return it to the pool diff --git a/chromadb/db/index/__init__.py b/chromadb/db/index/__init__.py deleted file mode 100644 index 06a132e..0000000 --- a/chromadb/db/index/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from abc import ABC, abstractmethod - - -class Index(ABC): - @abstractmethod - def __init__(self, id, settings, metadata): - pass - - @abstractmethod - def delete(self): - pass - - @abstractmethod - def delete_from_index(self, ids): - pass - - @abstractmethod - def add(self, ids, embeddings, update=False): - pass - - @abstractmethod - def get_nearest_neighbors(self, embedding, n_results, ids): - pass diff --git a/chromadb/db/index/hnswlib.py b/chromadb/db/index/hnswlib.py deleted file mode 100644 index 0d635a0..0000000 --- a/chromadb/db/index/hnswlib.py +++ /dev/null @@ -1,306 +0,0 @@ -import os -import pickle -import time -from typing import Dict, List, Optional, Set, Tuple, Union, cast - -from chromadb.api.types import Embeddings, IndexMetadata -import hnswlib -from chromadb.config import Settings -from chromadb.db.index import Index -from chromadb.errors import ( - InvalidDimensionException, -) -import logging -import re -from uuid import UUID -import multiprocessing - -logger = logging.getLogger(__name__) - - -valid_params = { - "hnsw:space": r"^(l2|cosine|ip)$", - "hnsw:construction_ef": r"^\d+$", - "hnsw:search_ef": r"^\d+$", - "hnsw:M": r"^\d+$", - "hnsw:num_threads": r"^\d+$", - "hnsw:resize_factor": r"^\d+(\.\d+)?$", -} - -DEFAULT_CAPACITY = 1000 - - -class HnswParams: - space: str - construction_ef: int - search_ef: int - M: int - num_threads: int - resize_factor: float - - def __init__(self, metadata: Dict[str, str]): - metadata = metadata or {} - - # Convert all values to strings for future compatibility. - metadata = {k: str(v) for k, v in metadata.items()} - - for param, value in metadata.items(): - if param.startswith("hnsw:"): - if param not in valid_params: - raise ValueError(f"Unknown HNSW parameter: {param}") - if not re.match(valid_params[param], value): - raise ValueError( - f"Invalid value for HNSW parameter: {param} = {value}" - ) - - self.space = metadata.get("hnsw:space", "l2") - self.construction_ef = int(metadata.get("hnsw:construction_ef", 100)) - self.search_ef = int(metadata.get("hnsw:search_ef", 10)) - self.M = int(metadata.get("hnsw:M", 16)) - self.num_threads = int( - metadata.get("hnsw:num_threads", multiprocessing.cpu_count()) - ) - self.resize_factor = float(metadata.get("hnsw:resize_factor", 1.2)) - - -def hexid(id: Union[str, UUID]) -> str: - """Backwards compatibility for old indexes which called uuid.hex on UUID ids""" - return id.hex if isinstance(id, UUID) else id - - -def delete_all_indexes(settings: Settings) -> None: - if os.path.exists(f"{settings.persist_directory}/index"): - for file in os.listdir(f"{settings.persist_directory}/index"): - os.remove(f"{settings.persist_directory}/index/{file}") - - -class Hnswlib(Index): - _id: str - _index: hnswlib.Index - _index_metadata: IndexMetadata - _params: HnswParams - _id_to_label: Dict[str, int] - _label_to_id: Dict[int, UUID] - - def __init__( - self, - id: str, - settings: Settings, - metadata: Dict[str, str], - number_elements: int, - ): - self._save_folder = settings.persist_directory + "/index" - self._params = HnswParams(metadata) - self._id = id - self._index = None - # Mapping of IDs to HNSW integer labels - self._id_to_label = {} - self._label_to_id = {} - - self._load(number_elements) - - def _init_index(self, dimensionality: int) -> None: - # more comments available at the source: https://github.com/nmslib/hnswlib - - index = hnswlib.Index( - space=self._params.space, dim=dimensionality - ) # possible options are l2, cosine or ip - index.init_index( - max_elements=DEFAULT_CAPACITY, - ef_construction=self._params.construction_ef, - M=self._params.M, - ) - index.set_ef(self._params.search_ef) - index.set_num_threads(self._params.num_threads) - - self._index = index - self._index_metadata = { - "dimensionality": dimensionality, - "curr_elements": 0, - "total_elements_added": 0, - "time_created": time.time(), - } - self._save() - - def _check_dimensionality(self, data: Embeddings) -> None: - """Assert that the given data matches the index dimensionality""" - dim = len(data[0]) - idx_dim = self._index.dim - if dim != idx_dim: - raise InvalidDimensionException( - f"Dimensionality of ({dim}) does not match index dimensionality ({idx_dim})" - ) - - def add( - self, ids: List[UUID], embeddings: Embeddings, update: bool = False - ) -> None: - """Add or update embeddings to the index""" - - dim = len(embeddings[0]) - - if self._index is None: - self._init_index(dim) - # Calling init_index will ensure the index is not none, so we can safely cast - self._index = cast(hnswlib.Index, self._index) - - # Check dimensionality - self._check_dimensionality(embeddings) - - labels = [] - for id in ids: - if hexid(id) in self._id_to_label: - if update: - labels.append(self._id_to_label[hexid(id)]) - else: - raise ValueError(f"ID {id} already exists in index") - else: - self._index_metadata["total_elements_added"] += 1 - self._index_metadata["curr_elements"] += 1 - next_label = self._index_metadata["total_elements_added"] - self._id_to_label[hexid(id)] = next_label - self._label_to_id[next_label] = id - labels.append(next_label) - - if ( - self._index_metadata["total_elements_added"] - > self._index.get_max_elements() - ): - new_size = int( - max( - self._index_metadata["total_elements_added"] - * self._params.resize_factor, - DEFAULT_CAPACITY, - ) - ) - self._index.resize_index(new_size) - - self._index.add_items(embeddings, labels) - self._save() - - def delete(self) -> None: - # delete files, dont throw error if they dont exist - try: - os.remove(f"{self._save_folder}/id_to_uuid_{self._id}.pkl") - os.remove(f"{self._save_folder}/uuid_to_id_{self._id}.pkl") - os.remove(f"{self._save_folder}/index_{self._id}.bin") - os.remove(f"{self._save_folder}/index_metadata_{self._id}.pkl") - except Exception: - pass - - self._index = None - self._collection_uuid = None - self._id_to_label = {} - self._label_to_id = {} - - def delete_from_index(self, ids: List[UUID]) -> None: - if self._index is not None: - for id in ids: - label = self._id_to_label[hexid(id)] - self._index.mark_deleted(label) - del self._label_to_id[label] - del self._id_to_label[hexid(id)] - self._index_metadata["curr_elements"] -= 1 - - self._save() - - def _save(self) -> None: - # create the directory if it doesn't exist - if not os.path.exists(f"{self._save_folder}"): - os.makedirs(f"{self._save_folder}") - - if self._index is None: - return - self._index.save_index(f"{self._save_folder}/index_{self._id}.bin") - - # pickle the mappers - # Use old filenames for backwards compatibility - with open(f"{self._save_folder}/id_to_uuid_{self._id}.pkl", "wb") as f: - pickle.dump(self._label_to_id, f, pickle.HIGHEST_PROTOCOL) - with open(f"{self._save_folder}/uuid_to_id_{self._id}.pkl", "wb") as f: - pickle.dump(self._id_to_label, f, pickle.HIGHEST_PROTOCOL) - with open(f"{self._save_folder}/index_metadata_{self._id}.pkl", "wb") as f: - pickle.dump(self._index_metadata, f, pickle.HIGHEST_PROTOCOL) - - logger.debug(f"Index saved to {self._save_folder}/index.bin") - - def _exists(self) -> None: - return - - def _load(self, curr_elements: int) -> None: - if not os.path.exists(f"{self._save_folder}/index_{self._id}.bin"): - return - - # unpickle the mappers - with open(f"{self._save_folder}/id_to_uuid_{self._id}.pkl", "rb") as f: - self._label_to_id = pickle.load(f) - with open(f"{self._save_folder}/uuid_to_id_{self._id}.pkl", "rb") as f: - self._id_to_label = pickle.load(f) - with open(f"{self._save_folder}/index_metadata_{self._id}.pkl", "rb") as f: - self._index_metadata = pickle.load(f) - - self._index_metadata["curr_elements"] = curr_elements - # Backwards compatability with versions that don't have curr_elements or total_elements_added - if "total_elements_added" not in self._index_metadata: - self._index_metadata["total_elements_added"] = self._index_metadata[ - "elements" - ] - - p = hnswlib.Index( - space=self._params.space, dim=self._index_metadata["dimensionality"] - ) - self._index = p - self._index.load_index( - f"{self._save_folder}/index_{self._id}.bin", - max_elements=int( - max(curr_elements * self._params.resize_factor, DEFAULT_CAPACITY) - ), - ) - self._index.set_ef(self._params.search_ef) - self._index.set_num_threads(self._params.num_threads) - - def get_nearest_neighbors( - self, query: Embeddings, k: int, ids: Optional[List[UUID]] = None - ) -> Tuple[List[List[UUID]], List[List[float]]]: - # The only case where the index is none is if no elements have been added - # We don't save the index until at least one element has been added - # And so there is also nothing at load time for persisted indexes - # In the case where no elements have been added, we return empty - if self._index is None: - return [[] for _ in range(len(query))], [[] for _ in range(len(query))] - - # Check dimensionality - self._check_dimensionality(query) - - # Check Number of requested results - if k > self._index_metadata["curr_elements"]: - logger.warning( - f"Number of requested results {k} is greater than number of elements in index {self._index_metadata['curr_elements']}, updating n_results = {self._index_metadata['curr_elements']}" - ) - k = self._index_metadata["curr_elements"] - - s2 = time.time() - # get ids from uuids as a set, if they are available - labels: Set[int] = set() - if ids is not None: - labels = {self._id_to_label[hexid(id)] for id in ids} - if len(labels) < k: - k = len(labels) - - filter_function = None - if len(labels) != 0: - filter_function = lambda label: label in labels # NOQA: E731 - - logger.debug(f"time to pre process our knn query: {time.time() - s2}") - - s3 = time.time() - database_labels, distances = self._index.knn_query( - query, k=k, filter=filter_function - ) - distances = distances.tolist() - distances = cast(List[List[float]], distances) - logger.debug(f"time to run knn query: {time.time() - s3}") - - return_ids = [ - [self._label_to_id[label] for label in labels] for labels in database_labels - ] - return return_ids, distances diff --git a/chromadb/db/migrations.py b/chromadb/db/migrations.py index 70541a1..af2ecce 100644 --- a/chromadb/db/migrations.py +++ b/chromadb/db/migrations.py @@ -1,6 +1,6 @@ from typing import Sequence -from typing_extensions import TypedDict -import os +from typing_extensions import TypedDict, NotRequired +from importlib_resources.abc import Traversable import re import hashlib from chromadb.db.base import SqlDB, Cursor @@ -9,6 +9,7 @@ from chromadb.config import System, Settings class MigrationFile(TypedDict): + path: NotRequired[Traversable] dir: str filename: str version: int @@ -59,9 +60,12 @@ class InvalidMigrationFilename(Exception): class MigratableDB(SqlDB): """Simple base class for databases which support basic migrations. - Migrations are SQL files stored in a project-relative directory. All migrations in - the same directory are assumed to be dependent on previous migrations in the same - directory, where "previous" is defined on lexographical ordering of filenames. + Migrations are SQL files stored as package resources and accessed via + importlib_resources. + + All migrations in the same directory are assumed to be dependent on previous + migrations in the same directory, where "previous" is defined on lexographical + ordering of filenames. Migrations have a ascending numeric version number and a hash of the file contents. When migrations are applied, the hashes of previous migrations are checked to ensure @@ -87,7 +91,7 @@ class MigratableDB(SqlDB): pass @abstractmethod - def migration_dirs(self) -> Sequence[str]: + def migration_dirs(self) -> Sequence[Traversable]: """Directories containing the migration sequences that should be applied to this DB.""" pass @@ -103,7 +107,7 @@ class MigratableDB(SqlDB): pass @abstractmethod - def db_migrations(self, dir: str) -> Sequence[Migration]: + def db_migrations(self, dir: Traversable) -> Sequence[Migration]: """Return a list of all migrations already applied to this database, from the given source directory, in ascending order.""" pass @@ -136,7 +140,7 @@ class MigratableDB(SqlDB): ) if len(unapplied_migrations) > 0: version = unapplied_migrations[0]["version"] - raise UnappliedMigrationsError(dir=dir, version=version) + raise UnappliedMigrationsError(dir=dir.name, version=version) def apply_migrations(self) -> None: """Validate existing migrations, and apply all new ones.""" @@ -157,13 +161,16 @@ class MigratableDB(SqlDB): filename_regex = re.compile(r"(\d+)-(.+)\.(.+)\.sql") -def _parse_migration_filename(dir: str, filename: str) -> MigrationFile: +def _parse_migration_filename( + dir: str, filename: str, path: Traversable +) -> MigrationFile: """Parse a migration filename into a MigrationFile object""" match = filename_regex.match(filename) if match is None: raise InvalidMigrationFilename("Invalid migration filename: " + filename) version, _, scope = match.groups() return { + "path": path, "dir": dir, "filename": filename, "version": int(version), @@ -203,13 +210,13 @@ def verify_migration_sequence( return source_migrations[len(db_migrations) :] -def find_migrations(dir: str, scope: str) -> Sequence[Migration]: +def find_migrations(dir: Traversable, scope: str) -> Sequence[Migration]: """Return a list of all migration present in the given directory, in ascending order. Filter by scope.""" files = [ - _parse_migration_filename(dir, filename) - for filename in os.listdir(dir) - if filename.endswith(".sql") + _parse_migration_filename(dir.name, t.name, t) + for t in dir.iterdir() + if t.name.endswith(".sql") ] files = list(filter(lambda f: f["scope"] == scope, files)) files = sorted(files, key=lambda f: f["version"]) @@ -218,7 +225,11 @@ def find_migrations(dir: str, scope: str) -> Sequence[Migration]: def _read_migration_file(file: MigrationFile) -> Migration: """Read a migration file""" - sql = open(os.path.join(file["dir"], file["filename"])).read() + if "path" not in file or not file["path"].is_file(): + raise FileNotFoundError( + f"No migration file found for dir {file['dir']} with filename {file['filename']} and scope {file['scope']} at version {file['version']}" + ) + sql = file["path"].read_text() hash = hashlib.md5(sql.encode("utf-8")).hexdigest() return { "hash": hash, diff --git a/chromadb/migrations/__init__.py b/chromadb/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/migrations/embeddings_queue/00001-embeddings.sqlite.sql b/chromadb/migrations/embeddings_queue/00001-embeddings.sqlite.sql similarity index 100% rename from migrations/embeddings_queue/00001-embeddings.sqlite.sql rename to chromadb/migrations/embeddings_queue/00001-embeddings.sqlite.sql diff --git a/migrations/metadb/00001-embedding-metadata.sqlite.sql b/chromadb/migrations/metadb/00001-embedding-metadata.sqlite.sql similarity index 100% rename from migrations/metadb/00001-embedding-metadata.sqlite.sql rename to chromadb/migrations/metadb/00001-embedding-metadata.sqlite.sql diff --git a/migrations/sysdb/00001-collections.sqlite.sql b/chromadb/migrations/sysdb/00001-collections.sqlite.sql similarity index 100% rename from migrations/sysdb/00001-collections.sqlite.sql rename to chromadb/migrations/sysdb/00001-collections.sqlite.sql diff --git a/migrations/sysdb/00002-segments.sqlite.sql b/chromadb/migrations/sysdb/00002-segments.sqlite.sql similarity index 100% rename from migrations/sysdb/00002-segments.sqlite.sql rename to chromadb/migrations/sysdb/00002-segments.sqlite.sql diff --git a/migrations/sysdb/00003-collection-dimension.sqlite.sql b/chromadb/migrations/sysdb/00003-collection-dimension.sqlite.sql similarity index 100% rename from migrations/sysdb/00003-collection-dimension.sqlite.sql rename to chromadb/migrations/sysdb/00003-collection-dimension.sqlite.sql diff --git a/chromadb/segment/__init__.py b/chromadb/segment/__init__.py index e75904a..5c2f431 100644 --- a/chromadb/segment/__init__.py +++ b/chromadb/segment/__init__.py @@ -3,6 +3,7 @@ from abc import abstractmethod from chromadb.types import ( Collection, MetadataEmbeddingRecord, + Operation, VectorEmbeddingRecord, Where, WhereDocument, @@ -105,3 +106,10 @@ class SegmentManager(Component): implementation full control over which segment impls are in or out of memory at a given time.)""" pass + + @abstractmethod + def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None: + """Signal to the segment manager that a collection is about to be used, so that + it can preload segments as needed. This is only a hint, and implementations are + free to ignore it.""" + pass diff --git a/chromadb/segment/impl/manager/local.py b/chromadb/segment/impl/manager/local.py index fdac85f..e9fe11a 100644 --- a/chromadb/segment/impl/manager/local.py +++ b/chromadb/segment/impl/manager/local.py @@ -1,3 +1,4 @@ +from threading import Lock from chromadb.segment import ( SegmentImplementation, SegmentManager, @@ -9,7 +10,7 @@ from chromadb.config import System, get_class from chromadb.db.system import SysDB from overrides import override from enum import Enum -from chromadb.types import Collection, Segment, SegmentScope, Metadata +from chromadb.types import Collection, Operation, Segment, SegmentScope, Metadata from typing import Dict, Type, Sequence, Optional, cast from uuid import UUID, uuid4 from collections import defaultdict @@ -18,11 +19,13 @@ from collections import defaultdict class SegmentType(Enum): SQLITE = "urn:chroma:segment/metadata/sqlite" HNSW_LOCAL_MEMORY = "urn:chroma:segment/vector/hnsw-local-memory" + HNSW_LOCAL_PERSISTED = "urn:chroma:segment/vector/hnsw-local-persisted" SEGMENT_TYPE_IMPLS = { SegmentType.SQLITE: "chromadb.segment.impl.metadata.sqlite.SqliteMetadataSegment", SegmentType.HNSW_LOCAL_MEMORY: "chromadb.segment.impl.vector.local_hnsw.LocalHnswSegment", + SegmentType.HNSW_LOCAL_PERSISTED: "chromadb.segment.impl.vector.local_persistent_hnsw.PersistentLocalHnswSegment", } @@ -31,6 +34,8 @@ class LocalSegmentManager(SegmentManager): _system: System _instances: Dict[UUID, SegmentImplementation] _segment_cache: Dict[UUID, Dict[SegmentScope, Segment]] + _vector_segment_type: SegmentType = SegmentType.HNSW_LOCAL_MEMORY + _lock: Lock def __init__(self, system: System): super().__init__(system) @@ -38,6 +43,10 @@ class LocalSegmentManager(SegmentManager): self._system = system self._instances = {} self._segment_cache = defaultdict(dict) + self._lock = Lock() + + if self._system.settings.require("is_persistent"): + self._vector_segment_type = SegmentType.HNSW_LOCAL_PERSISTED @override def start(self) -> None: @@ -62,7 +71,7 @@ class LocalSegmentManager(SegmentManager): @override def create_segments(self, collection: Collection) -> Sequence[Segment]: vector_segment = _segment( - SegmentType.HNSW_LOCAL_MEMORY, SegmentScope.VECTOR, collection + self._vector_segment_type, SegmentScope.VECTOR, collection ) metadata_segment = _segment( SegmentType.SQLITE, SegmentScope.METADATA, collection @@ -97,9 +106,20 @@ class LocalSegmentManager(SegmentManager): segment = next(filter(lambda s: s["type"] in known_types, segments)) self._segment_cache[collection_id][scope] = segment - instance = self._instance(self._segment_cache[collection_id][scope]) + # Instances must be atomically created, so we use a lock to ensure that only one thread + # creates the instance. + with self._lock: + instance = self._instance(self._segment_cache[collection_id][scope]) return cast(S, instance) + @override + def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None: + # The local segment manager responds to hints by pre-loading both the metadata and vector + # segments for the given collection. + for type in [MetadataReader, VectorReader]: + # Just use get_segment to load the segment into the cache + self.get_segment(collection_id, type) + def _cls(self, segment: Segment) -> Type[SegmentImplementation]: classname = SEGMENT_TYPE_IMPLS[SegmentType(segment["type"])] cls = get_class(classname, SegmentImplementation) diff --git a/chromadb/segment/impl/vector/batch.py b/chromadb/segment/impl/vector/batch.py new file mode 100644 index 0000000..aac533b --- /dev/null +++ b/chromadb/segment/impl/vector/batch.py @@ -0,0 +1,106 @@ +from typing import Dict, List, Set, cast + +from chromadb.types import EmbeddingRecord, Operation, SeqId, Vector + + +class Batch: + """Used to model the set of changes as an atomic operation""" + + _ids_to_records: Dict[str, EmbeddingRecord] + _deleted_ids: Set[str] + _written_ids: Set[str] + _upsert_add_ids: Set[str] # IDs that are being added in an upsert + add_count: int + update_count: int + max_seq_id: SeqId + + def __init__(self) -> None: + self._ids_to_records = {} + self._deleted_ids = set() + self._written_ids = set() + self._upsert_add_ids = set() + self.add_count = 0 + self.update_count = 0 + self.max_seq_id = 0 + + def __len__(self) -> int: + """Get the number of changes in this batch""" + return len(self._written_ids) + len(self._deleted_ids) + + def get_deleted_ids(self) -> List[str]: + """Get the list of deleted embeddings in this batch""" + return list(self._deleted_ids) + + def get_written_ids(self) -> List[str]: + """Get the list of written embeddings in this batch""" + return list(self._written_ids) + + def get_written_vectors(self, ids: List[str]) -> List[Vector]: + """Get the list of vectors to write in this batch""" + return [cast(Vector, self._ids_to_records[id]["embedding"]) for id in ids] + + def get_record(self, id: str) -> EmbeddingRecord: + """Get the record for a given ID""" + return self._ids_to_records[id] + + def is_deleted(self, id: str) -> bool: + """Check if a given ID is deleted""" + return id in self._deleted_ids + + @property + def delete_count(self) -> int: + return len(self._deleted_ids) + + def apply(self, record: EmbeddingRecord, exists_already: bool = False) -> None: + """Apply an embedding record to this batch. Records passed to this method are assumed to be validated for correctness. + For example, a delete or update presumes the ID exists in the index. An add presumes the ID does not exist in the index. + The exists_already flag should be set to True if the ID does exist in the index, and False otherwise. + """ + + id = record["id"] + if record["operation"] == Operation.DELETE: + # If the ID was previously written, remove it from the written set + # And update the add/update/delete counts + if id in self._written_ids: + self._written_ids.remove(id) + if self._ids_to_records[id]["operation"] == Operation.ADD: + self.add_count -= 1 + elif self._ids_to_records[id]["operation"] == Operation.UPDATE: + self.update_count -= 1 + self._deleted_ids.add(id) + elif self._ids_to_records[id]["operation"] == Operation.UPSERT: + if id in self._upsert_add_ids: + self.add_count -= 1 + self._upsert_add_ids.remove(id) + else: + self.update_count -= 1 + self._deleted_ids.add(id) + elif id not in self._deleted_ids: + self._deleted_ids.add(id) + + # Remove the record from the batch + if id in self._ids_to_records: + del self._ids_to_records[id] + + else: + self._ids_to_records[id] = record + self._written_ids.add(id) + + # If the ID was previously deleted, remove it from the deleted set + # And update the delete count + if id in self._deleted_ids: + self._deleted_ids.remove(id) + + # Update the add/update counts + if record["operation"] == Operation.UPSERT: + if not exists_already: + self.add_count += 1 + self._upsert_add_ids.add(id) + else: + self.update_count += 1 + elif record["operation"] == Operation.ADD: + self.add_count += 1 + elif record["operation"] == Operation.UPDATE: + self.update_count += 1 + + self.max_seq_id = max(self.max_seq_id, record["seq_id"]) diff --git a/chromadb/segment/impl/vector/brute_force_index.py b/chromadb/segment/impl/vector/brute_force_index.py new file mode 100644 index 0000000..f9466e3 --- /dev/null +++ b/chromadb/segment/impl/vector/brute_force_index.py @@ -0,0 +1,153 @@ +from typing import Any, Callable, Dict, List, Optional, Sequence, Set +import numpy as np +import numpy.typing as npt +from chromadb.types import ( + EmbeddingRecord, + VectorEmbeddingRecord, + VectorQuery, + VectorQueryResult, +) + +from chromadb.utils import distance_functions +import logging + +logger = logging.getLogger(__name__) + + +class BruteForceIndex: + """A lightweight, numpy based brute force index that is used for batches that have not been indexed into hnsw yet. It is not + thread safe and callers should ensure that only one thread is accessing it at a time. + """ + + id_to_index: Dict[str, int] + index_to_id: Dict[int, str] + id_to_seq_id: Dict[str, int] + deleted_ids: Set[str] + free_indices: List[int] + size: int + dimensionality: int + distance_fn: Callable[[npt.NDArray[Any], npt.NDArray[Any]], float] + vectors: npt.NDArray[Any] + + def __init__(self, size: int, dimensionality: int, space: str = "l2"): + if space == "l2": + self.distance_fn = distance_functions.l2 + elif space == "ip": + self.distance_fn = distance_functions.ip + elif space == "cosine": + self.distance_fn = distance_functions.cosine + else: + raise Exception(f"Unknown distance function: {space}") + + self.id_to_index = {} + self.index_to_id = {} + self.id_to_seq_id = {} + self.deleted_ids = set() + self.free_indices = list(range(size)) + self.size = size + self.dimensionality = dimensionality + self.vectors = np.zeros((size, dimensionality)) + + def __len__(self) -> int: + return len(self.id_to_index) + + def clear(self) -> None: + self.id_to_index = {} + self.index_to_id = {} + self.id_to_seq_id = {} + self.deleted_ids.clear() + self.free_indices = list(range(self.size)) + self.vectors.fill(0) + + def upsert(self, records: List[EmbeddingRecord]) -> None: + if len(records) + len(self) > self.size: + raise Exception( + "Index with capacity {} and {} current entries cannot add {} records".format( + self.size, len(self), len(records) + ) + ) + + for i, record in enumerate(records): + id = record["id"] + vector = record["embedding"] + self.id_to_seq_id[id] = record["seq_id"] + if id in self.deleted_ids: + self.deleted_ids.remove(id) + + # TODO: It may be faster to use multi-index selection on the vectors array + if id in self.id_to_index: + # Update + index = self.id_to_index[id] + self.vectors[index] = vector + else: + # Add + next_index = self.free_indices.pop() + self.id_to_index[id] = next_index + self.index_to_id[next_index] = id + self.vectors[next_index] = vector + + def delete(self, records: List[EmbeddingRecord]) -> None: + for record in records: + id = record["id"] + if id in self.id_to_index: + index = self.id_to_index[id] + self.deleted_ids.add(id) + del self.id_to_index[id] + del self.index_to_id[index] + del self.id_to_seq_id[id] + self.vectors[index].fill(np.NaN) + self.free_indices.append(index) + else: + logger.warning(f"Delete of nonexisting embedding ID: {id}") + + def has_id(self, id: str) -> bool: + """Returns whether the index contains the given ID""" + return id in self.id_to_index and id not in self.deleted_ids + + def get_vectors( + self, ids: Optional[Sequence[str]] = None + ) -> Sequence[VectorEmbeddingRecord]: + target_ids = ids or self.id_to_index.keys() + + return [ + VectorEmbeddingRecord( + id=id, + embedding=self.vectors[self.id_to_index[id]].tolist(), + seq_id=self.id_to_seq_id[id], + ) + for id in target_ids + ] + + def query(self, query: VectorQuery) -> Sequence[Sequence[VectorQueryResult]]: + np_query = np.array(query["vectors"]) + allowed_ids = ( + None if query["allowed_ids"] is None else set(query["allowed_ids"]) + ) + distances = np.apply_along_axis( + lambda query: np.apply_along_axis(self.distance_fn, 1, self.vectors, query), + 1, + np_query, + ) + + indices = np.argsort(distances).tolist() + # Filter out deleted labels + filtered_results = [] + for i, index_list in enumerate(indices): + curr_results = [] + for j in index_list: + # If the index is in the index_to_id map, then it has been added + if j in self.index_to_id: + id = self.index_to_id[j] + if id not in self.deleted_ids and ( + allowed_ids is None or id in allowed_ids + ): + curr_results.append( + VectorQueryResult( + id=id, + distance=distances[i][j].item(), + seq_id=self.id_to_seq_id[id], + embedding=self.vectors[j].tolist(), + ) + ) + filtered_results.append(curr_results) + return filtered_results diff --git a/chromadb/segment/impl/vector/hnsw_params.py b/chromadb/segment/impl/vector/hnsw_params.py new file mode 100644 index 0000000..b12c428 --- /dev/null +++ b/chromadb/segment/impl/vector/hnsw_params.py @@ -0,0 +1,88 @@ +import multiprocessing +import re +from typing import Any, Callable, Dict, Union + +from chromadb.types import Metadata + + +Validator = Callable[[Union[str, int, float]], bool] + +param_validators: Dict[str, Validator] = { + "hnsw:space": lambda p: bool(re.match(r"^(l2|cosine|ip)$", str(p))), + "hnsw:construction_ef": lambda p: isinstance(p, int), + "hnsw:search_ef": lambda p: isinstance(p, int), + "hnsw:M": lambda p: isinstance(p, int), + "hnsw:num_threads": lambda p: isinstance(p, int), + "hnsw:resize_factor": lambda p: isinstance(p, (int, float)), +} + +# Extra params used for persistent hnsw +persistent_param_validators: Dict[str, Validator] = { + "hnsw:batch_size": lambda p: isinstance(p, int) and p > 2, + "hnsw:sync_threshold": lambda p: isinstance(p, int) and p > 2, +} + + +class Params: + @staticmethod + def _select(metadata: Metadata) -> Dict[str, Any]: + segment_metadata = {} + for param, value in metadata.items(): + if param.startswith("hnsw:"): + segment_metadata[param] = value + return segment_metadata + + @staticmethod + def _validate(metadata: Dict[str, Any], validators: Dict[str, Validator]) -> None: + """Validates the metadata""" + # Validate it + for param, value in metadata.items(): + if param not in validators: + raise ValueError(f"Unknown HNSW parameter: {param}") + if not validators[param](value): + raise ValueError(f"Invalid value for HNSW parameter: {param} = {value}") + + +class HnswParams(Params): + space: str + construction_ef: int + search_ef: int + M: int + num_threads: int + resize_factor: float + + def __init__(self, metadata: Metadata): + metadata = metadata or {} + self.space = str(metadata.get("hnsw:space", "l2")) + self.construction_ef = int(metadata.get("hnsw:construction_ef", 100)) + self.search_ef = int(metadata.get("hnsw:search_ef", 10)) + self.M = int(metadata.get("hnsw:M", 16)) + self.num_threads = int( + metadata.get("hnsw:num_threads", multiprocessing.cpu_count()) + ) + self.resize_factor = float(metadata.get("hnsw:resize_factor", 1.2)) + + @staticmethod + def extract(metadata: Metadata) -> Metadata: + """Validate and return only the relevant hnsw params""" + segment_metadata = HnswParams._select(metadata) + HnswParams._validate(segment_metadata, param_validators) + return segment_metadata + + +class PersistentHnswParams(HnswParams): + batch_size: int + sync_threshold: int + + def __init__(self, metadata: Metadata): + super().__init__(metadata) + self.batch_size = int(metadata.get("hnsw:batch_size", 100)) + self.sync_threshold = int(metadata.get("hnsw:sync_threshold", 1000)) + + @staticmethod + def extract(metadata: Metadata) -> Metadata: + """Returns only the relevant hnsw params""" + all_validators = {**param_validators, **persistent_param_validators} + segment_metadata = PersistentHnswParams._select(metadata) + PersistentHnswParams._validate(segment_metadata, all_validators) + return segment_metadata diff --git a/chromadb/segment/impl/vector/local_hnsw.py b/chromadb/segment/impl/vector/local_hnsw.py index 81dcc7f..2b628bb 100644 --- a/chromadb/segment/impl/vector/local_hnsw.py +++ b/chromadb/segment/impl/vector/local_hnsw.py @@ -1,9 +1,11 @@ from overrides import override -from typing import Optional, Sequence, Dict, Set, List, Callable, Union, cast +from typing import Optional, Sequence, Dict, Set, List, cast from uuid import UUID from chromadb.segment import VectorReader from chromadb.ingest import Consumer from chromadb.config import System, Settings +from chromadb.segment.impl.vector.batch import Batch +from chromadb.segment.impl.vector.hnsw_params import HnswParams from chromadb.types import ( EmbeddingRecord, VectorEmbeddingRecord, @@ -16,83 +18,14 @@ from chromadb.types import ( Vector, ) from chromadb.errors import InvalidDimensionException -import re -import multiprocessing import hnswlib -from threading import Lock +from chromadb.utils.read_write_lock import ReadWriteLock, ReadRWLock, WriteRWLock import logging logger = logging.getLogger(__name__) DEFAULT_CAPACITY = 1000 -Validator = Callable[[Union[str, int, float]], bool] - -param_validators: Dict[str, Validator] = { - "hnsw:space": lambda p: bool(re.match(r"^(l2|cosine|ip)$", str(p))), - "hnsw:construction_ef": lambda p: isinstance(p, int), - "hnsw:search_ef": lambda p: isinstance(p, int), - "hnsw:M": lambda p: isinstance(p, int), - "hnsw:num_threads": lambda p: isinstance(p, int), - "hnsw:resize_factor": lambda p: isinstance(p, (int, float)), -} - - -class HnswParams: - space: str - construction_ef: int - search_ef: int - M: int - num_threads: int - resize_factor: float - - def __init__(self, metadata: Metadata): - metadata = metadata or {} - self.space = str(metadata.get("hnsw:space", "l2")) - self.construction_ef = int(metadata.get("hnsw:construction_ef", 100)) - self.search_ef = int(metadata.get("hnsw:search_ef", 10)) - self.M = int(metadata.get("hnsw:M", 16)) - self.num_threads = int( - metadata.get("hnsw:num_threads", multiprocessing.cpu_count()) - ) - self.resize_factor = float(metadata.get("hnsw:resize_factor", 1.2)) - - -class Batch: - """Used to model the set of changes as an atomic operation""" - - labels: List[Optional[int]] - vectors: List[Vector] - seq_ids: List[SeqId] - ids: List[str] - delete_labels: List[int] - delete_ids: List[str] - add_count: int - delete_count: int - - def __init__(self) -> None: - self.labels = [] - self.vectors = [] - self.seq_ids = [] - self.ids = [] - self.delete_labels = [] - self.delete_ids = [] - self.add_count = 0 - self.delete_count = 0 - - def add(self, label: Optional[int], record: EmbeddingRecord) -> None: - self.labels.append(label) - self.vectors.append(cast(Vector, record["embedding"])) - self.seq_ids.append(record["seq_id"]) - self.ids.append(record["id"]) - if not label: - self.add_count += 1 - - def delete(self, label: int, id: str) -> None: - self.delete_labels.append(label) - self.delete_ids.append(id) - self.delete_count += 1 - class LocalHnswSegment(VectorReader): _id: UUID @@ -104,10 +37,10 @@ class LocalHnswSegment(VectorReader): _index: Optional[hnswlib.Index] _dimensionality: Optional[int] - _elements: int + _total_elements_added: int _max_seq_id: SeqId - _lock: Lock + _lock: ReadWriteLock _id_to_label: Dict[str, int] _label_to_id: Dict[int, str] @@ -129,25 +62,14 @@ class LocalHnswSegment(VectorReader): self._id_to_label = {} self._label_to_id = {} - self._lock = Lock() + self._lock = ReadWriteLock() super().__init__(system, segment) @staticmethod @override def propagate_collection_metadata(metadata: Metadata) -> Optional[Metadata]: # Extract relevant metadata - segment_metadata = {} - for param, value in metadata.items(): - if param.startswith("hnsw:"): - segment_metadata[param] = value - - # Validate it - for param, value in segment_metadata.items(): - if param not in param_validators: - raise ValueError(f"Unknown HNSW parameter: {param}") - if not param_validators[param](value): - raise ValueError(f"Invalid value for HNSW parameter: {param} = {value}") - + segment_metadata = HnswParams.extract(metadata) return segment_metadata @override @@ -209,7 +131,7 @@ class LocalHnswSegment(VectorReader): labels: Set[int] = set() ids = query["allowed_ids"] if ids is not None: - labels = {self._id_to_label[id] for id in ids} + labels = {self._id_to_label[id] for id in ids if id in self._id_to_label} if len(labels) < k: k = len(labels) @@ -218,31 +140,38 @@ class LocalHnswSegment(VectorReader): query_vectors = query["vectors"] - result_labels, distances = self._index.knn_query( - query_vectors, k=k, filter=filter_function if ids else None - ) + with ReadRWLock(self._lock): + result_labels, distances = self._index.knn_query( + query_vectors, k=k, filter=filter_function if ids else None + ) - distances = cast(List[List[float]], distances) - result_labels = cast(List[List[int]], result_labels) + # TODO: these casts are not correct, hnswlib returns np + # distances = cast(List[List[float]], distances) + # result_labels = cast(List[List[int]], result_labels) - all_results: List[List[VectorQueryResult]] = [] - for result_i in range(len(result_labels)): - results: List[VectorQueryResult] = [] - for label, distance in zip(result_labels[result_i], distances[result_i]): - id = self._label_to_id[label] - seq_id = self._id_to_seq_id[id] - if query["include_embeddings"]: - embedding = self._index.get_items([label])[0] - else: - embedding = None - results.append( - VectorQueryResult( - id=id, seq_id=seq_id, distance=distance, embedding=embedding + all_results: List[List[VectorQueryResult]] = [] + for result_i in range(len(result_labels)): + results: List[VectorQueryResult] = [] + for label, distance in zip( + result_labels[result_i], distances[result_i] + ): + id = self._label_to_id[label] + seq_id = self._id_to_seq_id[id] + if query["include_embeddings"]: + embedding = self._index.get_items([label])[0] + else: + embedding = None + results.append( + VectorQueryResult( + id=id, + seq_id=seq_id, + distance=distance.item(), + embedding=embedding, + ) ) - ) - all_results.append(results) + all_results.append(results) - return all_results + return all_results @override def max_seqid(self) -> SeqId: @@ -291,45 +220,52 @@ class LocalHnswSegment(VectorReader): def _apply_batch(self, batch: Batch) -> None: """Apply a batch of changes, as atomically as possible.""" + deleted_ids = batch.get_deleted_ids() + written_ids = batch.get_written_ids() + vectors_to_write = batch.get_written_vectors(written_ids) + labels_to_write = [0] * len(vectors_to_write) - if batch.delete_ids: + if len(deleted_ids) > 0: index = cast(hnswlib.Index, self._index) - for i in range(len(batch.delete_ids)): - label = batch.delete_labels[i] - id = batch.delete_ids[i] + for i in range(len(deleted_ids)): + id = deleted_ids[i] + # Never added this id to hnsw, so we can safely ignore it for deletions + if id not in self._id_to_label: + continue + label = self._id_to_label[id] index.mark_deleted(label) del self._id_to_label[id] del self._label_to_id[label] del self._id_to_seq_id[id] - if batch.ids: - self._ensure_index(batch.add_count, len(batch.vectors[0])) + if len(written_ids) > 0: + self._ensure_index(batch.add_count, len(vectors_to_write[0])) next_label = self._total_elements_added + 1 - for i in range(len(batch.labels)): - if batch.labels[i] is None: - batch.labels[i] = next_label + for i in range(len(written_ids)): + if written_ids[i] not in self._id_to_label: + labels_to_write[i] = next_label next_label += 1 - - labels = cast(List[int], batch.labels) + else: + labels_to_write[i] = self._id_to_label[written_ids[i]] index = cast(hnswlib.Index, self._index) # First, update the index - index.add_items(batch.vectors, labels) + index.add_items(vectors_to_write, labels_to_write) # If that succeeds, update the mappings - for id, label, seq_id in zip(batch.ids, labels, batch.seq_ids): - self._id_to_seq_id[id] = seq_id - self._id_to_label[id] = label - self._label_to_id[label] = id + for i, id in enumerate(written_ids): + self._id_to_seq_id[id] = batch.get_record(id)["seq_id"] + self._id_to_label[id] = labels_to_write[i] + self._label_to_id[labels_to_write[i]] = id # If that succeeds, update the total count self._total_elements_added += batch.add_count # If that succeeds, finally the seq ID - self._max_seq_id = max(self._max_seq_id, max(batch.seq_ids)) + self._max_seq_id = batch.max_seq_id def _write_records(self, records: Sequence[EmbeddingRecord]) -> None: """Add a batch of embeddings to the index""" @@ -337,7 +273,7 @@ class LocalHnswSegment(VectorReader): raise RuntimeError("Cannot add embeddings to stopped component") # Avoid all sorts of potential problems by ensuring single-threaded access - with self._lock: + with WriteRWLock(self._lock): batch = Batch() for record in records: @@ -348,30 +284,24 @@ class LocalHnswSegment(VectorReader): if op == Operation.DELETE: if label: - batch.delete(label, id) + batch.apply(record) else: logger.warning(f"Delete of nonexisting embedding ID: {id}") elif op == Operation.UPDATE: if record["embedding"] is not None: if label is not None: - batch.add(label, record) + batch.apply(record) else: logger.warning( f"Update of nonexisting embedding ID: {record['id']}" ) elif op == Operation.ADD: if not label: - batch.add(label, record) + batch.apply(record, False) else: logger.warning(f"Add of existing embedding ID: {id}") elif op == Operation.UPSERT: - batch.add(label, record) + batch.apply(record, label is not None) self._apply_batch(batch) - - -# TODO: Implement this as a performance improvement, if rebuilding the -# index on startup is too slow. But test this first. -class PersistentLocalHnswSegment(LocalHnswSegment): - pass diff --git a/chromadb/segment/impl/vector/local_persistent_hnsw.py b/chromadb/segment/impl/vector/local_persistent_hnsw.py new file mode 100644 index 0000000..c49cc6f --- /dev/null +++ b/chromadb/segment/impl/vector/local_persistent_hnsw.py @@ -0,0 +1,379 @@ +import os +from overrides import override +import pickle +from typing import Dict, List, Optional, Sequence, Set, cast +from chromadb.config import System +from chromadb.segment.impl.vector.batch import Batch +from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams +from chromadb.segment.impl.vector.local_hnsw import ( + DEFAULT_CAPACITY, + LocalHnswSegment, +) +from chromadb.segment.impl.vector.brute_force_index import BruteForceIndex +from chromadb.types import ( + EmbeddingRecord, + Metadata, + Operation, + Segment, + SeqId, + Vector, + VectorEmbeddingRecord, + VectorQuery, + VectorQueryResult, +) +import hnswlib +import logging + +from chromadb.utils.read_write_lock import ReadRWLock, WriteRWLock + + +logger = logging.getLogger(__name__) + + +class PersistentData: + """Stores the data and metadata needed for a PersistentLocalHnswSegment""" + + dimensionality: Optional[int] + total_elements_added: int + max_seq_id: SeqId + + id_to_label: Dict[str, int] + label_to_id: Dict[int, str] + id_to_seq_id: Dict[str, SeqId] + + def __init__( + self, + dimensionality: Optional[int], + total_elements_added: int, + max_seq_id: int, + id_to_label: Dict[str, int], + label_to_id: Dict[int, str], + id_to_seq_id: Dict[str, SeqId], + ): + self.dimensionality = dimensionality + self.total_elements_added = total_elements_added + self.max_seq_id = max_seq_id + self.id_to_label = id_to_label + self.label_to_id = label_to_id + self.id_to_seq_id = id_to_seq_id + + @staticmethod + def load_from_file(filename: str) -> "PersistentData": + """Load persistent data from a file""" + with open(filename, "rb") as f: + ret = cast(PersistentData, pickle.load(f)) + return ret + + +class PersistentLocalHnswSegment(LocalHnswSegment): + METADATA_FILE: str = "index_metadata.pickle" + # How many records to add to index at once, we do this because crossing the python/c++ boundary is expensive (for add()) + # When records are not added to the c++ index, they are buffered in memory and served + # via brute force search. + _batch_size: int + _brute_force_index: Optional[BruteForceIndex] + _curr_batch: Batch + # How many records to add to index before syncing to disk + _sync_threshold: int + _persist_data: PersistentData + _persist_directory: str + + def __init__(self, system: System, segment: Segment): + super().__init__(system, segment) + + self._params = PersistentHnswParams(segment["metadata"] or {}) + self._batch_size = self._params.batch_size + self._sync_threshold = self._params.sync_threshold + + self._persist_directory = system.settings.require("persist_directory") + self._curr_batch = Batch() + self._brute_force_index = None + if not os.path.exists(self._get_storage_folder()): + os.makedirs(self._get_storage_folder(), exist_ok=True) + # Load persist data if it exists already, otherwise create it + if self._index_exists(): + self._persist_data = PersistentData.load_from_file( + self._get_metadata_file() + ) + self._dimensionality = self._persist_data.dimensionality + self._total_elements_added = self._persist_data.total_elements_added + self._max_seq_id = self._persist_data.max_seq_id + self._id_to_label = self._persist_data.id_to_label + self._label_to_id = self._persist_data.label_to_id + self._id_to_seq_id = self._persist_data.id_to_seq_id + # If the index was written to, we need to re-initialize it + if len(self._id_to_label) > 0: + self._dimensionality = cast(int, self._dimensionality) + self._init_index(self._dimensionality) + else: + self._persist_data = PersistentData( + self._dimensionality, + self._total_elements_added, + self._max_seq_id, + self._id_to_label, + self._label_to_id, + self._id_to_seq_id, + ) + + @staticmethod + @override + def propagate_collection_metadata(metadata: Metadata) -> Optional[Metadata]: + # Extract relevant metadata + segment_metadata = PersistentHnswParams.extract(metadata) + return segment_metadata + + def _index_exists(self) -> bool: + """Check if the index exists via the metadata file""" + return os.path.exists(self._get_metadata_file()) + + def _get_metadata_file(self) -> str: + """Get the metadata file path""" + return os.path.join(self._get_storage_folder(), self.METADATA_FILE) + + def _get_storage_folder(self) -> str: + """Get the storage folder path""" + folder = os.path.join(self._persist_directory, str(self._id)) + return folder + + @override + def _init_index(self, dimensionality: int) -> None: + index = hnswlib.Index(space=self._params.space, dim=dimensionality) + self._brute_force_index = BruteForceIndex( + size=self._batch_size, + dimensionality=dimensionality, + space=self._params.space, + ) + + # Check if index exists and load it if it does + if self._index_exists(): + index.load_index( + self._get_storage_folder(), + is_persistent_index=True, + max_elements=int( + max(self.count() * self._params.resize_factor, DEFAULT_CAPACITY) + ), + ) + else: + index.init_index( + max_elements=DEFAULT_CAPACITY, + ef_construction=self._params.construction_ef, + M=self._params.M, + is_persistent_index=True, + persistence_location=self._get_storage_folder(), + ) + + index.set_ef(self._params.search_ef) + index.set_num_threads(self._params.num_threads) + + self._index = index + self._dimensionality = dimensionality + + def _persist(self) -> None: + """Persist the index and data to disk""" + index = cast(hnswlib.Index, self._index) + + # Persist the index + index.persist_dirty() + + # Persist the metadata + self._persist_data.dimensionality = self._dimensionality + self._persist_data.total_elements_added = self._total_elements_added + self._persist_data.max_seq_id = self._max_seq_id + + # TODO: This should really be stored in sqlite, the index itself, or a better + # storage format + self._persist_data.id_to_label = self._id_to_label + self._persist_data.label_to_id = self._label_to_id + self._persist_data.id_to_seq_id = self._id_to_seq_id + + with open(self._get_metadata_file(), "wb") as metadata_file: + pickle.dump(self._persist_data, metadata_file, pickle.HIGHEST_PROTOCOL) + + @override + def _apply_batch(self, batch: Batch) -> None: + super()._apply_batch(batch) + if ( + self._total_elements_added - self._persist_data.total_elements_added + >= self._sync_threshold + ): + self._persist() + + @override + def _write_records(self, records: Sequence[EmbeddingRecord]) -> None: + """Add a batch of embeddings to the index""" + if not self._running: + raise RuntimeError("Cannot add embeddings to stopped component") + + with WriteRWLock(self._lock): + for record in records: + if record["embedding"] is not None: + self._ensure_index(len(records), len(record["embedding"])) + self._brute_force_index = cast(BruteForceIndex, self._brute_force_index) + + self._max_seq_id = max(self._max_seq_id, record["seq_id"]) + id = record["id"] + op = record["operation"] + exists_in_index = self._id_to_label.get( + id, None + ) is not None or self._brute_force_index.has_id(id) + + if op == Operation.DELETE: + if exists_in_index: + self._curr_batch.apply(record) + self._brute_force_index.delete([record]) + else: + logger.warning(f"Delete of nonexisting embedding ID: {id}") + + elif op == Operation.UPDATE: + if record["embedding"] is not None: + if exists_in_index: + self._curr_batch.apply(record) + self._brute_force_index.upsert([record]) + else: + logger.warning( + f"Update of nonexisting embedding ID: {record['id']}" + ) + elif op == Operation.ADD: + if record["embedding"] is not None: + if not exists_in_index: + self._curr_batch.apply(record, not exists_in_index) + self._brute_force_index.upsert([record]) + else: + logger.warning(f"Add of existing embedding ID: {id}") + elif op == Operation.UPSERT: + if record["embedding"] is not None: + self._curr_batch.apply(record, exists_in_index) + self._brute_force_index.upsert([record]) + if len(self._curr_batch) >= self._batch_size: + self._apply_batch(self._curr_batch) + self._curr_batch = Batch() + self._brute_force_index.clear() + + @override + def count(self) -> int: + return ( + len(self._id_to_label) + + self._curr_batch.add_count + - self._curr_batch.delete_count + ) + + @override + def get_vectors( + self, ids: Optional[Sequence[str]] = None + ) -> Sequence[VectorEmbeddingRecord]: + """Get the embeddings from the HNSW index and layered brute force batch index""" + results = [] + ids_hnsw: Set[str] = set() + ids_bf: Set[str] = set() + + if self._index is not None: + ids_hnsw = set(self._id_to_label.keys()) + if self._brute_force_index is not None: + ids_bf = set(self._curr_batch.get_written_ids()) + + target_ids = ids or list(ids_hnsw.union(ids_bf)) + self._brute_force_index = cast(BruteForceIndex, self._brute_force_index) + hnsw_labels = [] + + for id in target_ids: + if id in ids_bf: + results.append(self._brute_force_index.get_vectors([id])[0]) + elif id in ids_hnsw and id not in self._curr_batch._deleted_ids: + hnsw_labels.append(self._id_to_label[id]) + + if len(hnsw_labels) > 0 and self._index is not None: + vectors = cast(Sequence[Vector], self._index.get_items(hnsw_labels)) + + for label, vector in zip(hnsw_labels, vectors): + id = self._label_to_id[label] + seq_id = self._id_to_seq_id[id] + results.append( + VectorEmbeddingRecord(id=id, seq_id=seq_id, embedding=vector) + ) + + return results + + @override + def query_vectors( + self, query: VectorQuery + ) -> Sequence[Sequence[VectorQueryResult]]: + if self._index is None and self._brute_force_index is None: + return [[] for _ in range(len(query["vectors"]))] + + k = query["k"] + if k > self.count(): + logger.warning( + f"Number of requested results {k} is greater than number of elements in index {self.count()}, updating n_results = {self.count()}" + ) + k = self.count() + + # Overquery by updated and deleted elements layered on the index because they may + # hide the real nearest neighbors in the hnsw index + hnsw_k = k + self._curr_batch.update_count + self._curr_batch.delete_count + if hnsw_k > len(self._id_to_label): + hnsw_k = len(self._id_to_label) + hnsw_query = VectorQuery( + vectors=query["vectors"], + k=hnsw_k, + allowed_ids=query["allowed_ids"], + include_embeddings=query["include_embeddings"], + options=query["options"], + ) + + # For each query vector, we want to take the top k results from the + # combined results of the brute force and hnsw index + results: List[List[VectorQueryResult]] = [] + self._brute_force_index = cast(BruteForceIndex, self._brute_force_index) + with ReadRWLock(self._lock): + bf_results = self._brute_force_index.query(query) + hnsw_results = super().query_vectors(hnsw_query) + for i in range(len(query["vectors"])): + # Merge results into a single list of size k + bf_pointer: int = 0 + hnsw_pointer: int = 0 + curr_bf_result: Sequence[VectorQueryResult] = bf_results[i] + curr_hnsw_result: Sequence[VectorQueryResult] = hnsw_results[i] + curr_results: List[VectorQueryResult] = [] + # In the case where filters cause the number of results to be less than k, + # we set k to be the number of results + total_results = len(curr_bf_result) + len(curr_hnsw_result) + if total_results == 0: + results.append([]) + else: + while len(curr_results) < min(k, total_results): + if bf_pointer < len(curr_bf_result) and hnsw_pointer < len( + curr_hnsw_result + ): + bf_dist = curr_bf_result[bf_pointer]["distance"] + hnsw_dist = curr_hnsw_result[hnsw_pointer]["distance"] + if bf_dist <= hnsw_dist: + curr_results.append(curr_bf_result[bf_pointer]) + bf_pointer += 1 + else: + id = curr_hnsw_result[hnsw_pointer]["id"] + # Only add the hnsw result if it is not in the brute force index + # as updated or deleted + if not self._brute_force_index.has_id( + id + ) and not self._curr_batch.is_deleted(id): + curr_results.append(curr_hnsw_result[hnsw_pointer]) + hnsw_pointer += 1 + else: + break + remaining = min(k, total_results) - len(curr_results) + if remaining > 0 and hnsw_pointer < len(curr_hnsw_result): + for i in range( + hnsw_pointer, + min(len(curr_hnsw_result), hnsw_pointer + remaining + 1), + ): + id = curr_hnsw_result[i]["id"] + if not self._brute_force_index.has_id( + id + ) and not self._curr_batch.is_deleted(id): + curr_results.append(curr_hnsw_result[i]) + elif remaining > 0 and bf_pointer < len(curr_bf_result): + curr_results.extend( + curr_bf_result[bf_pointer : bf_pointer + remaining] + ) + results.append(curr_results) + return results diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index 1de6e1b..9a4d890 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -92,14 +92,21 @@ class FastAPI(chromadb.server.Server): self.router.add_api_route("/api/v1/reset", self.reset, methods=["POST"]) self.router.add_api_route("/api/v1/version", self.version, methods=["GET"]) self.router.add_api_route("/api/v1/heartbeat", self.heartbeat, methods=["GET"]) - self.router.add_api_route("/api/v1/persist", self.persist, methods=["POST"]) - self.router.add_api_route("/api/v1/raw_sql", self.raw_sql, methods=["POST"]) + self.router.add_api_route( + "/api/v1/raw_sql", self.raw_sql, methods=["POST"], response_model=None + ) self.router.add_api_route( - "/api/v1/collections", self.list_collections, methods=["GET"] + "/api/v1/collections", + self.list_collections, + methods=["GET"], + response_model=None, ) self.router.add_api_route( - "/api/v1/collections", self.create_collection, methods=["POST"] + "/api/v1/collections", + self.create_collection, + methods=["POST"], + response_model=None, ) self.router.add_api_route( @@ -107,46 +114,67 @@ class FastAPI(chromadb.server.Server): self.add, methods=["POST"], status_code=status.HTTP_201_CREATED, + response_model=None, ) self.router.add_api_route( - "/api/v1/collections/{collection_id}/update", self.update, methods=["POST"] + "/api/v1/collections/{collection_id}/update", + self.update, + methods=["POST"], + response_model=None, ) self.router.add_api_route( - "/api/v1/collections/{collection_id}/upsert", self.upsert, methods=["POST"] + "/api/v1/collections/{collection_id}/upsert", + self.upsert, + methods=["POST"], + response_model=None, ) self.router.add_api_route( - "/api/v1/collections/{collection_id}/get", self.get, methods=["POST"] + "/api/v1/collections/{collection_id}/get", + self.get, + methods=["POST"], + response_model=None, ) self.router.add_api_route( - "/api/v1/collections/{collection_id}/delete", self.delete, methods=["POST"] + "/api/v1/collections/{collection_id}/delete", + self.delete, + methods=["POST"], + response_model=None, ) self.router.add_api_route( - "/api/v1/collections/{collection_id}/count", self.count, methods=["GET"] + "/api/v1/collections/{collection_id}/count", + self.count, + methods=["GET"], + response_model=None, ) self.router.add_api_route( "/api/v1/collections/{collection_id}/query", self.get_nearest_neighbors, methods=["POST"], + response_model=None, ) self.router.add_api_route( "/api/v1/collections/{collection_name}/create_index", self.create_index, methods=["POST"], + response_model=None, ) self.router.add_api_route( "/api/v1/collections/{collection_name}", self.get_collection, methods=["GET"], + response_model=None, ) self.router.add_api_route( "/api/v1/collections/{collection_id}", self.update_collection, methods=["PUT"], + response_model=None, ) self.router.add_api_route( "/api/v1/collections/{collection_name}", self.delete_collection, methods=["DELETE"], + response_model=None, ) self._app.include_router(self.router) @@ -162,9 +190,6 @@ class FastAPI(chromadb.server.Server): def heartbeat(self) -> Dict[str, int]: return self.root() - def persist(self) -> None: - self._api.persist() - def version(self) -> str: return self._api.get_version() diff --git a/chromadb/telemetry/__init__.py b/chromadb/telemetry/__init__.py index 30b7bdb..db96254 100644 --- a/chromadb/telemetry/__init__.py +++ b/chromadb/telemetry/__init__.py @@ -11,8 +11,8 @@ from pathlib import Path from enum import Enum TELEMETRY_WHITELISTED_SETTINGS = [ - "chroma_db_impl", "chroma_api_impl", + "is_persistent", "chroma_server_ssl_enabled", ] diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py index 5bb0001..23b2250 100644 --- a/chromadb/test/conftest.py +++ b/chromadb/test/conftest.py @@ -8,7 +8,7 @@ import os import uvicorn import time import pytest -from typing import Generator, List, Callable +from typing import Generator, List, Callable, Optional, Tuple import shutil import logging import socket @@ -39,13 +39,31 @@ def find_free_port() -> int: return s.getsockname()[1] # type: ignore -def _run_server(port: int) -> None: +def _run_server( + port: int, is_persistent: bool = False, persist_directory: Optional[str] = None +) -> None: """Run a Chroma server locally""" - settings = Settings( - chroma_api_impl="local", - chroma_db_impl="duckdb", - persist_directory=tempfile.gettempdir() + "/test_server", - ) + if is_persistent and persist_directory: + settings = Settings( + chroma_api_impl="chromadb.api.segment.SegmentAPI", + chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", + is_persistent=is_persistent, + persist_directory=persist_directory, + allow_reset=True, + ) + else: + settings = Settings( + chroma_api_impl="chromadb.api.segment.SegmentAPI", + chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", + is_persistent=False, + allow_reset=True, + ) server = chromadb.server.fastapi.FastAPI(settings) uvicorn.run(server.app(), host="0.0.0.0", port=port, log_level="error") @@ -63,16 +81,22 @@ def _await_server(api: API, attempts: int = 0) -> None: _await_server(api, attempts + 1) -def fastapi() -> Generator[System, None, None]: +def _fastapi_fixture(is_persistent: bool = False) -> Generator[System, None, None]: """Fixture generator that launches a server in a separate process, and yields a fastapi client connect to it""" + port = find_free_port() logger.info(f"Running test FastAPI server on port {port}") ctx = multiprocessing.get_context("spawn") - proc = ctx.Process(target=_run_server, args=(port,), daemon=True) + args: Tuple[int, bool, Optional[str]] = (port, False, None) + persist_directory = None + if is_persistent: + persist_directory = tempfile.mkdtemp() + args = (port, is_persistent, persist_directory) + proc = ctx.Process(target=_run_server, args=args, daemon=True) proc.start() settings = Settings( - chroma_api_impl="rest", + chroma_api_impl="chromadb.api.fastapi.FastAPI", chroma_server_host="localhost", chroma_server_http_port=str(port), allow_reset=True, @@ -84,38 +108,17 @@ def fastapi() -> Generator[System, None, None]: yield system system.stop() proc.kill() + if is_persistent and persist_directory is not None: + if os.path.exists(persist_directory): + shutil.rmtree(persist_directory) -def duckdb() -> Generator[System, None, None]: - """Fixture generator for duckdb""" - settings = Settings( - chroma_api_impl="local", - chroma_db_impl="duckdb", - persist_directory=tempfile.gettempdir(), - allow_reset=True, - ) - system = System(settings) - system.start() - yield system - system.stop() +def fastapi() -> Generator[System, None, None]: + return _fastapi_fixture(is_persistent=False) -def duckdb_parquet() -> Generator[System, None, None]: - """Fixture generator for duckdb+parquet""" - - save_path = tempfile.gettempdir() + "/tests" - settings = Settings( - chroma_api_impl="local", - chroma_db_impl="duckdb+parquet", - persist_directory=save_path, - allow_reset=True, - ) - system = System(settings) - system.start() - yield system - system.stop() - if os.path.exists(save_path): - shutil.rmtree(save_path) +def fastapi_persistent() -> Generator[System, None, None]: + return _fastapi_fixture(is_persistent=True) def integration() -> Generator[System, None, None]: @@ -137,7 +140,7 @@ def sqlite() -> Generator[System, None, None]: chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", - sqlite_database=":memory:", + is_persistent=False, allow_reset=True, ) system = System(settings) @@ -146,8 +149,29 @@ def sqlite() -> Generator[System, None, None]: system.stop() +def sqlite_persistent() -> Generator[System, None, None]: + """Fixture generator for segment-based API using persistent Sqlite""" + save_path = tempfile.mkdtemp() + settings = Settings( + chroma_api_impl="chromadb.api.segment.SegmentAPI", + chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", + allow_reset=True, + is_persistent=True, + persist_directory=save_path, + ) + system = System(settings) + system.start() + yield system + system.stop() + if os.path.exists(save_path): + shutil.rmtree(save_path) + + def system_fixtures() -> List[Callable[[], Generator[System, None, None]]]: - fixtures = [duckdb, duckdb_parquet, fastapi, sqlite] + fixtures = [fastapi, fastapi_persistent, sqlite, sqlite_persistent] if "CHROMA_INTEGRATION_TEST" in os.environ: fixtures.append(integration) if "CHROMA_INTEGRATION_TEST_ONLY" in os.environ: diff --git a/chromadb/test/db/migrations/__init__.py b/chromadb/test/db/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chromadb/test/db/test_migrations.py b/chromadb/test/db/test_migrations.py index 73ccc35..96df89b 100644 --- a/chromadb/test/db/test_migrations.py +++ b/chromadb/test/db/test_migrations.py @@ -1,4 +1,5 @@ import pytest +from importlib_resources import files from typing import Generator, List, Callable import chromadb.db.migrations as migrations from chromadb.db.impl.sqlite import SqliteDB @@ -11,7 +12,10 @@ def sqlite() -> Generator[migrations.MigratableDB, None, None]: """Fixture generator for sqlite DB""" db = SqliteDB( System( - Settings(sqlite_database=":memory:", migrations="none", allow_reset=True) + Settings( + migrations="none", + allow_reset=True, + ) ) ) db.start() @@ -47,10 +51,9 @@ def test_setup_migrations(db: migrations.MigratableDB) -> None: def test_migrations(db: migrations.MigratableDB) -> None: db.initialize_migrations() - db_migrations = db.db_migrations("chromadb/test/db/migrations") - source_migrations = migrations.find_migrations( - "chromadb/test/db/migrations", db.migration_scope() - ) + dir = files("chromadb.test.db.migrations") + db_migrations = db.db_migrations(dir) + source_migrations = migrations.find_migrations(dir, db.migration_scope()) unapplied_migrations = migrations.verify_migration_sequence( db_migrations, source_migrations @@ -66,7 +69,7 @@ def test_migrations(db: migrations.MigratableDB) -> None: for m in unapplied_migrations[:-1]: db.apply_migration(cur, m) - db_migrations = db.db_migrations("chromadb/test/db/migrations") + db_migrations = db.db_migrations(dir) unapplied_migrations = migrations.verify_migration_sequence( db_migrations, source_migrations ) @@ -85,7 +88,7 @@ def test_migrations(db: migrations.MigratableDB) -> None: for m in unapplied_migrations: db.apply_migration(cur, m) - db_migrations = db.db_migrations("chromadb/test/db/migrations") + db_migrations = db.db_migrations(dir) unapplied_migrations = migrations.verify_migration_sequence( db_migrations, source_migrations ) @@ -102,11 +105,10 @@ def test_tampered_migration(db: migrations.MigratableDB) -> None: db.setup_migrations() - source_migrations = migrations.find_migrations( - "chromadb/test/db/migrations", db.migration_scope() - ) + dir = files("chromadb.test.db.migrations") + source_migrations = migrations.find_migrations(dir, db.migration_scope()) - db_migrations = db.db_migrations("chromadb/test/db/migrations") + db_migrations = db.db_migrations(dir) unapplied_migrations = migrations.verify_migration_sequence( db_migrations, source_migrations @@ -116,7 +118,7 @@ def test_tampered_migration(db: migrations.MigratableDB) -> None: for m in unapplied_migrations: db.apply_migration(cur, m) - db_migrations = db.db_migrations("chromadb/test/db/migrations") + db_migrations = db.db_migrations(dir) unapplied_migrations = migrations.verify_migration_sequence( db_migrations, source_migrations ) @@ -143,7 +145,8 @@ def test_initialization( monkeypatch: pytest.MonkeyPatch, db: migrations.MigratableDB ) -> None: db.reset_state() - monkeypatch.setattr(db, "migration_dirs", lambda: ["chromadb/test/db/migrations"]) + dir = files("chromadb.test.db.migrations") + monkeypatch.setattr(db, "migration_dirs", lambda: [dir]) assert not db.migrations_initialized() diff --git a/chromadb/test/db/test_system.py b/chromadb/test/db/test_system.py index 02ff978..82c9b6a 100644 --- a/chromadb/test/db/test_system.py +++ b/chromadb/test/db/test_system.py @@ -1,3 +1,6 @@ +import os +import shutil +import tempfile import pytest from typing import Generator, List, Callable, Dict, Union from chromadb.types import Collection, Segment, SegmentScope @@ -11,14 +14,29 @@ import uuid def sqlite() -> Generator[SysDB, None, None]: """Fixture generator for sqlite DB""" - db = SqliteDB(System(Settings(sqlite_database=":memory:", allow_reset=True))) + db = SqliteDB(System(Settings(allow_reset=True))) db.start() yield db db.stop() +def sqlite_persistent() -> Generator[SysDB, None, None]: + """Fixture generator for sqlite DB""" + save_path = tempfile.mkdtemp() + db = SqliteDB( + System( + Settings(allow_reset=True, is_persistent=True, persist_directory=save_path) + ) + ) + db.start() + yield db + db.stop() + if os.path.exists(save_path): + shutil.rmtree(save_path) + + def db_fixtures() -> List[Callable[[], Generator[SysDB, None, None]]]: - return [sqlite] + return [sqlite, sqlite_persistent] @pytest.fixture(scope="module", params=db_fixtures()) diff --git a/chromadb/test/hnswlib/test_hnswlib.py b/chromadb/test/hnswlib/test_hnswlib.py deleted file mode 100644 index 2039c67..0000000 --- a/chromadb/test/hnswlib/test_hnswlib.py +++ /dev/null @@ -1,67 +0,0 @@ -import os -import shutil -import tempfile -from typing import Generator - -import pytest -from chromadb.db.index.hnswlib import Hnswlib -from chromadb.config import Settings -import uuid -import numpy as np - - -@pytest.fixture(scope="module") -def settings() -> Generator[Settings, None, None]: - save_path = tempfile.gettempdir() + "/tests/hnswlib/" - yield Settings(persist_directory=save_path) - if os.path.exists(save_path): - shutil.rmtree(save_path) - - -def test_count_tracking(settings: Settings) -> None: - hnswlib = Hnswlib("test", settings, {}, 2) - hnswlib._init_index(2) - assert hnswlib._index_metadata["curr_elements"] == 0 - assert hnswlib._index_metadata["total_elements_added"] == 0 - idA, idB = uuid.uuid4(), uuid.uuid4() - - embeddingA = np.random.rand(1, 2) - hnswlib.add([idA], embeddingA.tolist()) - assert ( - hnswlib._index_metadata["curr_elements"] - == hnswlib._index_metadata["total_elements_added"] - == 1 - ) - embeddingB = np.random.rand(1, 2) - hnswlib.add([idB], embeddingB.tolist()) - assert ( - hnswlib._index_metadata["curr_elements"] - == hnswlib._index_metadata["total_elements_added"] - == 2 - ) - hnswlib.delete_from_index(ids=[idA]) - assert hnswlib._index_metadata["curr_elements"] == 1 - assert hnswlib._index_metadata["total_elements_added"] == 2 - hnswlib.delete_from_index(ids=[idB]) - assert hnswlib._index_metadata["curr_elements"] == 0 - assert hnswlib._index_metadata["total_elements_added"] == 2 - - -def test_add_delete_large_amount(settings: Settings) -> None: - # Test adding a large number of records - N = 2000 - D = 512 - large_records = np.random.rand(N, D).astype(np.float32).tolist() - ids = [uuid.uuid4() for _ in range(N)] - hnswlib = Hnswlib("test", settings, {}, N) - hnswlib._init_index(D) - hnswlib.add(ids, large_records) - assert hnswlib._index_metadata["curr_elements"] == N - assert hnswlib._index_metadata["total_elements_added"] == N - - # Test deleting a large number of records by getting a random subset of the ids - ids_to_delete = np.random.choice(np.array(ids), size=100, replace=False).tolist() - hnswlib.delete_from_index(ids_to_delete) - - assert hnswlib._index_metadata["curr_elements"] == N - 100 - assert hnswlib._index_metadata["total_elements_added"] == N diff --git a/chromadb/test/ingest/test_producer_consumer.py b/chromadb/test/ingest/test_producer_consumer.py index 2689196..02808aa 100644 --- a/chromadb/test/ingest/test_producer_consumer.py +++ b/chromadb/test/ingest/test_producer_consumer.py @@ -1,3 +1,6 @@ +import os +import shutil +import tempfile import pytest from itertools import count from typing import ( @@ -26,15 +29,29 @@ from asyncio import Event, wait_for, TimeoutError def sqlite() -> Generator[Tuple[Producer, Consumer], None, None]: """Fixture generator for sqlite Producer + Consumer""" - system = System(Settings(sqlite_database=":memory:", allow_reset=True)) + system = System(Settings(allow_reset=True)) db = system.require(SqliteDB) system.start() yield db, db system.stop() +def sqlite_persistent() -> Generator[Tuple[Producer, Consumer], None, None]: + """Fixture generator for sqlite_persistent Producer + Consumer""" + save_path = tempfile.mkdtemp() + system = System( + Settings(allow_reset=True, is_persistent=True, persist_directory=save_path) + ) + db = system.require(SqliteDB) + system.start() + yield db, db + system.stop() + if os.path.exists(save_path): + shutil.rmtree(save_path) + + def fixtures() -> List[Callable[[], Generator[Tuple[Producer, Consumer], None, None]]]: - return [sqlite] + return [sqlite, sqlite_persistent] @pytest.fixture(scope="module", params=fixtures()) diff --git a/chromadb/test/property/invariants.py b/chromadb/test/property/invariants.py index 6b1d6bd..329efcc 100644 --- a/chromadb/test/property/invariants.py +++ b/chromadb/test/property/invariants.py @@ -1,6 +1,6 @@ import math from chromadb.test.property.strategies import NormalizedRecordSet, RecordSet -from typing import Callable, Optional, Tuple, Union, List, TypeVar, cast, Dict +from typing import Callable, Optional, Tuple, Union, List, TypeVar, cast from typing_extensions import Literal import numpy as np import numpy.typing as npt @@ -9,6 +9,8 @@ from chromadb.api.models.Collection import Collection from hypothesis import note from hypothesis.errors import InvalidArgument +from chromadb.utils import distance_functions + T = TypeVar("T") @@ -124,23 +126,12 @@ def no_duplicates(collection: Collection) -> None: assert len(ids) == len(set(ids)) -# These match what the spec of hnswlib is -# This epsilon is used to prevent division by zero and the value is the same -# https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/python_bindings/bindings.cpp#L238 -NORM_EPS = 1e-30 -distance_functions: Dict[str, Callable[[npt.ArrayLike, npt.ArrayLike], float]] = { - "l2": lambda x, y: np.linalg.norm(x - y) ** 2, # type: ignore - "cosine": lambda x, y: 1 - np.dot(x, y) / ((np.linalg.norm(x) + NORM_EPS) * (np.linalg.norm(y) + NORM_EPS)), # type: ignore - "ip": lambda x, y: 1 - np.dot(x, y), # type: ignore -} - - def _exact_distances( query: types.Embeddings, targets: types.Embeddings, - distance_fn: Callable[[npt.ArrayLike, npt.ArrayLike], float] = distance_functions[ - "l2" - ], + distance_fn: Callable[ + [npt.ArrayLike, npt.ArrayLike], float + ] = distance_functions.l2, ) -> Tuple[List[List[int]], List[List[float]]]: """Return the ordered indices and distances from each query to each target""" np_query = np.array(query) @@ -168,6 +159,7 @@ def ann_accuracy( n_results: int = 1, min_recall: float = 0.99, embedding_function: Optional[types.EmbeddingFunction] = None, + query_indices: Optional[List[int]] = None, ) -> None: """Validate that the API performs nearest_neighbor searches correctly""" normalized_record_set = wrap_all(record_set) @@ -185,7 +177,7 @@ def ann_accuracy( embeddings = embedding_function(normalized_record_set["documents"]) # l2 is the default distance function - distance_function = distance_functions["l2"] + distance_function = distance_functions.l2 accuracy_threshold = 1e-6 assert collection.metadata is not None assert embeddings is not None @@ -200,19 +192,25 @@ def ann_accuracy( accuracy_threshold = accuracy_threshold * math.pow(10, int(math.log10(dim))) if space == "cosine": - distance_function = distance_functions["cosine"] - + distance_function = distance_functions.cosine if space == "ip": - distance_function = distance_functions["ip"] + distance_function = distance_functions.ip # Perform exact distance computation + query_embeddings = ( + embeddings if query_indices is None else [embeddings[i] for i in query_indices] + ) + query_documents = normalized_record_set["documents"] + if query_indices is not None and query_documents is not None: + query_documents = [query_documents[i] for i in query_indices] + indices, distances = _exact_distances( - embeddings, embeddings, distance_fn=distance_function + query_embeddings, embeddings, distance_fn=distance_function ) query_results = collection.query( - query_embeddings=normalized_record_set["embeddings"], - query_texts=normalized_record_set["documents"] if not have_embeddings else None, + query_embeddings=query_embeddings if have_embeddings else None, + query_texts=query_documents if not have_embeddings else None, n_results=n_results, include=["embeddings", "documents", "metadatas", "distances"], ) diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py index 57d4855..e22dcae 100644 --- a/chromadb/test/property/strategies.py +++ b/chromadb/test/property/strategies.py @@ -230,6 +230,7 @@ def collections( with_hnsw_params: bool = False, has_embeddings: Optional[bool] = None, has_documents: Optional[bool] = None, + with_persistent_hnsw_params: bool = False, ) -> Collection: """Strategy to generate a Collection object. If add_filterable_data is True, then known_metadata_keys and known_document_keywords will be populated with consistent data.""" @@ -240,10 +241,20 @@ def collections( dimension = draw(st.integers(min_value=2, max_value=2048)) dtype = draw(st.sampled_from(float_types)) + if with_persistent_hnsw_params and not with_hnsw_params: + raise ValueError( + "with_hnsw_params requires with_persistent_hnsw_params to be true" + ) + if with_hnsw_params: if metadata is None: metadata = {} metadata.update(test_hnsw_config) + if with_persistent_hnsw_params: + metadata["hnsw:batch_size"] = draw(st.integers(min_value=3, max_value=2000)) + metadata["hnsw:sync_threshold"] = draw( + st.integers(min_value=3, max_value=2000) + ) # Sometimes, select a space at random if draw(st.booleans()): # TODO: pull the distance functions from a source of truth that lives not diff --git a/chromadb/test/property/test_add.py b/chromadb/test/property/test_add.py index a70f3fc..d5d6b84 100644 --- a/chromadb/test/property/test_add.py +++ b/chromadb/test/property/test_add.py @@ -29,7 +29,7 @@ def test_add( if not invariants.is_metadata_valid(normalized_record_set): with pytest.raises(Exception): - collection.add(**normalized_record_set) + coll.add(**normalized_record_set) return coll.add(**record_set) diff --git a/chromadb/test/property/test_cross_version_persist.py b/chromadb/test/property/test_cross_version_persist.py index f4b87ae..95a937c 100644 --- a/chromadb/test/property/test_cross_version_persist.py +++ b/chromadb/test/property/test_cross_version_persist.py @@ -78,9 +78,14 @@ def configurations(versions: List[str]) -> List[Tuple[str, Settings]]: ( version, Settings( - chroma_api_impl="local", - chroma_db_impl="duckdb+parquet", - persist_directory=tempfile.gettempdir() + "/tests/" + version + "/", + chroma_api_impl="chromadb.api.segment.SegmentAPI", + chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", + allow_reset=True, + is_persistent=True, + persist_directory=tempfile.gettempdir(), ), ) for version in versions @@ -199,7 +204,6 @@ def persist_generated_data_with_old_version( embedding_id_to_index = {id: i for i, id in enumerate(check_embeddings["ids"])} actual_ids = sorted(actual_ids, key=lambda id: embedding_id_to_index[id]) assert actual_ids == check_embeddings["ids"] - api.persist() except Exception as e: conn.send(e) raise e diff --git a/chromadb/test/property/test_embeddings.py b/chromadb/test/property/test_embeddings.py index de00741..c2bbf74 100644 --- a/chromadb/test/property/test_embeddings.py +++ b/chromadb/test/property/test_embeddings.py @@ -1,13 +1,12 @@ import pytest import logging import hypothesis.strategies as st -from typing import Set, cast, Union, DefaultDict +from typing import Dict, Set, cast, Union, DefaultDict from dataclasses import dataclass from chromadb.api.types import ID, Include, IDs import chromadb.errors as errors from chromadb.api import API from chromadb.api.models.Collection import Collection -from chromadb.db.impl.sqlite import SqliteDB import chromadb.test.property.strategies as strategies from hypothesis.stateful import ( Bundle, @@ -237,16 +236,11 @@ class EmbeddingStateMachine(RuleBasedStateMachine): # Sqlite merges the metadata, as opposed to old # implementations which overwrites it record_set_state = self.record_set_state["metadatas"][target_idx] - if ( - hasattr(self.api, "_sysdb") - and type(self.api._sysdb) == SqliteDB - and record_set_state is not None - ): + if record_set_state is not None: + record_set_state = cast( + Dict[str, Union[str, int, float]], record_set_state + ) record_set_state.update(normalized_record_set["metadatas"][idx]) - else: - self.record_set_state["metadatas"][ - target_idx - ] = normalized_record_set["metadatas"][idx] if normalized_record_set["documents"] is not None: self.record_set_state["documents"][ target_idx diff --git a/chromadb/test/property/test_persist.py b/chromadb/test/property/test_persist.py index 7a1e8ca..3e9ac0f 100644 --- a/chromadb/test/property/test_persist.py +++ b/chromadb/test/property/test_persist.py @@ -13,8 +13,15 @@ import chromadb.test.property.invariants as invariants from chromadb.test.property.test_embeddings import ( EmbeddingStateMachine, EmbeddingStateMachineStates, + collection_st as embedding_collection_st, + trace, +) +from hypothesis.stateful import ( + run_state_machine_as_test, + rule, + precondition, + initialize, ) -from hypothesis.stateful import run_state_machine_as_test, rule, precondition import os import shutil import tempfile @@ -23,24 +30,35 @@ CreatePersistAPI = Callable[[], API] configurations = [ Settings( - chroma_api_impl="local", - chroma_db_impl="duckdb+parquet", + chroma_api_impl="chromadb.api.segment.SegmentAPI", + chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", + allow_reset=True, + is_persistent=True, persist_directory=tempfile.gettempdir() + "/tests", - ) + ), ] @pytest.fixture(scope="module", params=configurations) def settings(request: pytest.FixtureRequest) -> Generator[Settings, None, None]: configuration = request.param - yield configuration save_path = configuration.persist_directory + # Create if it doesn't exist + if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) + yield configuration # Remove if it exists if os.path.exists(save_path): shutil.rmtree(save_path) -collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll") +collection_st = st.shared( + strategies.collections(with_hnsw_params=True, with_persistent_hnsw_params=True), + key="coll", +) @given( @@ -77,7 +95,6 @@ def test_persist( embedding_function=collection_strategy.embedding_function, ) - api_1.persist() del api_1 api_2 = chromadb.Client(settings) @@ -130,6 +147,24 @@ class PersistEmbeddingsStateMachine(EmbeddingStateMachine): self.api.reset() super().__init__(self.api) + @initialize(collection=embedding_collection_st, batch_size=st.integers(min_value=3, max_value=2000), sync_threshold=st.integers(min_value=3, max_value=2000)) # type: ignore + def initialize( + self, collection: strategies.Collection, batch_size: int, sync_threshold: int + ): + self.api.reset() + self.collection = self.api.create_collection( + name=collection.name, + metadata=collection.metadata, + embedding_function=collection.embedding_function, + ) + self.embedding_function = collection.embedding_function + trace("init") + self.on_state_change(EmbeddingStateMachineStates.initialize) + + self.record_set_state = strategies.StateMachineRecordSet( + ids=[], metadatas=[], documents=[], embeddings=[] + ) + @precondition( lambda self: len(self.record_set_state["ids"]) >= 1 and self.last_persist_delay <= 0 @@ -137,7 +172,6 @@ class PersistEmbeddingsStateMachine(EmbeddingStateMachine): @rule() def persist(self) -> None: self.on_state_change(PersistEmbeddingsStateMachineStates.persist) - self.api.persist() collection_name = self.collection.name # Create a new process and then inside the process run the invariants # TODO: Once we switch off of duckdb and onto sqlite we can remove this @@ -160,6 +194,9 @@ class PersistEmbeddingsStateMachine(EmbeddingStateMachine): else: self.last_persist_delay -= 1 + def teardown(self) -> None: + self.api.reset() + def test_persist_embeddings_state( caplog: pytest.LogCaptureFixture, settings: Settings diff --git a/chromadb/test/segment/test_metadata.py b/chromadb/test/segment/test_metadata.py index 9c3d10d..a5b861b 100644 --- a/chromadb/test/segment/test_metadata.py +++ b/chromadb/test/segment/test_metadata.py @@ -1,3 +1,6 @@ +import os +import shutil +import tempfile import pytest from typing import Generator, List, Callable, Iterator, Dict, Optional, Union, Sequence from chromadb.config import System, Settings @@ -23,15 +26,29 @@ from itertools import count def sqlite() -> Generator[System, None, None]: """Fixture generator for sqlite DB""" - settings = Settings(sqlite_database=":memory:", allow_reset=True) + settings = Settings(allow_reset=True, is_persistent=False) system = System(settings) system.start() yield system system.stop() +def sqlite_persistent() -> Generator[System, None, None]: + """Fixture generator for sqlite DB""" + save_path = tempfile.mkdtemp() + settings = Settings( + allow_reset=True, is_persistent=True, persist_directory=save_path + ) + system = System(settings) + system.start() + yield system + system.stop() + if os.path.exists(save_path): + shutil.rmtree(save_path) + + def system_fixtures() -> List[Callable[[], Generator[System, None, None]]]: - return [sqlite] + return [sqlite, sqlite_persistent] @pytest.fixture(scope="module", params=system_fixtures()) @@ -106,8 +123,8 @@ def sync(segment: MetadataReader, seq_id: SeqId) -> None: def test_insert_and_count( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] ) -> None: - system.reset_state() producer = system.instance(Producer) + system.reset_state() topic = str(segment_definition["topic"]) @@ -144,9 +161,8 @@ def assert_equiv_records( def test_get( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] ) -> None: - system.reset_state() - producer = system.instance(Producer) + system.reset_state() topic = str(segment_definition["topic"]) embeddings = [next(sample_embeddings) for i in range(10)] @@ -242,9 +258,8 @@ def test_get( def test_fulltext( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] ) -> None: - system.reset_state() - producer = system.instance(Producer) + system.reset_state() topic = str(segment_definition["topic"]) segment = SqliteMetadataSegment(system, segment_definition) @@ -304,9 +319,8 @@ def test_fulltext( def test_delete( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] ) -> None: - system.reset_state() - producer = system.instance(Producer) + system.reset_state() topic = str(segment_definition["topic"]) segment = SqliteMetadataSegment(system, segment_definition) @@ -367,9 +381,8 @@ def test_delete( def test_update( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] ) -> None: - system.reset_state() - producer = system.instance(Producer) + system.reset_state() topic = str(segment_definition["topic"]) segment = SqliteMetadataSegment(system, segment_definition) @@ -395,9 +408,8 @@ def test_update( def test_upsert( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] ) -> None: - system.reset_state() - producer = system.instance(Producer) + system.reset_state() topic = str(segment_definition["topic"]) segment = SqliteMetadataSegment(system, segment_definition) diff --git a/chromadb/test/segment/test_vector.py b/chromadb/test/segment/test_vector.py index 1ba084b..de142d7 100644 --- a/chromadb/test/segment/test_vector.py +++ b/chromadb/test/segment/test_vector.py @@ -1,5 +1,5 @@ import pytest -from typing import Generator, List, Callable, Iterator, cast +from typing import Generator, List, Callable, Iterator, Type, cast from chromadb.config import System, Settings from chromadb.types import ( SubmitEmbeddingRecord, @@ -16,23 +16,58 @@ from chromadb.segment import VectorReader import uuid import time -from chromadb.segment.impl.vector.local_hnsw import LocalHnswSegment +from chromadb.segment.impl.vector.local_hnsw import ( + LocalHnswSegment, +) + +from chromadb.segment.impl.vector.local_persistent_hnsw import ( + PersistentLocalHnswSegment, +) from pytest import FixtureRequest from itertools import count +import tempfile +import os +import shutil def sqlite() -> Generator[System, None, None]: """Fixture generator for sqlite DB""" - settings = Settings(sqlite_database=":memory:", allow_reset=True) + save_path = tempfile.mkdtemp() + settings = Settings( + allow_reset=True, + is_persistent=False, + persist_directory=save_path, + ) system = System(settings) system.start() yield system system.stop() + if os.path.exists(save_path): + shutil.rmtree(save_path) +def sqlite_persistent() -> Generator[System, None, None]: + """Fixture generator for sqlite DB""" + save_path = tempfile.mkdtemp() + settings = Settings( + allow_reset=True, + is_persistent=True, + persist_directory=save_path, + ) + system = System(settings) + system.start() + yield system + system.stop() + if os.path.exists(save_path): + shutil.rmtree(save_path) + + +# We will excercise in memory, persistent sqlite with both ephemeral and persistent hnsw. +# We technically never expose persitent sqlite with memory hnsw to users, but it's a valid +# configuration, so we test it here. def system_fixtures() -> List[Callable[[], Generator[System, None, None]]]: - return [sqlite] + return [sqlite, sqlite_persistent] @pytest.fixture(scope="module", params=system_fixtures()) @@ -60,14 +95,24 @@ def sample_embeddings() -> Iterator[SubmitEmbeddingRecord]: return (create_record(i) for i in count()) -segment_definition = Segment( - id=uuid.uuid4(), - type="test_type", - scope=SegmentScope.VECTOR, - topic="persistent://test/test/test_topic_1", - collection=None, - metadata=None, -) +def vector_readers() -> List[Type[VectorReader]]: + return [LocalHnswSegment, PersistentLocalHnswSegment] + + +@pytest.fixture(scope="module", params=vector_readers()) +def vector_reader(request: FixtureRequest) -> Generator[Type[VectorReader], None, None]: + yield request.param + + +def create_random_segment_definition() -> Segment: + return Segment( + id=uuid.uuid4(), + type="test_type", + scope=SegmentScope.VECTOR, + topic="persistent://test/test/test_topic_1", + collection=None, + metadata=None, + ) def sync(segment: VectorReader, seq_id: SeqId) -> None: @@ -81,18 +126,21 @@ def sync(segment: VectorReader, seq_id: SeqId) -> None: def test_insert_and_count( - system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] + system: System, + sample_embeddings: Iterator[SubmitEmbeddingRecord], + vector_reader: Type[VectorReader], ) -> None: - system.reset_state() producer = system.instance(Producer) + system.reset_state() + segment_definition = create_random_segment_definition() topic = str(segment_definition["topic"]) max_id = 0 for i in range(3): max_id = producer.submit_embedding(topic, next(sample_embeddings)) - segment = LocalHnswSegment(system, segment_definition) + segment = vector_reader(system, segment_definition) segment.start() sync(segment, max_id) @@ -114,14 +162,16 @@ def approx_equal_vector(a: Vector, b: Vector, epsilon: float = 0.0001) -> bool: def test_get_vectors( - system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] + system: System, + sample_embeddings: Iterator[SubmitEmbeddingRecord], + vector_reader: Type[VectorReader], ) -> None: - system.reset_state() producer = system.instance(Producer) - + system.reset_state() + segment_definition = create_random_segment_definition() topic = str(segment_definition["topic"]) - segment = LocalHnswSegment(system, segment_definition) + segment = vector_reader(system, segment_definition) segment.start() embeddings = [next(sample_embeddings) for i in range(10)] @@ -157,14 +207,16 @@ def test_get_vectors( def test_ann_query( - system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] + system: System, + sample_embeddings: Iterator[SubmitEmbeddingRecord], + vector_reader: Type[VectorReader], ) -> None: - system.reset_state() producer = system.instance(Producer) - + system.reset_state() + segment_definition = create_random_segment_definition() topic = str(segment_definition["topic"]) - segment = LocalHnswSegment(system, segment_definition) + segment = vector_reader(system, segment_definition) segment.start() embeddings = [next(sample_embeddings) for i in range(100)] @@ -220,14 +272,16 @@ def test_ann_query( def test_delete( - system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] + system: System, + sample_embeddings: Iterator[SubmitEmbeddingRecord], + vector_reader: Type[VectorReader], ) -> None: - system.reset_state() producer = system.instance(Producer) - + system.reset_state() + segment_definition = create_random_segment_definition() topic = str(segment_definition["topic"]) - segment = LocalHnswSegment(system, segment_definition) + segment = vector_reader(system, segment_definition) segment.start() embeddings = [next(sample_embeddings) for i in range(5)] @@ -261,6 +315,8 @@ def test_delete( assert segment.get_vectors(ids=[embeddings[0]["id"]]) == [] results = segment.get_vectors() assert len(results) == 4 + # get_vectors returns results in arbitrary order + results = sorted(results, key=lambda v: v["id"]) for actual, expected in zip(results, embeddings[1:]): assert actual["id"] == expected["id"] assert approx_equal_vector( @@ -357,14 +413,16 @@ def _test_update( def test_update( - system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] + system: System, + sample_embeddings: Iterator[SubmitEmbeddingRecord], + vector_reader: Type[VectorReader], ) -> None: - system.reset_state() producer = system.instance(Producer) - + system.reset_state() + segment_definition = create_random_segment_definition() topic = str(segment_definition["topic"]) - segment = LocalHnswSegment(system, segment_definition) + segment = vector_reader(system, segment_definition) segment.start() _test_update(producer, topic, segment, sample_embeddings, Operation.UPDATE) @@ -388,14 +446,16 @@ def test_update( def test_upsert( - system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] + system: System, + sample_embeddings: Iterator[SubmitEmbeddingRecord], + vector_reader: Type[VectorReader], ) -> None: - system.reset_state() producer = system.instance(Producer) - + system.reset_state() + segment_definition = create_random_segment_definition() topic = str(segment_definition["topic"]) - segment = LocalHnswSegment(system, segment_definition) + segment = vector_reader(system, segment_definition) segment.start() _test_update(producer, topic, segment, sample_embeddings, Operation.UPSERT) diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index e370026..7eb554a 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -16,10 +16,15 @@ from chromadb.utils.embedding_functions import ( def local_persist_api(): return chromadb.Client( Settings( - chroma_api_impl="local", - chroma_db_impl="duckdb+parquet", - persist_directory=tempfile.gettempdir() + "/test_server", - ) + chroma_api_impl="chromadb.api.segment.SegmentAPI", + chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", + allow_reset=True, + is_persistent=True, + persist_directory=tempfile.gettempdir(), + ), ) @@ -28,10 +33,15 @@ def local_persist_api(): def local_persist_api_cache_bust(): return chromadb.Client( Settings( - chroma_api_impl="local", - chroma_db_impl="duckdb+parquet", - persist_directory=tempfile.gettempdir() + "/test_server", - ) + chroma_api_impl="chromadb.api.segment.SegmentAPI", + chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", + allow_reset=True, + is_persistent=True, + persist_directory=tempfile.gettempdir(), + ), ) @@ -52,9 +62,6 @@ def test_persist_index_loading(api_fixture, request): collection = api.create_collection("test") collection.add(ids="id1", documents="hello") - api.persist() - del api - api2 = request.getfixturevalue("local_persist_api_cache_bust") collection = api2.get_collection("test") @@ -75,9 +82,6 @@ def test_persist_index_loading_embedding_function(api_fixture, request): collection = api.create_collection("test", embedding_function=embedding_function) collection.add(ids="id1", documents="hello") - api.persist() - del api - api2 = request.getfixturevalue("local_persist_api_cache_bust") collection = api2.get_collection("test", embedding_function=embedding_function) @@ -100,9 +104,6 @@ def test_persist_index_get_or_create_embedding_function(api_fixture, request): ) collection.add(ids="id1", documents="hello") - api.persist() - del api - api2 = request.getfixturevalue("local_persist_api_cache_bust") collection = api2.get_or_create_collection( "test", embedding_function=embedding_function @@ -135,16 +136,11 @@ def test_persist(api_fixture, request): assert collection.count() == 2 - api.persist() - del api - api = request.getfixturevalue(api_fixture.__name__) collection = api.get_collection("testspace") assert collection.count() == 2 api.delete_collection("testspace") - api.persist() - del api api = request.getfixturevalue(api_fixture.__name__) assert api.list_collections() == [] @@ -1052,6 +1048,7 @@ def test_invalid_id(api): def test_index_params(api): + EPS = 1e-12 # first standard add api.reset() collection = api.create_collection(name="test_index_params") @@ -1073,8 +1070,8 @@ def test_index_params(api): query_embeddings=[0.6, 1.12, 1.6], n_results=1, ) - assert items["distances"][0][0] > 0 - assert items["distances"][0][0] < 1 + assert items["distances"][0][0] > 0 - EPS + assert items["distances"][0][0] < 1 + EPS # ip api.reset() @@ -1108,14 +1105,16 @@ def test_invalid_index_params(api): def test_persist_index_loading_params(api, request): api = request.getfixturevalue("local_persist_api") api.reset() - collection = api.create_collection("test", metadata={"hnsw:space": "ip"}) + collection = api.create_collection( + "test", + metadata={"hnsw:space": "ip"}, + ) collection.add(ids="id1", documents="hello") - api.persist() - del api - api2 = request.getfixturevalue("local_persist_api_cache_bust") - collection = api2.get_collection("test") + collection = api2.get_collection( + "test", + ) assert collection.metadata["hnsw:space"] == "ip" diff --git a/chromadb/test/test_chroma.py b/chromadb/test/test_chroma.py index 2ed1673..42b1441 100644 --- a/chromadb/test/test_chroma.py +++ b/chromadb/test/test_chroma.py @@ -1,53 +1,56 @@ import unittest import os from unittest.mock import patch, Mock - +import pytest import chromadb import chromadb.config -from chromadb.db import DB +from chromadb.db.system import SysDB +from chromadb.ingest import Consumer, Producer class GetDBTest(unittest.TestCase): - @patch("chromadb.db.duckdb.DuckDB", autospec=True) + @patch("chromadb.db.impl.sqlite.SqliteDB", autospec=True) def test_default_db(self, mock: Mock) -> None: system = chromadb.config.System( chromadb.config.Settings(persist_directory="./foo") ) - system.instance(DB) + system.instance(SysDB) assert mock.called - @patch("chromadb.db.duckdb.PersistentDuckDB", autospec=True) - def test_persistent_duckdb(self, mock: Mock) -> None: + @patch("chromadb.db.impl.sqlite.SqliteDB", autospec=True) + def test_sqlite_sysdb(self, mock: Mock) -> None: system = chromadb.config.System( chromadb.config.Settings( - chroma_db_impl="duckdb+parquet", persist_directory="./foo" - ) - ) - system.instance(DB) - assert mock.called - - @patch("chromadb.db.clickhouse.Clickhouse", autospec=True) - def test_clickhouse(self, mock: Mock) -> None: - system = chromadb.config.System( - chromadb.config.Settings( - chroma_db_impl="clickhouse", + chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", persist_directory="./foo", - clickhouse_host="foo", - clickhouse_port="666", ) ) - system.instance(DB) + system.instance(SysDB) + assert mock.called + + @patch("chromadb.db.impl.sqlite.SqliteDB", autospec=True) + def test_sqlite_queue(self, mock: Mock) -> None: + system = chromadb.config.System( + chromadb.config.Settings( + chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", + chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", + persist_directory="./foo", + ) + ) + system.instance(Producer) + system.instance(Consumer) assert mock.called class GetAPITest(unittest.TestCase): - @patch("chromadb.api.local.LocalAPI", autospec=True) + @patch("chromadb.api.segment.SegmentAPI", autospec=True) @patch.dict(os.environ, {}, clear=True) def test_local(self, mock_api: Mock) -> None: chromadb.Client(chromadb.config.Settings(persist_directory="./foo")) assert mock_api.called - @patch("chromadb.db.duckdb.DuckDB", autospec=True) + @patch("chromadb.db.impl.sqlite.SqliteDB", autospec=True) @patch.dict(os.environ, {}, clear=True) def test_local_db(self, mock_db: Mock) -> None: chromadb.Client(chromadb.config.Settings(persist_directory="./foo")) @@ -58,7 +61,7 @@ class GetAPITest(unittest.TestCase): def test_fastapi(self, mock: Mock) -> None: chromadb.Client( chromadb.config.Settings( - chroma_api_impl="rest", + chroma_api_impl="chromadb.api.fastapi.FastAPI", persist_directory="./foo", chroma_server_host="foo", chroma_server_http_port="80", @@ -70,7 +73,7 @@ class GetAPITest(unittest.TestCase): @patch.dict(os.environ, {}, clear=True) def test_settings_pass_to_fastapi(self, mock: Mock) -> None: settings = chromadb.config.Settings( - chroma_api_impl="rest", + chroma_api_impl="chromadb.api.fastapi.FastAPI", chroma_server_host="foo", chroma_server_http_port="80", chroma_server_headers={"foo": "bar"}, @@ -90,3 +93,15 @@ class GetAPITest(unittest.TestCase): # Check if the settings passed to the mock match the settings we used # raise Exception(passed_settings.settings) assert passed_settings.settings == settings + + +def test_legacy_values() -> None: + with pytest.raises(ValueError): + chromadb.Client( + chromadb.config.Settings( + chroma_api_impl="chromadb.api.local.LocalAPI", + persist_directory="./foo", + chroma_server_host="foo", + chroma_server_http_port="80", + ) + ) diff --git a/chromadb/test/test_client.py b/chromadb/test/test_client.py new file mode 100644 index 0000000..1164e1e --- /dev/null +++ b/chromadb/test/test_client.py @@ -0,0 +1,37 @@ +import chromadb +from chromadb.api import API +import chromadb.server.fastapi +import pytest +import tempfile + + +@pytest.fixture +def ephemeral_api() -> API: + return chromadb.EphemeralClient() + + +@pytest.fixture +def persistent_api() -> API: + return chromadb.PersistentClient( + path=tempfile.gettempdir() + "/test_server", + ) + + +@pytest.fixture +def http_api() -> API: + return chromadb.HttpClient() + + +def test_ephemeral_client(ephemeral_api: API) -> None: + settings = ephemeral_api.get_settings() + assert settings.is_persistent is False + + +def test_persistent_client(persistent_api: API) -> None: + settings = persistent_api.get_settings() + assert settings.is_persistent is True + + +def test_http_client(http_api: API) -> None: + settings = http_api.get_settings() + assert settings.chroma_api_impl == "chromadb.api.fastapi.FastAPI" diff --git a/chromadb/test/test_multithreaded.py b/chromadb/test/test_multithreaded.py new file mode 100644 index 0000000..57c259d --- /dev/null +++ b/chromadb/test/test_multithreaded.py @@ -0,0 +1,221 @@ +import multiprocessing +from concurrent.futures import Future, ThreadPoolExecutor, wait +import random +import threading +from typing import Any, Dict, List, Optional, Set, Tuple, cast +import numpy as np + +from chromadb.api import API +import chromadb.test.property.invariants as invariants +from chromadb.test.property.strategies import RecordSet +from chromadb.test.property.strategies import test_hnsw_config +from chromadb.types import Metadata + + +def generate_data_shape() -> Tuple[int, int]: + N = random.randint(10, 10000) + D = random.randint(10, 256) + return (N, D) + + +def generate_record_set(N: int, D: int) -> RecordSet: + ids = [str(i) for i in range(N)] + metadatas: List[Dict[str, int]] = [{f"{i}": i} for i in range(N)] + documents = [f"doc {i}" for i in range(N)] + embeddings = np.random.rand(N, D).tolist() + + # Create a normalized record set to compare against + normalized_record_set: RecordSet = { + "ids": ids, + "embeddings": embeddings, + "metadatas": metadatas, # type: ignore + "documents": documents, + } + + return normalized_record_set + + +# Hypothesis is bad at generating large datasets so we manually generate data in +# this test to test multithreaded add with larger datasets +def _test_multithreaded_add(api: API, N: int, D: int, num_workers: int) -> None: + records_set = generate_record_set(N, D) + ids = records_set["ids"] + embeddings = records_set["embeddings"] + metadatas = records_set["metadatas"] + documents = records_set["documents"] + + print(f"Adding {N} records with {D} dimensions on {num_workers} workers") + + # TODO: batch_size and sync_threshold should be configurable + api.reset() + coll = api.create_collection(name="test", metadata=test_hnsw_config) + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures: List[Future[Any]] = [] + total_sent = -1 + while total_sent < len(ids): + # Randomly grab up to 10% of the dataset and send it to the executor + batch_size = random.randint(1, N // 10) + to_send = min(batch_size, len(ids) - total_sent) + start = total_sent + 1 + end = total_sent + to_send + 1 + if embeddings is not None and len(embeddings[start:end]) == 0: + break + future = executor.submit( + coll.add, + ids=ids[start:end], + embeddings=embeddings[start:end] if embeddings is not None else None, + metadatas=metadatas[start:end] if metadatas is not None else None, # type: ignore + documents=documents[start:end] if documents is not None else None, + ) + futures.append(future) + total_sent += to_send + + wait(futures) + + for future in futures: + exception = future.exception() + if exception is not None: + raise exception + + # Check that invariants hold + invariants.count(coll, records_set) + invariants.ids_match(coll, records_set) + invariants.metadatas_match(coll, records_set) + invariants.no_duplicates(coll) + + # Check that the ANN accuracy is good + # On a random subset of the dataset + query_indices = random.sample([i for i in range(N)], 10) + n_results = 5 + invariants.ann_accuracy( + coll, + records_set, + n_results=n_results, + query_indices=query_indices, + ) + + +def _test_interleaved_add_query(api: API, N: int, D: int, num_workers: int) -> None: + """Test that will use multiple threads to interleave operations on the db and verify they work correctly""" + + api.reset() + coll = api.create_collection(name="test", metadata=test_hnsw_config) + + records_set = generate_record_set(N, D) + ids = cast(List[str], records_set["ids"]) + embeddings = cast(List[float], records_set["embeddings"]) + metadatas = cast(List[Metadata], records_set["metadatas"]) + documents = records_set["documents"] + + added_ids: Set[str] = set() + lock = threading.Lock() + + print(f"Adding {N} records with {D} dimensions on {num_workers} workers") + + def perform_operation( + operation: int, ids_to_modify: Optional[List[str]] = None + ) -> None: + """Perform a random operation on the collection""" + if operation == 0: + assert ids_to_modify is not None + indices_to_modify = [ids.index(id) for id in ids_to_modify] + # Add a subset of the dataset + if len(indices_to_modify) == 0: + return + coll.add( + ids=ids_to_modify, + embeddings=[embeddings[i] for i in indices_to_modify] + if embeddings is not None + else None, + metadatas=[metadatas[i] for i in indices_to_modify] + if metadatas is not None + else None, + documents=[documents[i] for i in indices_to_modify] + if documents is not None + else None, + ) + with lock: + added_ids.update(ids_to_modify) + elif operation == 1: + currently_added_ids = [] + n_results = 5 + with lock: + currently_added_ids = list(added_ids.copy()) + currently_added_indices = [ids.index(id) for id in currently_added_ids] + if ( + len(currently_added_ids) == 0 + or len(currently_added_indices) < n_results + ): + return + # Query the collection, we can't test the results because we want to interleave + # queries and adds. We cannot do so without a lock and serializing the operations + # which would defeat the purpose of this test. Instead we interleave queries and + # adds and check the invariants at the end + query_indices = random.sample( + currently_added_indices, + min(10, len(currently_added_indices)), + ) + query_vectors = [embeddings[i] for i in query_indices] + # Query the collections + coll.query( + query_vectors, + n_results=n_results, + ) + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures: List[Future[Any]] = [] + total_sent = -1 + while total_sent < len(ids) - 1: + operation = random.randint(0, 2) + if operation == 0: + # Randomly grab up to 10% of the dataset and send it to the executor + batch_size = random.randint(1, N // 10) + to_send = min(batch_size, len(ids) - total_sent) + start = total_sent + 1 + end = total_sent + to_send + 1 + future = executor.submit(perform_operation, operation, ids[start:end]) + futures.append(future) + total_sent += to_send + elif operation == 1: + future = executor.submit( + perform_operation, + operation, + ) + futures.append(future) + + wait(futures) + + for future in futures: + exception = future.exception() + if exception is not None: + raise exception + + # Check that invariants hold + invariants.count(coll, records_set) + invariants.ids_match(coll, records_set) + invariants.metadatas_match(coll, records_set) + invariants.no_duplicates(coll) + # Check that the ANN accuracy is good + # On a random subset of the dataset + query_indices = random.sample([i for i in range(N)], 10) + n_results = 5 + invariants.ann_accuracy( + coll, + records_set, + n_results=n_results, + query_indices=query_indices, + ) + + +def test_multithreaded_add(api: API) -> None: + for i in range(3): + num_workers = random.randint(2, multiprocessing.cpu_count() * 2) + N, D = generate_data_shape() + _test_multithreaded_add(api, N, D, num_workers) + + +def test_interleaved_add_query(api: API) -> None: + for i in range(3): + num_workers = random.randint(2, multiprocessing.cpu_count() * 2) + N, D = generate_data_shape() + _test_interleaved_add_query(api, N, D, num_workers) diff --git a/chromadb/utils/distance_functions.py b/chromadb/utils/distance_functions.py new file mode 100644 index 0000000..88fc770 --- /dev/null +++ b/chromadb/utils/distance_functions.py @@ -0,0 +1,18 @@ +from typing import Dict, Callable +import numpy as np +import numpy.typing as npt + + +# These match what the spec of hnswlib is +# This epsilon is used to prevent division by zero and the value is the same +# https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/python_bindings/bindings.cpp#L238 +NORM_EPS = 1e-30 +distance_functions: Dict[str, Callable[[npt.ArrayLike, npt.ArrayLike], float]] = { + "l2": lambda x, y: np.linalg.norm(x - y) ** 2, # type: ignore + "cosine": lambda x, y: 1 - np.dot(x, y) / ((np.linalg.norm(x) + NORM_EPS) * (np.linalg.norm(y) + NORM_EPS)), # type: ignore + "ip": lambda x, y: 1 - np.dot(x, y), # type: ignore +} + +l2 = distance_functions["l2"] +cosine = distance_functions["cosine"] +ip = distance_functions["ip"] diff --git a/chromadb/utils/read_write_lock.py b/chromadb/utils/read_write_lock.py new file mode 100644 index 0000000..16c60ca --- /dev/null +++ b/chromadb/utils/read_write_lock.py @@ -0,0 +1,74 @@ +import threading +from types import TracebackType +from typing import Optional, Type + + +class ReadWriteLock: + """A lock object that allows many simultaneous "read locks", but + only one "write lock." """ + + def __init__(self) -> None: + self._read_ready = threading.Condition(threading.RLock()) + self._readers = 0 + + def acquire_read(self) -> None: + """Acquire a read lock. Blocks only if a thread has + acquired the write lock.""" + self._read_ready.acquire() + try: + self._readers += 1 + finally: + self._read_ready.release() + + def release_read(self) -> None: + """Release a read lock.""" + self._read_ready.acquire() + try: + self._readers -= 1 + if not self._readers: + self._read_ready.notifyAll() + finally: + self._read_ready.release() + + def acquire_write(self) -> None: + """Acquire a write lock. Blocks until there are no + acquired read or write locks.""" + self._read_ready.acquire() + while self._readers > 0: + self._read_ready.wait() + + def release_write(self) -> None: + """Release a write lock.""" + self._read_ready.release() + + +class ReadRWLock: + def __init__(self, rwLock: ReadWriteLock): + self.rwLock = rwLock + + def __enter__(self) -> None: + self.rwLock.acquire_read() + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.rwLock.release_read() + + +class WriteRWLock: + def __init__(self, rwLock: ReadWriteLock): + self.rwLock = rwLock + + def __enter__(self) -> None: + self.rwLock.acquire_write() + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.rwLock.release_write() diff --git a/clients/js/src/ChromaClient.ts b/clients/js/src/ChromaClient.ts index fe377cb..8d20860 100644 --- a/clients/js/src/ChromaClient.ts +++ b/clients/js/src/ChromaClient.ts @@ -83,13 +83,6 @@ export class ChromaClient { return ret["nanosecond heartbeat"] } - /** - * @ignore - */ - public async persist(): Promise { - throw new Error("Not implemented in JS client"); - } - /** * Creates a new collection with the specified properties. * diff --git a/clients/js/src/Collection.ts b/clients/js/src/Collection.ts index 2852b7b..e99e7bd 100644 --- a/clients/js/src/Collection.ts +++ b/clients/js/src/Collection.ts @@ -195,6 +195,7 @@ export class Collection { embeddings: embeddingsArray as number[][], // We know this is defined because of the validate function // @ts-ignore documents: documentsArray, + // @ts-ignore metadatas: metadatasArray, }, this.api.options) .then(handleSuccess) @@ -248,6 +249,7 @@ export class Collection { embeddings: embeddingsArray as number[][], // We know this is defined because of the validate function //@ts-ignore documents: documentsArray, + //@ts-ignore metadatas: metadatasArray, }, this.api.options @@ -361,6 +363,7 @@ export class Collection { where, limit, offset, + //@ts-ignore include, where_document: whereDocument, }, this.api.options) @@ -512,6 +515,7 @@ export class Collection { where, n_results: nResults, where_document: whereDocument, + //@ts-ignore include: include, }, this.api.options) .then(handleSuccess) diff --git a/clients/js/src/generated/api.ts b/clients/js/src/generated/api.ts index ec92a5b..a465d41 100644 --- a/clients/js/src/generated/api.ts +++ b/clients/js/src/generated/api.ts @@ -413,32 +413,6 @@ export const ApiApiFetchParamCreator = function (configuration?: Configuration) options: localVarRequestOptions, }; }, - /** - * @summary Persist - * @param {RequestInit} [options] Override http request option. - * @throws {RequiredError} - */ - persist(options: RequestInit = {}): FetchArgs { - let localVarPath = `/api/v1/persist`; - const localVarPathQueryStart = localVarPath.indexOf("?"); - const localVarRequestOptions: RequestInit = Object.assign({ method: 'POST' }, options); - const localVarHeaderParameter: Headers = options.headers ? new Headers(options.headers) : new Headers(); - const localVarQueryParameter = new URLSearchParams(localVarPathQueryStart !== -1 ? localVarPath.substring(localVarPathQueryStart + 1) : ""); - if (localVarPathQueryStart !== -1) { - localVarPath = localVarPath.substring(0, localVarPathQueryStart); - } - - localVarRequestOptions.headers = localVarHeaderParameter; - - const localVarQueryParameterString = localVarQueryParameter.toString(); - if (localVarQueryParameterString) { - localVarPath += "?" + localVarQueryParameterString; - } - return { - url: localVarPath, - options: localVarRequestOptions, - }; - }, /** * @summary Raw Sql * @param {Api.RawSql} request @@ -1001,28 +975,6 @@ export const ApiApiFp = function(configuration?: Configuration) { }); }; }, - /** - * @summary Persist - * @param {RequestInit} [options] Override http request option. - * @throws {RequiredError} - */ - persist(options?: RequestInit): (fetch?: FetchAPI, basePath?: string) => Promise { - const localVarFetchArgs = ApiApiFetchParamCreator(configuration).persist(options); - return (fetch: FetchAPI = defaultFetch, basePath: string = BASE_PATH) => { - return fetch(basePath + localVarFetchArgs.url, localVarFetchArgs.options).then((response) => { - const contentType = response.headers.get('Content-Type'); - const mimeType = contentType ? contentType.replace(/;.*/, '') : undefined; - - if (response.status === 200) { - if (mimeType === 'application/json') { - return response.json() as any; - } - throw response; - } - throw response; - }); - }; - }, /** * @summary Raw Sql * @param {Api.RawSql} request @@ -1338,15 +1290,6 @@ export class ApiApi extends BaseAPI { return ApiApiFp(this.configuration).listCollections(options)(this.fetch, this.basePath); } - /** - * @summary Persist - * @param {RequestInit} [options] Override http request option. - * @throws {RequiredError} - */ - public persist(options?: RequestInit) { - return ApiApiFp(this.configuration).persist(options)(this.fetch, this.basePath); - } - /** * @summary Raw Sql * @param {Api.RawSql} request diff --git a/clients/js/src/generated/models.ts b/clients/js/src/generated/models.ts index ab9d6b8..b0d951a 100644 --- a/clients/js/src/generated/models.ts +++ b/clients/js/src/generated/models.ts @@ -17,10 +17,10 @@ export namespace Api { } export interface AddEmbedding { - embeddings: Api.AddEmbedding.Embedding[]; - metadatas?: Api.AddEmbedding.Metadatas.ArrayValue[] | Api.AddEmbedding.Metadatas.ObjectValue; - documents?: string | Api.AddEmbedding.Documents.ArrayValue[]; - ids?: string | Api.AddEmbedding.Ids.ArrayValue[]; + embeddings?: Api.AddEmbedding.Embedding[]; + metadatas?: Api.AddEmbedding.Metadata[]; + documents?: string[]; + ids: string[]; 'increment_index'?: boolean; } @@ -32,43 +32,7 @@ export namespace Api { export interface Embedding { } - export type Metadatas = Api.AddEmbedding.Metadatas.ArrayValue[] | Api.AddEmbedding.Metadatas.ObjectValue; - - /** - * @export - * @namespace Metadatas - */ - export namespace Metadatas { - export interface ArrayValue { - } - - export interface ObjectValue { - } - - } - - export type Documents = string | Api.AddEmbedding.Documents.ArrayValue[]; - - /** - * @export - * @namespace Documents - */ - export namespace Documents { - export interface ArrayValue { - } - - } - - export type Ids = string | Api.AddEmbedding.Ids.ArrayValue[]; - - /** - * @export - * @namespace Ids - */ - export namespace Ids { - export interface ArrayValue { - } - + export interface Metadata { } } @@ -108,7 +72,7 @@ export namespace Api { } export interface DeleteEmbedding { - ids?: Api.DeleteEmbedding.Id[]; + ids?: string[]; where?: Api.DeleteEmbedding.Where; 'where_document'?: Api.DeleteEmbedding.WhereDocument; } @@ -118,9 +82,6 @@ export namespace Api { * @namespace DeleteEmbedding */ export namespace DeleteEmbedding { - export interface Id { - } - export interface Where { } @@ -133,7 +94,7 @@ export namespace Api { } export interface GetEmbedding { - ids?: Api.GetEmbedding.Id[]; + ids?: string[]; where?: Api.GetEmbedding.Where; 'where_document'?: Api.GetEmbedding.WhereDocument; sort?: string; @@ -147,7 +108,7 @@ export namespace Api { * @memberof GetEmbedding */ offset?: number; - include?: Api.GetEmbedding.IncludeEnum[]; + include?: (Api.GetEmbedding.Include.EnumValueEnum | Api.GetEmbedding.Include.EnumValueEnum2 | Api.GetEmbedding.Include.EnumValueEnum3 | Api.GetEmbedding.Include.EnumValueEnum4)[]; } /** @@ -155,20 +116,35 @@ export namespace Api { * @namespace GetEmbedding */ export namespace GetEmbedding { - export interface Id { - } - export interface Where { } export interface WhereDocument { } - export enum IncludeEnum { - Documents = 'documents', - Embeddings = 'embeddings', - Metadatas = 'metadatas', - Distances = 'distances' + export type Include = Api.GetEmbedding.Include.EnumValueEnum | Api.GetEmbedding.Include.EnumValueEnum2 | Api.GetEmbedding.Include.EnumValueEnum3 | Api.GetEmbedding.Include.EnumValueEnum4; + + /** + * @export + * @namespace Include + */ + export namespace Include { + export enum EnumValueEnum { + Documents = 'documents' + } + + export enum EnumValueEnum2 { + Embeddings = 'embeddings' + } + + export enum EnumValueEnum3 { + Metadatas = 'metadatas' + } + + export enum EnumValueEnum4 { + Distances = 'distances' + } + } } @@ -186,9 +162,6 @@ export namespace Api { export interface ListCollections200Response { } - export interface Persist200Response { - } - export interface QueryEmbedding { where?: Api.QueryEmbedding.Where; 'where_document'?: Api.QueryEmbedding.WhereDocument; @@ -198,7 +171,7 @@ export namespace Api { * @memberof QueryEmbedding */ 'n_results'?: number; - include?: Api.QueryEmbedding.IncludeEnum[]; + include?: (Api.QueryEmbedding.Include.EnumValueEnum | Api.QueryEmbedding.Include.EnumValueEnum2 | Api.QueryEmbedding.Include.EnumValueEnum3 | Api.QueryEmbedding.Include.EnumValueEnum4)[]; } /** @@ -215,17 +188,35 @@ export namespace Api { export interface QueryEmbedding2 { } - export enum IncludeEnum { - Documents = 'documents', - Embeddings = 'embeddings', - Metadatas = 'metadatas', - Distances = 'distances' + export type Include = Api.QueryEmbedding.Include.EnumValueEnum | Api.QueryEmbedding.Include.EnumValueEnum2 | Api.QueryEmbedding.Include.EnumValueEnum3 | Api.QueryEmbedding.Include.EnumValueEnum4; + + /** + * @export + * @namespace Include + */ + export namespace Include { + export enum EnumValueEnum { + Documents = 'documents' + } + + export enum EnumValueEnum2 { + Embeddings = 'embeddings' + } + + export enum EnumValueEnum3 { + Metadatas = 'metadatas' + } + + export enum EnumValueEnum4 { + Distances = 'distances' + } + } } export interface RawSql { - 'raw_sql'?: string; + 'raw_sql': string; } export interface RawSql200Response { @@ -260,9 +251,9 @@ export namespace Api { export interface UpdateEmbedding { embeddings?: Api.UpdateEmbedding.Embedding[]; - metadatas?: Api.UpdateEmbedding.Metadatas.ArrayValue[] | Api.UpdateEmbedding.Metadatas.ObjectValue; - documents?: string | Api.UpdateEmbedding.Documents.ArrayValue[]; - ids?: string | Api.UpdateEmbedding.Ids.ArrayValue[]; + metadatas?: Api.UpdateEmbedding.Metadata[]; + documents?: string[]; + ids: string[]; 'increment_index'?: boolean; } @@ -274,43 +265,7 @@ export namespace Api { export interface Embedding { } - export type Metadatas = Api.UpdateEmbedding.Metadatas.ArrayValue[] | Api.UpdateEmbedding.Metadatas.ObjectValue; - - /** - * @export - * @namespace Metadatas - */ - export namespace Metadatas { - export interface ArrayValue { - } - - export interface ObjectValue { - } - - } - - export type Documents = string | Api.UpdateEmbedding.Documents.ArrayValue[]; - - /** - * @export - * @namespace Documents - */ - export namespace Documents { - export interface ArrayValue { - } - - } - - export type Ids = string | Api.UpdateEmbedding.Ids.ArrayValue[]; - - /** - * @export - * @namespace Ids - */ - export namespace Ids { - export interface ArrayValue { - } - + export interface Metadata { } } diff --git a/clients/js/test/client.test.ts b/clients/js/test/client.test.ts index 441d604..5fbf2b0 100644 --- a/clients/js/test/client.test.ts +++ b/clients/js/test/client.test.ts @@ -191,5 +191,5 @@ test('wrong code returns an error', async () => { // @ts-ignore - supposed to fail const results = await collection.get({ where: { "test": { "$contains": "hello" } } }); expect(results.error).toBeDefined() - expect(results.error).toBe("ValueError('Expected one of $gt, $lt, $gte, $lte, $ne, $eq, got $contains')") + expect(results.error).toBe("ValueError('Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, got $contains')") }) diff --git a/clients/js/test/update.collection.test.ts b/clients/js/test/update.collection.test.ts index 6de6430..e96f21d 100644 --- a/clients/js/test/update.collection.test.ts +++ b/clients/js/test/update.collection.test.ts @@ -38,7 +38,7 @@ test("it should get embedding with matching documents", async () => { expect(results2).toBeDefined(); expect(results2).toBeInstanceOf(Object); expect(results2.embeddings![0]).toEqual([1, 2, 3, 4, 5, 6, 7, 8, 9, 11]); - expect(results2.metadatas[0]).toEqual({ test: "test1new" }); + expect(results2.metadatas[0]).toEqual({ test: "test1new", float_value: -2 }); expect(results2.documents[0]).toEqual("doc1new"); }); diff --git a/clients/python/README.md b/clients/python/README.md index 640be06..c5e592b 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -16,11 +16,8 @@ To connect to your server and perform operations using the client only library, ```python import chromadb -from chromadb.config import Settings # Example setup of the client to connect to your chroma server -client = chromadb.Client(Settings(chroma_api_impl="rest", - chroma_server_host="localhost", - chroma_server_http_port=8000)) +client = chromadb.HttpClient(host="localhost", port=8000) collection = client.create_collection("all-my-documents") diff --git a/clients/python/integration-test.sh b/clients/python/integration-test.sh index 2288c6c..45a181e 100755 --- a/clients/python/integration-test.sh +++ b/clients/python/integration-test.sh @@ -18,7 +18,7 @@ trap cleanup EXIT docker compose -f docker-compose.test.yml up --build -d export CHROMA_INTEGRATION_TEST_ONLY=1 -export CHROMA_API_IMPL=rest +export CHROMA_API_IMPL=chromadb.api.fastapi.FastAPI export CHROMA_SERVER_HOST=localhost export CHROMA_SERVER_HTTP_PORT=8000 diff --git a/config/backup_disk.xml b/config/backup_disk.xml deleted file mode 100644 index 8c4d7c2..0000000 --- a/config/backup_disk.xml +++ /dev/null @@ -1,14 +0,0 @@ - - - - - local - /etc/clickhouse-server/ - - - - - backups - /etc/clickhouse-server/ - - diff --git a/config/chroma_users.xml b/config/chroma_users.xml deleted file mode 100644 index 59a09bc..0000000 --- a/config/chroma_users.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - 1 - 1 - - - diff --git a/docker-compose.server.example.yml b/docker-compose.server.example.yml index e150507..5c575fe 100644 --- a/docker-compose.server.example.yml +++ b/docker-compose.server.example.yml @@ -8,38 +8,12 @@ services: image: ghcr.io/chroma-core/chroma:latest volumes: - index_data:/chroma/.chroma/index - environment: - - CHROMA_DB_IMPL=clickhouse - - CLICKHOUSE_HOST=clickhouse - - CLICKHOUSE_PORT=8123 ports: - 8000:8000 - depends_on: - - clickhouse - networks: - - net - clickhouse: - image: clickhouse/clickhouse-server:22.9-alpine - environment: - - ALLOW_EMPTY_PASSWORD=yes - - CLICKHOUSE_TCP_PORT=9000 - - CLICKHOUSE_HTTP_PORT=8123 - ports: - - '8123:8123' - - '9000:9000' - volumes: - - clickhouse_data:/var/lib/clickhouse - - clickhouse_logs:/var/log/clickhouse-server - - backups:/backups - - ${PWD}/config/backup_disk.xml:/etc/clickhouse-server/config.d/backup_disk.xml - - ${PWD}/config/chroma_users.xml:/etc/clickhouse-server/users.d/chroma.xml networks: - net + volumes: - clickhouse_data: - driver: local - clickhouse_logs: - driver: local index_data: driver: local backups: diff --git a/docker-compose.test.yml b/docker-compose.test.yml index a15ca51..eb65303 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -14,38 +14,15 @@ services: - test_index_data:/index_data command: uvicorn chromadb.app:app --workers 1 --host 0.0.0.0 --port 8000 --log-config log_config.yml environment: - - CHROMA_DB_IMPL=clickhouse - - CLICKHOUSE_HOST=test_clickhouse - - CLICKHOUSE_PORT=8123 - ANONYMIZED_TELEMETRY=False - ALLOW_RESET=True + - IS_PERSISTENT=TRUE ports: - ${CHROMA_PORT}:8000 - depends_on: - - test_clickhouse - networks: - - test_net - - test_clickhouse: - image: clickhouse/clickhouse-server:22.9-alpine - environment: - - ALLOW_EMPTY_PASSWORD=yes - - CLICKHOUSE_TCP_PORT=9000 - - CLICKHOUSE_HTTP_PORT=8123 - ports: - - '8123:8123' - - '9000:9000' - volumes: - - test_clickhouse_data:/bitnami/clickhouse - - test_backups:/backups - - ./config/backup_disk.xml:/etc/clickhouse-server/config.d/backup_disk.xml - - ./config/chroma_users.xml:/etc/clickhouse-server/users.d/chroma.xml networks: - test_net volumes: - test_clickhouse_data: - driver: local test_index_data: driver: local test_backups: diff --git a/docker-compose.yml b/docker-compose.yml index 803d886..5f298f1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,39 +15,13 @@ services: - index_data:/index_data command: uvicorn chromadb.app:app --reload --workers 1 --host 0.0.0.0 --port 8000 --log-config log_config.yml environment: - - CHROMA_DB_IMPL=clickhouse - - CLICKHOUSE_HOST=clickhouse - - CLICKHOUSE_PORT=8123 + - IS_PERSISTENT=TRUE ports: - 8000:8000 - depends_on: - - clickhouse - networks: - - net - - clickhouse: - image: clickhouse/clickhouse-server:22.9-alpine - environment: - - ALLOW_EMPTY_PASSWORD=yes - - CLICKHOUSE_TCP_PORT=9000 - - CLICKHOUSE_HTTP_PORT=8123 - ports: - - '8123:8123' - - '9000:9000' - volumes: - - clickhouse_data:/var/lib/clickhouse - - clickhouse_logs:/var/log/clickhouse-server - - backups:/backups - - ./config/backup_disk.xml:/etc/clickhouse-server/config.d/backup_disk.xml - - ./config/chroma_users.xml:/etc/clickhouse-server/users.d/chroma.xml networks: - net volumes: - clickhouse_data: - driver: local - clickhouse_logs: - driver: local index_data: driver: local backups: diff --git a/examples/basic_functionality/alternative_embeddings.ipynb b/examples/basic_functionality/alternative_embeddings.ipynb index e12fc29..e069cdd 100644 --- a/examples/basic_functionality/alternative_embeddings.ipynb +++ b/examples/basic_functionality/alternative_embeddings.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -12,7 +13,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -21,24 +22,16 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using embedded DuckDB without persistence: data will be transient\n" - ] - } - ], + "outputs": [], "source": [ "client = chromadb.Client()" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -47,30 +40,30 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Using OpenAI Embeddings. This assumes you have the openai package installed\n", "openai_ef = embedding_functions.OpenAIEmbeddingFunction(\n", - " api_key=\"OPENAI_API_KEY\", # Replace with your own OpenAI API key\n", + " api_key=\"OPENAI_KEY\", # Replace with your own OpenAI API key\n", " model_name=\"text-embedding-ada-002\"\n", ")" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Create a new chroma collection\n", - "openai_collection = client.create_collection(name=\"openai_embeddings\", embedding_function=openai_ef)" + "openai_collection = client.get_or_create_collection(name=\"openai_embeddings\", embedding_function=openai_ef)" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -83,20 +76,20 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'ids': [['id1', 'id2']],\n", - " 'embeddings': None,\n", - " 'documents': [['This is a document', 'This is another document']],\n", + " 'distances': [[0.1385088860988617, 0.2017185091972351]],\n", " 'metadatas': [[{'source': 'my_source'}, {'source': 'my_source'}]],\n", - " 'distances': [[0.13865342736244202, 0.20187020301818848]]}" + " 'embeddings': None,\n", + " 'documents': [['This is a document', 'This is another document']]}" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -324,7 +317,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.10.8" } }, "nbformat": 4, diff --git a/examples/basic_functionality/local_persistence.ipynb b/examples/basic_functionality/local_persistence.ipynb index 9f10326..e05d638 100644 --- a/examples/basic_functionality/local_persistence.ipynb +++ b/examples/basic_functionality/local_persistence.ipynb @@ -6,7 +6,7 @@ "metadata": {}, "source": [ "# Local Peristence Demo\n", - "This notebook demonstrates how to persist the in-memory version of Chroma to disk, then load it back in. " + "This notebook demonstrates how to configure Chroma to persist to disk, then load it back in. " ] }, { @@ -15,56 +15,29 @@ "metadata": {}, "outputs": [], "source": [ - "import chromadb\n", - "from chromadb.config import Settings" + "import chromadb" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a new Chroma client with persistence enabled. \n", + "persist_directory = \"db\"\n", + "\n", + "client = chromadb.PersistentClient(path=persist_directory)\n", + "\n", + "# Create a new chroma collection\n", + "collection_name = \"peristed_collection\"\n", + "collection = client.get_or_create_collection(name=collection_name)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Running Chroma using direct local API.\n", - "No existing DB found in db, skipping load\n", - "No existing DB found in db, skipping load\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/antontroynikov/miniforge3/envs/chroma/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "# Create a new Chroma client with persistence enabled. \n", - "persist_directory = \"db\"\n", - "\n", - "client = chromadb.Client(\n", - " Settings(\n", - " persist_directory=persist_directory,\n", - " chroma_db_impl=\"duckdb+parquet\",\n", - " )\n", - ")\n", - "\n", - "# Start from scratch\n", - "client.reset()\n", - "\n", - "# Create a new chroma collection\n", - "collection_name = \"peristed_collection\"\n", - "collection = client.create_collection(name=collection_name)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, "outputs": [], "source": [ "# Add some data to the collection\n", @@ -96,56 +69,12 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Persisting DB to disk, putting it in the save folder db\n" - ] - }, - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Persist the DB. This also happens automatically when the client is garbage collected.\n", - "# In a notebook, prefer to call persist explicitly.\n", - "client.persist()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Running Chroma using direct local API.\n", - "loaded in 8 embeddings\n", - "loaded in 1 collections\n" - ] - } - ], + "outputs": [], "source": [ "# Create a new client with the same settings\n", - "client = chromadb.Client(\n", - " Settings(\n", - " persist_directory=persist_directory,\n", - " chroma_db_impl=\"duckdb+parquet\",\n", - " )\n", - ")\n", + "client = chromadb.PersistentClient(path=persist_directory)\n", "\n", "# Load the collection\n", "collection = client.get_collection(collection_name)" @@ -153,14 +82,14 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'embeddings': [[[1.1, 2.3, 3.2]]], 'documents': [['doc5']], 'ids': [['id5']], 'metadatas': [[{'uri': 'img5.png', 'style': 'style1'}]], 'distances': [[0.0]]}\n" + "{'ids': [['id1']], 'distances': [[5.1159076593562386e-15]], 'metadatas': [[{'style': 'style1', 'uri': 'img1.png'}]], 'embeddings': None, 'documents': [['doc1']]}\n" ] } ], @@ -176,25 +105,57 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Persisting DB to disk, putting it in the save folder db\n" - ] + "data": { + "text/plain": [ + "{'ids': ['id1', 'id2', 'id3', 'id4', 'id5', 'id6', 'id7', 'id8'],\n", + " 'embeddings': [[1.100000023841858, 2.299999952316284, 3.200000047683716],\n", + " [4.5, 6.900000095367432, 4.400000095367432],\n", + " [1.100000023841858, 2.299999952316284, 3.200000047683716],\n", + " [4.5, 6.900000095367432, 4.400000095367432],\n", + " [1.100000023841858, 2.299999952316284, 3.200000047683716],\n", + " [4.5, 6.900000095367432, 4.400000095367432],\n", + " [1.100000023841858, 2.299999952316284, 3.200000047683716],\n", + " [4.5, 6.900000095367432, 4.400000095367432]],\n", + " 'metadatas': [{'style': 'style1', 'uri': 'img1.png'},\n", + " {'style': 'style2', 'uri': 'img2.png'},\n", + " {'style': 'style1', 'uri': 'img3.png'},\n", + " {'style': 'style1', 'uri': 'img4.png'},\n", + " {'style': 'style1', 'uri': 'img5.png'},\n", + " {'style': 'style1', 'uri': 'img6.png'},\n", + " {'style': 'style1', 'uri': 'img7.png'},\n", + " {'style': 'style1', 'uri': 'img8.png'}],\n", + " 'documents': ['doc1', 'doc2', 'doc3', 'doc4', 'doc5', 'doc6', 'doc7', 'doc8']}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "# Clean up\n", - "client.reset()\n", - "client.persist()\n", - "\n", - "# You can also just delete the persist directory\n", - "!rm -rf db/" + "collection.get(include=[\"embeddings\", \"metadatas\", \"documents\"])" ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# Clean up\n", + "! rm -rf db" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -213,7 +174,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.10.8" }, "orig_nbformat": 4, "vscode": { diff --git a/examples/basic_functionality/where_filtering.ipynb b/examples/basic_functionality/where_filtering.ipynb index 6c09860..1042358 100644 --- a/examples/basic_functionality/where_filtering.ipynb +++ b/examples/basic_functionality/where_filtering.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -20,34 +20,18 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 2, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using embedded DuckDB without persistence: data will be transient\n" - ] - } - ], + "outputs": [], "source": [ "client = chromadb.Client()" ] }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 3, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "No embedding_function provided, using default embedding function: SentenceTransformerEmbeddingFunction\n" - ] - } - ], + "outputs": [], "source": [ "# Create a new chroma collection\n", "collection_name = \"filter_example_collection\"\n", @@ -56,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -89,7 +73,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -97,11 +81,11 @@ "text/plain": [ "{'ids': ['id7'],\n", " 'embeddings': None,\n", - " 'documents': ['A document that discusses international affairs'],\n", - " 'metadatas': [{'status': 'read'}]}" + " 'metadatas': [{'status': 'read'}],\n", + " 'documents': ['A document that discusses international affairs']}" ] }, - "execution_count": 29, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -113,20 +97,20 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'ids': ['id8', 'id1'],\n", + "{'ids': ['id1', 'id8'],\n", " 'embeddings': None,\n", - " 'documents': ['A document that discusses global affairs',\n", - " 'A document that discusses domestic policy'],\n", - " 'metadatas': [{'status': 'unread'}, {'status': 'read'}]}" + " 'metadatas': [{'status': 'read'}, {'status': 'unread'}],\n", + " 'documents': ['A document that discusses domestic policy',\n", + " 'A document that discusses global affairs']}" ] }, - "execution_count": 30, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -138,24 +122,24 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'ids': [['id7', 'id2', 'id8']],\n", - " 'embeddings': None,\n", - " 'documents': [['A document that discusses international affairs',\n", - " 'A document that discusses international affairs',\n", - " 'A document that discusses global affairs']],\n", + " 'distances': [[16.740001678466797, 87.22000122070312, 87.22000122070312]],\n", " 'metadatas': [[{'status': 'read'},\n", " {'status': 'unread'},\n", " {'status': 'unread'}]],\n", - " 'distances': [[16.740001678466797, 87.22000122070312, 87.22000122070312]]}" + " 'embeddings': None,\n", + " 'documents': [['A document that discusses international affairs',\n", + " 'A document that discusses international affairs',\n", + " 'A document that discusses global affairs']]}" ] }, - "execution_count": 31, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -165,6 +149,13 @@ "# Outputs 3 docs because collection only has 3 docs about affairs\n", "collection.query(query_embeddings=[[0, 0, 0]], where_document={\"$contains\": \"affairs\"}, n_results=5)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/pyproject.toml b/pyproject.toml index 0ba1678..69648b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,10 +18,8 @@ dependencies = [ 'pandas >= 1.3', 'requests >= 2.28', 'pydantic>=1.9,<2.0', - 'hnswlib >= 0.7', - 'clickhouse_connect >= 0.5.7', - 'duckdb >= 0.7.1', - 'fastapi==0.85.1', + 'chroma-hnswlib==0.7.1', + 'fastapi>=0.95.2, <0.100.0', 'uvicorn[standard] >= 0.18.3', 'numpy >= 1.21.6', 'posthog >= 2.4.0', @@ -29,8 +27,10 @@ dependencies = [ 'pulsar-client>=3.1.0', 'onnxruntime >= 1.14.1', 'tokenizers >= 0.13.2', + 'pypika >= 0.48.9', 'tqdm >= 4.65.0', 'overrides >= 7.3.1', + 'importlib-resources', 'graphlib_backport >= 1.0.3; python_version < "3.9"' ] diff --git a/requirements.txt b/requirements.txt index 71d4ee8..056132c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,7 @@ -chroma-hnswlib==0.7.0 -clickhouse-connect==0.5.7 -duckdb==0.7.1 -fastapi==0.85.1 +chroma-hnswlib==0.7.1 +fastapi>=0.95.2, <0.100.0 graphlib_backport==1.0.3; python_version < '3.9' +importlib-resources numpy==1.21.6 onnxruntime==1.14.1 overrides==7.3.1