mirror of
https://github.com/placeholder-soft/chroma.git
synced 2026-04-29 12:24:58 +08:00
[ENH]: CIP-5: Large Batch Handling Improvements Proposal (#1077)
- Including only CIP for review. Refs: #1049 ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - New proposal to handle large batches of embeddings gracefully ## Test plan *How are these changes tested?* - [ ] Tests pass locally with `pytest` for python, `yarn test` for js ## Documentation Changes TBD --------- Signed-off-by: sunilkumardash9 <sunilkumardash9@gmail.com> Co-authored-by: Sunil Kumar Dash <47926185+sunilkumardash9@users.noreply.github.com>
This commit is contained in:
@@ -378,3 +378,10 @@ class API(Component, ABC):
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_batch_size(self) -> int:
|
||||
"""Return the maximum number of records that can be submitted in a single call
|
||||
to submit_embeddings."""
|
||||
pass
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, cast
|
||||
from typing import Optional, cast, Tuple
|
||||
from typing import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
@@ -23,6 +23,7 @@ from chromadb.api.types import (
|
||||
GetResult,
|
||||
QueryResult,
|
||||
CollectionMetadata,
|
||||
validate_batch,
|
||||
)
|
||||
from chromadb.auth import (
|
||||
ClientAuthProvider,
|
||||
@@ -38,6 +39,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class FastAPI(API):
|
||||
_settings: Settings
|
||||
_max_batch_size: int = -1
|
||||
|
||||
@staticmethod
|
||||
def _validate_host(host: str) -> None:
|
||||
@@ -296,6 +298,29 @@ class FastAPI(API):
|
||||
raise_chroma_error(resp)
|
||||
return cast(IDs, resp.json())
|
||||
|
||||
def _submit_batch(
|
||||
self,
|
||||
batch: Tuple[
|
||||
IDs, Optional[Embeddings], Optional[Metadatas], Optional[Documents]
|
||||
],
|
||||
url: str,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Submits a batch of embeddings to the database
|
||||
"""
|
||||
resp = self._session.post(
|
||||
self._api_url + url,
|
||||
data=json.dumps(
|
||||
{
|
||||
"ids": batch[0],
|
||||
"embeddings": batch[1],
|
||||
"metadatas": batch[2],
|
||||
"documents": batch[3],
|
||||
}
|
||||
),
|
||||
)
|
||||
return resp
|
||||
|
||||
@override
|
||||
def _add(
|
||||
self,
|
||||
@@ -309,18 +334,9 @@ class FastAPI(API):
|
||||
Adds a batch of embeddings to the database
|
||||
- pass in column oriented data lists
|
||||
"""
|
||||
resp = self._session.post(
|
||||
self._api_url + "/collections/" + str(collection_id) + "/add",
|
||||
data=json.dumps(
|
||||
{
|
||||
"ids": ids,
|
||||
"embeddings": embeddings,
|
||||
"metadatas": metadatas,
|
||||
"documents": documents,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
batch = (ids, embeddings, metadatas, documents)
|
||||
validate_batch(batch, {"max_batch_size": self.max_batch_size})
|
||||
resp = self._submit_batch(batch, "/collections/" + str(collection_id) + "/add")
|
||||
raise_chroma_error(resp)
|
||||
return True
|
||||
|
||||
@@ -337,18 +353,11 @@ class FastAPI(API):
|
||||
Updates a batch of embeddings in the database
|
||||
- pass in column oriented data lists
|
||||
"""
|
||||
resp = self._session.post(
|
||||
self._api_url + "/collections/" + str(collection_id) + "/update",
|
||||
data=json.dumps(
|
||||
{
|
||||
"ids": ids,
|
||||
"embeddings": embeddings,
|
||||
"metadatas": metadatas,
|
||||
"documents": documents,
|
||||
}
|
||||
),
|
||||
batch = (ids, embeddings, metadatas, documents)
|
||||
validate_batch(batch, {"max_batch_size": self.max_batch_size})
|
||||
resp = self._submit_batch(
|
||||
batch, "/collections/" + str(collection_id) + "/update"
|
||||
)
|
||||
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
|
||||
@@ -365,18 +374,11 @@ class FastAPI(API):
|
||||
Upserts a batch of embeddings in the database
|
||||
- pass in column oriented data lists
|
||||
"""
|
||||
resp = self._session.post(
|
||||
self._api_url + "/collections/" + str(collection_id) + "/upsert",
|
||||
data=json.dumps(
|
||||
{
|
||||
"ids": ids,
|
||||
"embeddings": embeddings,
|
||||
"metadatas": metadatas,
|
||||
"documents": documents,
|
||||
}
|
||||
),
|
||||
batch = (ids, embeddings, metadatas, documents)
|
||||
validate_batch(batch, {"max_batch_size": self.max_batch_size})
|
||||
resp = self._submit_batch(
|
||||
batch, "/collections/" + str(collection_id) + "/upsert"
|
||||
)
|
||||
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
|
||||
@@ -434,6 +436,15 @@ class FastAPI(API):
|
||||
"""Returns the settings of the client"""
|
||||
return self._settings
|
||||
|
||||
@property
|
||||
@override
|
||||
def max_batch_size(self) -> int:
|
||||
if self._max_batch_size == -1:
|
||||
resp = self._session.get(self._api_url + "/pre-flight-checks")
|
||||
raise_chroma_error(resp)
|
||||
self._max_batch_size = cast(int, resp.json()["max_batch_size"])
|
||||
return self._max_batch_size
|
||||
|
||||
|
||||
def raise_chroma_error(resp: requests.Response) -> None:
|
||||
"""Raises an error if the response is not ok, using a ChromaError if possible"""
|
||||
|
||||
@@ -26,6 +26,7 @@ from chromadb.api.types import (
|
||||
validate_update_metadata,
|
||||
validate_where,
|
||||
validate_where_document,
|
||||
validate_batch,
|
||||
)
|
||||
from chromadb.telemetry.events import CollectionAddEvent, CollectionDeleteEvent
|
||||
|
||||
@@ -38,6 +39,7 @@ import time
|
||||
import logging
|
||||
import re
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -241,9 +243,18 @@ class SegmentAPI(API):
|
||||
) -> bool:
|
||||
coll = self._get_collection(collection_id)
|
||||
self._manager.hint_use_collection(collection_id, t.Operation.ADD)
|
||||
|
||||
validate_batch(
|
||||
(ids, embeddings, metadatas, documents),
|
||||
{"max_batch_size": self.max_batch_size},
|
||||
)
|
||||
records_to_submit = []
|
||||
for r in _records(t.Operation.ADD, ids, embeddings, metadatas, documents):
|
||||
for r in _records(
|
||||
t.Operation.ADD,
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
):
|
||||
self._validate_embedding_record(coll, r)
|
||||
records_to_submit.append(r)
|
||||
self._producer.submit_embeddings(coll["topic"], records_to_submit)
|
||||
@@ -262,9 +273,18 @@ class SegmentAPI(API):
|
||||
) -> bool:
|
||||
coll = self._get_collection(collection_id)
|
||||
self._manager.hint_use_collection(collection_id, t.Operation.UPDATE)
|
||||
|
||||
validate_batch(
|
||||
(ids, embeddings, metadatas, documents),
|
||||
{"max_batch_size": self.max_batch_size},
|
||||
)
|
||||
records_to_submit = []
|
||||
for r in _records(t.Operation.UPDATE, ids, embeddings, metadatas, documents):
|
||||
for r in _records(
|
||||
t.Operation.UPDATE,
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
):
|
||||
self._validate_embedding_record(coll, r)
|
||||
records_to_submit.append(r)
|
||||
self._producer.submit_embeddings(coll["topic"], records_to_submit)
|
||||
@@ -282,9 +302,18 @@ class SegmentAPI(API):
|
||||
) -> bool:
|
||||
coll = self._get_collection(collection_id)
|
||||
self._manager.hint_use_collection(collection_id, t.Operation.UPSERT)
|
||||
|
||||
validate_batch(
|
||||
(ids, embeddings, metadatas, documents),
|
||||
{"max_batch_size": self.max_batch_size},
|
||||
)
|
||||
records_to_submit = []
|
||||
for r in _records(t.Operation.UPSERT, ids, embeddings, metadatas, documents):
|
||||
for r in _records(
|
||||
t.Operation.UPSERT,
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
):
|
||||
self._validate_embedding_record(coll, r)
|
||||
records_to_submit.append(r)
|
||||
self._producer.submit_embeddings(coll["topic"], records_to_submit)
|
||||
@@ -524,6 +553,11 @@ class SegmentAPI(API):
|
||||
def get_settings(self) -> Settings:
|
||||
return self._settings
|
||||
|
||||
@property
|
||||
@override
|
||||
def max_batch_size(self) -> int:
|
||||
return self._producer.max_batch_size
|
||||
|
||||
def _topic(self, collection_id: UUID) -> str:
|
||||
return f"persistent://{self._tenant_id}/{self._topic_ns}/{collection_id}"
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional, Union, Sequence, TypeVar, List, Dict, Any
|
||||
from typing import Optional, Union, Sequence, TypeVar, List, Dict, Any, Tuple
|
||||
from typing_extensions import Literal, TypedDict, Protocol
|
||||
import chromadb.errors as errors
|
||||
from chromadb.types import (
|
||||
@@ -367,3 +367,13 @@ def validate_embeddings(embeddings: Embeddings) -> Embeddings:
|
||||
f"Expected each value in the embedding to be a int or float, got {embeddings}"
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
def validate_batch(
|
||||
batch: Tuple[IDs, Optional[Embeddings], Optional[Metadatas], Optional[Documents]],
|
||||
limits: Dict[str, Any],
|
||||
) -> None:
|
||||
if len(batch[0]) > limits["max_batch_size"]:
|
||||
raise ValueError(
|
||||
f"Batch size {len(batch[0])} exceeds maximum batch size {limits['max_batch_size']}"
|
||||
)
|
||||
|
||||
@@ -126,6 +126,9 @@ class FastAPI(chromadb.server.Server):
|
||||
self.router.add_api_route("/api/v1/reset", self.reset, methods=["POST"])
|
||||
self.router.add_api_route("/api/v1/version", self.version, methods=["GET"])
|
||||
self.router.add_api_route("/api/v1/heartbeat", self.heartbeat, methods=["GET"])
|
||||
self.router.add_api_route(
|
||||
"/api/v1/pre-flight-checks", self.pre_flight_checks, methods=["GET"]
|
||||
)
|
||||
|
||||
self.router.add_api_route(
|
||||
"/api/v1/collections",
|
||||
@@ -312,3 +315,8 @@ class FastAPI(chromadb.server.Server):
|
||||
include=query.include,
|
||||
)
|
||||
return nnresult
|
||||
|
||||
def pre_flight_checks(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"max_batch_size": self._api.max_batch_size,
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
from typing import cast
|
||||
import random
|
||||
import uuid
|
||||
from random import randint
|
||||
from typing import cast, List, Any, Dict
|
||||
import pytest
|
||||
import hypothesis.strategies as st
|
||||
from hypothesis import given, settings
|
||||
from chromadb.api import API
|
||||
from chromadb.api.types import Embeddings
|
||||
from chromadb.api.types import Embeddings, Metadatas
|
||||
import chromadb.test.property.strategies as strategies
|
||||
import chromadb.test.property.invariants as invariants
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll")
|
||||
|
||||
@@ -44,6 +48,79 @@ def test_add(
|
||||
)
|
||||
|
||||
|
||||
def create_large_recordset(
|
||||
min_size: int = 45000,
|
||||
max_size: int = 50000,
|
||||
) -> strategies.RecordSet:
|
||||
size = randint(min_size, max_size)
|
||||
|
||||
ids = [str(uuid.uuid4()) for _ in range(size)]
|
||||
metadatas = [{"some_key": f"{i}"} for i in range(size)]
|
||||
documents = [f"Document {i}" for i in range(size)]
|
||||
embeddings = [[1, 2, 3] for _ in range(size)]
|
||||
record_set: Dict[str, List[Any]] = {
|
||||
"ids": ids,
|
||||
"embeddings": cast(Embeddings, embeddings),
|
||||
"metadatas": metadatas,
|
||||
"documents": documents,
|
||||
}
|
||||
return record_set
|
||||
|
||||
|
||||
@given(collection=collection_st)
|
||||
@settings(deadline=None, max_examples=1)
|
||||
def test_add_large(api: API, collection: strategies.Collection) -> None:
|
||||
api.reset()
|
||||
record_set = create_large_recordset(
|
||||
min_size=api.max_batch_size,
|
||||
max_size=api.max_batch_size + int(api.max_batch_size * random.random()),
|
||||
)
|
||||
coll = api.create_collection(
|
||||
name=collection.name,
|
||||
metadata=collection.metadata,
|
||||
embedding_function=collection.embedding_function,
|
||||
)
|
||||
normalized_record_set = invariants.wrap_all(record_set)
|
||||
|
||||
if not invariants.is_metadata_valid(normalized_record_set):
|
||||
with pytest.raises(Exception):
|
||||
coll.add(**normalized_record_set)
|
||||
return
|
||||
for batch in create_batches(
|
||||
api=api,
|
||||
ids=cast(List[str], record_set["ids"]),
|
||||
embeddings=cast(Embeddings, record_set["embeddings"]),
|
||||
metadatas=cast(Metadatas, record_set["metadatas"]),
|
||||
documents=cast(List[str], record_set["documents"]),
|
||||
):
|
||||
coll.add(*batch)
|
||||
invariants.count(coll, cast(strategies.RecordSet, normalized_record_set))
|
||||
|
||||
|
||||
@given(collection=collection_st)
|
||||
@settings(deadline=None, max_examples=1)
|
||||
def test_add_large_exceeding(api: API, collection: strategies.Collection) -> None:
|
||||
api.reset()
|
||||
record_set = create_large_recordset(
|
||||
min_size=api.max_batch_size,
|
||||
max_size=api.max_batch_size + int(api.max_batch_size * random.random()),
|
||||
)
|
||||
coll = api.create_collection(
|
||||
name=collection.name,
|
||||
metadata=collection.metadata,
|
||||
embedding_function=collection.embedding_function,
|
||||
)
|
||||
normalized_record_set = invariants.wrap_all(record_set)
|
||||
|
||||
if not invariants.is_metadata_valid(normalized_record_set):
|
||||
with pytest.raises(Exception):
|
||||
coll.add(**normalized_record_set)
|
||||
return
|
||||
with pytest.raises(Exception) as e:
|
||||
coll.add(**record_set)
|
||||
assert "exceeds maximum batch size" in str(e.value)
|
||||
|
||||
|
||||
# TODO: This test fails right now because the ids are not sorted by the input order
|
||||
@pytest.mark.xfail(
|
||||
reason="This is expected to fail right now. We should change the API to sort the \
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# type: ignore
|
||||
import requests
|
||||
|
||||
import chromadb
|
||||
from chromadb.api.fastapi import FastAPI
|
||||
from chromadb.api.types import QueryResult
|
||||
from chromadb.config import Settings
|
||||
import chromadb.server.fastapi
|
||||
@@ -164,6 +166,22 @@ def test_heartbeat(api):
|
||||
assert heartbeat > datetime.now() - timedelta(seconds=10)
|
||||
|
||||
|
||||
def test_max_batch_size(api):
|
||||
print(api)
|
||||
batch_size = api.max_batch_size
|
||||
assert batch_size > 0
|
||||
|
||||
|
||||
def test_pre_flight_checks(api):
|
||||
if not isinstance(api, FastAPI):
|
||||
pytest.skip("Not a FastAPI instance")
|
||||
|
||||
resp = requests.get(f"{api._api_url}/pre-flight-checks")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() is not None
|
||||
assert "max_batch_size" in resp.json().keys()
|
||||
|
||||
|
||||
batch_records = {
|
||||
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
|
||||
"ids": ["https://example.com/1", "https://example.com/2"],
|
||||
|
||||
34
chromadb/utils/batch_utils.py
Normal file
34
chromadb/utils/batch_utils.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Optional, Tuple, List
|
||||
from chromadb.api import API
|
||||
from chromadb.api.types import (
|
||||
Documents,
|
||||
Embeddings,
|
||||
IDs,
|
||||
Metadatas,
|
||||
)
|
||||
|
||||
|
||||
def create_batches(
|
||||
api: API,
|
||||
ids: IDs,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
metadatas: Optional[Metadatas] = None,
|
||||
documents: Optional[Documents] = None,
|
||||
) -> List[Tuple[IDs, Embeddings, Optional[Metadatas], Optional[Documents]]]:
|
||||
_batches: List[
|
||||
Tuple[IDs, Embeddings, Optional[Metadatas], Optional[Documents]]
|
||||
] = []
|
||||
if len(ids) > api.max_batch_size:
|
||||
# create split batches
|
||||
for i in range(0, len(ids), api.max_batch_size):
|
||||
_batches.append(
|
||||
( # type: ignore
|
||||
ids[i : i + api.max_batch_size],
|
||||
embeddings[i : i + api.max_batch_size] if embeddings else None,
|
||||
metadatas[i : i + api.max_batch_size] if metadatas else None,
|
||||
documents[i : i + api.max_batch_size] if documents else None,
|
||||
)
|
||||
)
|
||||
else:
|
||||
_batches.append((ids, embeddings, metadatas, documents)) # type: ignore
|
||||
return _batches
|
||||
59
docs/CIP_5_Large_Batch_Handling_Improvements.md
Normal file
59
docs/CIP_5_Large_Batch_Handling_Improvements.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# CIP-5: Large Batch Handling Improvements Proposal
|
||||
|
||||
## Status
|
||||
|
||||
Current Status: `Under Discussion`
|
||||
|
||||
## **Motivation**
|
||||
|
||||
As users start putting Chroma in its paces and storing ever-increasing datasets, we must ensure that errors
|
||||
related to significant and potentially expensive batches are handled gracefully. This CIP proposes to add a new
|
||||
setting, `max_batch_size` API, on the local segment API and use it to split large batches into smaller ones.
|
||||
|
||||
## **Public Interfaces**
|
||||
|
||||
The following interfaces are impacted:
|
||||
|
||||
- New Server API endpoint - `/pre-flight-checks`
|
||||
- New `max_batch_size` property on the `API` interface
|
||||
- Updated `_add`, `_update` and `_upsert` methods on `chromadb.api.segment.SegmentAPI`
|
||||
- Updated `_add`, `_update` and `_upsert` methods on `chromadb.api.fastapi.FastAPI`
|
||||
- New utility library `batch_utils.py`
|
||||
- New exception raised when batch size exceeds `max_batch_size`
|
||||
|
||||
## **Proposed Changes**
|
||||
|
||||
We propose the following changes:
|
||||
|
||||
- The new `max_batch_size` property is now available in the `API` interface. The property relies on the
|
||||
underlying `Producer` class
|
||||
to fetch the actual value. The property will be implemented by both `chromadb.api.segment.SegmentAPI`
|
||||
and `chromadb.api.fastapi.FastAPI`
|
||||
- `chromadb.api.segment.SegmentAPI` will implement the `max_batch_size` property by fetching the value from the
|
||||
`Producer` class.
|
||||
- `chromadb.api.fastapi.FastAPI` will implement the `max_batch_size` by fetching it from a new `/pre-flight-checks`
|
||||
endpoint on the Server.
|
||||
- New `/pre-flight-checks` endpoint on the Server will return a dictionary with pre-flight checks the client must
|
||||
fulfil to integrate with the server side. For now, we propose using this only for `max_batch_size`, but we can
|
||||
add more checks in the future. The pre-flight checks will be only fetched once per client and cached for the duration
|
||||
of the client's lifetime.
|
||||
- Updated `_add`, `_update` and `_upsert` method on `chromadb.api.segment.SegmentAPI` to validate batch size.
|
||||
- Updated `_add`, `_update` and `_upsert` method on `chromadb.api.fastapi.FastAPI` to validate batch size (client-side
|
||||
validation)
|
||||
- New utility library `batch_utils.py` will contain the logic for splitting batches into smaller ones.
|
||||
|
||||
## **Compatibility, Deprecation, and Migration Plan**
|
||||
|
||||
The change will be fully compatible with existing implementations. The changes will be transparent to the user.
|
||||
|
||||
## **Test Plan**
|
||||
|
||||
New tests:
|
||||
|
||||
- Batch splitting tests for `chromadb.api.segment.SegmentAPI`
|
||||
- Batch splitting tests for `chromadb.api.fastapi.FastAPI`
|
||||
- Tests for `/pre-flight-checks` endpoint
|
||||
|
||||
## **Rejected Alternatives**
|
||||
|
||||
N/A
|
||||
Reference in New Issue
Block a user