mirror of
https://github.com/placeholder-soft/chroma.git
synced 2026-01-12 17:02:54 +08:00
[ENH] Pulsar Producer & Consumer (#921)
## Description of changes *Summarize the changes made by this PR.* - New functionality - Adds a basic pulsar producer, consumer and associated tests. As well as a docker compose for the distributed version of chroma. ## Test plan We added bin/cluster-test.sh, which starts pulsar and allows test_producer_consumer to run the pulsar fixture. ## Documentation Changes None required.
This commit is contained in:
31
.github/workflows/chroma-cluster-test.yml
vendored
Normal file
31
.github/workflows/chroma-cluster-test.yml
vendored
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
name: Chroma Cluster Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- '**'
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python: ['3.7']
|
||||||
|
platform: [ubuntu-latest]
|
||||||
|
testfile: ["chromadb/test/ingest/test_producer_consumer.py"] # Just this one test for now
|
||||||
|
runs-on: ${{ matrix.platform }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
- name: Set up Python ${{ matrix.python }}
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python }}
|
||||||
|
- name: Install test dependencies
|
||||||
|
run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt
|
||||||
|
- name: Integration Test
|
||||||
|
run: bin/cluster-test.sh ${{ matrix.testfile }}
|
||||||
@@ -32,4 +32,4 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
args: [--strict, --ignore-missing-imports, --follow-imports=silent, --disable-error-code=type-abstract]
|
args: [--strict, --ignore-missing-imports, --follow-imports=silent, --disable-error-code=type-abstract]
|
||||||
additional_dependencies: ["types-requests", "pydantic", "overrides", "hypothesis", "pytest", "pypika", "numpy"]
|
additional_dependencies: ["types-requests", "pydantic", "overrides", "hypothesis", "pytest", "pypika", "numpy", "types-protobuf"]
|
||||||
|
|||||||
16
bin/cluster-test.sh
Executable file
16
bin/cluster-test.sh
Executable file
@@ -0,0 +1,16 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
function cleanup {
|
||||||
|
docker compose -f docker-compose.cluster.yml down --rmi local --volumes
|
||||||
|
}
|
||||||
|
|
||||||
|
trap cleanup EXIT
|
||||||
|
|
||||||
|
docker compose -f docker-compose.cluster.yml up -d --wait pulsar
|
||||||
|
|
||||||
|
export CHROMA_CLUSTER_TEST_ONLY=1
|
||||||
|
|
||||||
|
echo testing: python -m pytest "$@"
|
||||||
|
python -m pytest "$@"
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
from chromadb.api import API
|
from chromadb.api import API
|
||||||
from chromadb.config import Settings, System
|
from chromadb.config import Settings, System
|
||||||
from chromadb.db.system import SysDB
|
from chromadb.db.system import SysDB
|
||||||
|
from chromadb.ingest.impl.utils import create_topic_name
|
||||||
from chromadb.segment import SegmentManager, MetadataReader, VectorReader
|
from chromadb.segment import SegmentManager, MetadataReader, VectorReader
|
||||||
from chromadb.telemetry import Telemetry
|
from chromadb.telemetry import Telemetry
|
||||||
from chromadb.ingest import Producer
|
from chromadb.ingest import Producer
|
||||||
@@ -130,6 +131,9 @@ class SegmentAPI(API):
|
|||||||
coll = t.Collection(
|
coll = t.Collection(
|
||||||
id=id, name=name, metadata=metadata, topic=self._topic(id), dimension=None
|
id=id, name=name, metadata=metadata, topic=self._topic(id), dimension=None
|
||||||
)
|
)
|
||||||
|
# TODO: Topic creation right now lives in the producer but it should be moved to the coordinator,
|
||||||
|
# and the producer should just be responsible for publishing messages. Coordinator should
|
||||||
|
# be responsible for all management of topics.
|
||||||
self._producer.create_topic(coll["topic"])
|
self._producer.create_topic(coll["topic"])
|
||||||
segments = self._manager.create_segments(coll)
|
segments = self._manager.create_segments(coll)
|
||||||
self._sysdb.create_collection(coll)
|
self._sysdb.create_collection(coll)
|
||||||
@@ -559,7 +563,7 @@ class SegmentAPI(API):
|
|||||||
return self._producer.max_batch_size
|
return self._producer.max_batch_size
|
||||||
|
|
||||||
def _topic(self, collection_id: UUID) -> str:
|
def _topic(self, collection_id: UUID) -> str:
|
||||||
return f"persistent://{self._tenant_id}/{self._topic_ns}/{collection_id}"
|
return create_topic_name(self._tenant_id, self._topic_ns, str(collection_id))
|
||||||
|
|
||||||
# TODO: This could potentially cause race conditions in a distributed version of the
|
# TODO: This could potentially cause race conditions in a distributed version of the
|
||||||
# system, since the cache is only local.
|
# system, since the cache is only local.
|
||||||
|
|||||||
@@ -92,6 +92,10 @@ class Settings(BaseSettings): # type: ignore
|
|||||||
chroma_server_grpc_port: Optional[str] = None
|
chroma_server_grpc_port: Optional[str] = None
|
||||||
chroma_server_cors_allow_origins: List[str] = [] # eg ["http://localhost:3000"]
|
chroma_server_cors_allow_origins: List[str] = [] # eg ["http://localhost:3000"]
|
||||||
|
|
||||||
|
pulsar_broker_url: Optional[str] = None
|
||||||
|
pulsar_admin_port: Optional[str] = None
|
||||||
|
pulsar_broker_port: Optional[str] = None
|
||||||
|
|
||||||
chroma_server_auth_provider: Optional[str] = None
|
chroma_server_auth_provider: Optional[str] = None
|
||||||
|
|
||||||
@validator("chroma_server_auth_provider", pre=True, always=True, allow_reuse=True)
|
@validator("chroma_server_auth_provider", pre=True, always=True, allow_reuse=True)
|
||||||
|
|||||||
304
chromadb/ingest/impl/pulsar.py
Normal file
304
chromadb/ingest/impl/pulsar.py
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
|
||||||
|
import uuid
|
||||||
|
from chromadb.config import Settings, System
|
||||||
|
from chromadb.ingest import Consumer, ConsumerCallbackFn, Producer
|
||||||
|
from overrides import overrides, EnforceOverrides
|
||||||
|
from uuid import UUID
|
||||||
|
from chromadb.ingest.impl.pulsar_admin import PulsarAdmin
|
||||||
|
from chromadb.ingest.impl.utils import create_pulsar_connection_str
|
||||||
|
from chromadb.proto.convert import from_proto_submit, to_proto_submit
|
||||||
|
import chromadb.proto.chroma_pb2 as proto
|
||||||
|
from chromadb.types import SeqId, SubmitEmbeddingRecord
|
||||||
|
import pulsar
|
||||||
|
from concurrent.futures import wait, Future
|
||||||
|
|
||||||
|
from chromadb.utils.messageid import int_to_pulsar, pulsar_to_int
|
||||||
|
|
||||||
|
|
||||||
|
class PulsarProducer(Producer, EnforceOverrides):
|
||||||
|
_connection_str: str
|
||||||
|
_topic_to_producer: Dict[str, pulsar.Producer]
|
||||||
|
_client: pulsar.Client
|
||||||
|
_admin: PulsarAdmin
|
||||||
|
_settings: Settings
|
||||||
|
|
||||||
|
def __init__(self, system: System) -> None:
|
||||||
|
pulsar_host = system.settings.require("pulsar_broker_url")
|
||||||
|
pulsar_port = system.settings.require("pulsar_broker_port")
|
||||||
|
self._connection_str = create_pulsar_connection_str(pulsar_host, pulsar_port)
|
||||||
|
self._topic_to_producer = {}
|
||||||
|
self._settings = system.settings
|
||||||
|
self._admin = PulsarAdmin(system)
|
||||||
|
super().__init__(system)
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def start(self) -> None:
|
||||||
|
self._client = pulsar.Client(self._connection_str)
|
||||||
|
super().start()
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def stop(self) -> None:
|
||||||
|
self._client.close()
|
||||||
|
super().stop()
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def create_topic(self, topic_name: str) -> None:
|
||||||
|
self._admin.create_topic(topic_name)
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def delete_topic(self, topic_name: str) -> None:
|
||||||
|
self._admin.delete_topic(topic_name)
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def submit_embedding(
|
||||||
|
self, topic_name: str, embedding: SubmitEmbeddingRecord
|
||||||
|
) -> SeqId:
|
||||||
|
"""Add an embedding record to the given topic. Returns the SeqID of the record."""
|
||||||
|
producer = self._get_or_create_producer(topic_name)
|
||||||
|
proto_submit: proto.SubmitEmbeddingRecord = to_proto_submit(embedding)
|
||||||
|
# TODO: batch performance / async
|
||||||
|
msg_id: pulsar.MessageId = producer.send(proto_submit.SerializeToString())
|
||||||
|
return pulsar_to_int(msg_id)
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def submit_embeddings(
|
||||||
|
self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord]
|
||||||
|
) -> Sequence[SeqId]:
|
||||||
|
if not self._running:
|
||||||
|
raise RuntimeError("Component not running")
|
||||||
|
|
||||||
|
if len(embeddings) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if len(embeddings) > self.max_batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"""
|
||||||
|
Cannot submit more than {self.max_batch_size:,} embeddings at once.
|
||||||
|
Please submit your embeddings in batches of size
|
||||||
|
{self.max_batch_size:,} or less.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
producer = self._get_or_create_producer(topic_name)
|
||||||
|
protos_to_submit = [to_proto_submit(embedding) for embedding in embeddings]
|
||||||
|
|
||||||
|
def create_producer_callback(
|
||||||
|
future: Future[int],
|
||||||
|
) -> Callable[[Any, pulsar.MessageId], None]:
|
||||||
|
def producer_callback(res: Any, msg_id: pulsar.MessageId) -> None:
|
||||||
|
if msg_id:
|
||||||
|
future.set_result(pulsar_to_int(msg_id))
|
||||||
|
else:
|
||||||
|
future.set_exception(
|
||||||
|
Exception(
|
||||||
|
"Unknown error while submitting embedding in producer_callback"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return producer_callback
|
||||||
|
|
||||||
|
futures = []
|
||||||
|
for proto_to_submit in protos_to_submit:
|
||||||
|
future: Future[int] = Future()
|
||||||
|
producer.send_async(
|
||||||
|
proto_to_submit.SerializeToString(),
|
||||||
|
callback=create_producer_callback(future),
|
||||||
|
)
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
|
wait(futures)
|
||||||
|
|
||||||
|
results: List[SeqId] = []
|
||||||
|
for future in futures:
|
||||||
|
exception = future.exception()
|
||||||
|
if exception is not None:
|
||||||
|
raise exception
|
||||||
|
results.append(future.result())
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@property
|
||||||
|
@overrides
|
||||||
|
def max_batch_size(self) -> int:
|
||||||
|
# For now, we use 1,000
|
||||||
|
# TODO: tune this to a reasonable value by default
|
||||||
|
return 1000
|
||||||
|
|
||||||
|
def _get_or_create_producer(self, topic_name: str) -> pulsar.Producer:
|
||||||
|
if topic_name not in self._topic_to_producer:
|
||||||
|
producer = self._client.create_producer(topic_name)
|
||||||
|
self._topic_to_producer[topic_name] = producer
|
||||||
|
return self._topic_to_producer[topic_name]
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def reset_state(self) -> None:
|
||||||
|
if not self._settings.require("allow_reset"):
|
||||||
|
raise ValueError(
|
||||||
|
"Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted."
|
||||||
|
)
|
||||||
|
for topic_name in self._topic_to_producer:
|
||||||
|
self._admin.delete_topic(topic_name)
|
||||||
|
self._topic_to_producer = {}
|
||||||
|
super().reset_state()
|
||||||
|
|
||||||
|
|
||||||
|
class PulsarConsumer(Consumer, EnforceOverrides):
|
||||||
|
class PulsarSubscription:
|
||||||
|
id: UUID
|
||||||
|
topic_name: str
|
||||||
|
start: int
|
||||||
|
end: int
|
||||||
|
callback: ConsumerCallbackFn
|
||||||
|
consumer: pulsar.Consumer
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
id: UUID,
|
||||||
|
topic_name: str,
|
||||||
|
start: int,
|
||||||
|
end: int,
|
||||||
|
callback: ConsumerCallbackFn,
|
||||||
|
consumer: pulsar.Consumer,
|
||||||
|
):
|
||||||
|
self.id = id
|
||||||
|
self.topic_name = topic_name
|
||||||
|
self.start = start
|
||||||
|
self.end = end
|
||||||
|
self.callback = callback
|
||||||
|
self.consumer = consumer
|
||||||
|
|
||||||
|
_connection_str: str
|
||||||
|
_client: pulsar.Client
|
||||||
|
_subscriptions: Dict[str, Set[PulsarSubscription]]
|
||||||
|
_settings: Settings
|
||||||
|
|
||||||
|
def __init__(self, system: System) -> None:
|
||||||
|
pulsar_host = system.settings.require("pulsar_broker_url")
|
||||||
|
pulsar_port = system.settings.require("pulsar_broker_port")
|
||||||
|
self._connection_str = create_pulsar_connection_str(pulsar_host, pulsar_port)
|
||||||
|
self._subscriptions = defaultdict(set)
|
||||||
|
self._settings = system.settings
|
||||||
|
super().__init__(system)
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def start(self) -> None:
|
||||||
|
self._client = pulsar.Client(self._connection_str)
|
||||||
|
super().start()
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def stop(self) -> None:
|
||||||
|
self._client.close()
|
||||||
|
super().stop()
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def subscribe(
|
||||||
|
self,
|
||||||
|
topic_name: str,
|
||||||
|
consume_fn: ConsumerCallbackFn,
|
||||||
|
start: Optional[SeqId] = None,
|
||||||
|
end: Optional[SeqId] = None,
|
||||||
|
id: Optional[UUID] = None,
|
||||||
|
) -> UUID:
|
||||||
|
"""Register a function that will be called to recieve embeddings for a given
|
||||||
|
topic. The given function may be called any number of times, with any number of
|
||||||
|
records, and may be called concurrently.
|
||||||
|
|
||||||
|
Only records between start (exclusive) and end (inclusive) SeqIDs will be
|
||||||
|
returned. If start is None, the first record returned will be the next record
|
||||||
|
generated, not including those generated before creating the subscription. If
|
||||||
|
end is None, the consumer will consume indefinitely, otherwise it will
|
||||||
|
automatically be unsubscribed when the end SeqID is reached.
|
||||||
|
|
||||||
|
If the function throws an exception, the function may be called again with the
|
||||||
|
same or different records.
|
||||||
|
|
||||||
|
Takes an optional UUID as a unique subscription ID. If no ID is provided, a new
|
||||||
|
ID will be generated and returned."""
|
||||||
|
if not self._running:
|
||||||
|
raise RuntimeError("Consumer must be started before subscribing")
|
||||||
|
|
||||||
|
subscription_id = (
|
||||||
|
id or uuid.uuid4()
|
||||||
|
) # TODO: this should really be created by the coordinator and stored in sysdb
|
||||||
|
|
||||||
|
start, end = self._validate_range(start, end)
|
||||||
|
|
||||||
|
def wrap_callback(consumer: pulsar.Consumer, message: pulsar.Message) -> None:
|
||||||
|
msg_data = message.data()
|
||||||
|
msg_id = pulsar_to_int(message.message_id())
|
||||||
|
submit_embedding_record = proto.SubmitEmbeddingRecord()
|
||||||
|
proto.SubmitEmbeddingRecord.ParseFromString(
|
||||||
|
submit_embedding_record, msg_data
|
||||||
|
)
|
||||||
|
embedding_record = from_proto_submit(submit_embedding_record, msg_id)
|
||||||
|
consume_fn([embedding_record])
|
||||||
|
consumer.acknowledge(message)
|
||||||
|
if msg_id == end:
|
||||||
|
self.unsubscribe(subscription_id)
|
||||||
|
|
||||||
|
consumer = self._client.subscribe(
|
||||||
|
topic_name,
|
||||||
|
subscription_id.hex,
|
||||||
|
message_listener=wrap_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
subscription = self.PulsarSubscription(
|
||||||
|
subscription_id, topic_name, start, end, consume_fn, consumer
|
||||||
|
)
|
||||||
|
self._subscriptions[topic_name].add(subscription)
|
||||||
|
|
||||||
|
# NOTE: For some reason the seek() method expects a shadowed MessageId type
|
||||||
|
# which resides in _msg_id.
|
||||||
|
consumer.seek(int_to_pulsar(start)._msg_id)
|
||||||
|
|
||||||
|
return subscription_id
|
||||||
|
|
||||||
|
def _validate_range(
|
||||||
|
self, start: Optional[SeqId], end: Optional[SeqId]
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""Validate and normalize the start and end SeqIDs for a subscription using this
|
||||||
|
impl."""
|
||||||
|
start = start or pulsar_to_int(pulsar.MessageId.latest)
|
||||||
|
end = end or self.max_seqid()
|
||||||
|
if not isinstance(start, int) or not isinstance(end, int):
|
||||||
|
raise ValueError("SeqIDs must be integers")
|
||||||
|
if start >= end:
|
||||||
|
raise ValueError(f"Invalid SeqID range: {start} to {end}")
|
||||||
|
return start, end
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def unsubscribe(self, subscription_id: UUID) -> None:
|
||||||
|
"""Unregister a subscription. The consume function will no longer be invoked,
|
||||||
|
and resources associated with the subscription will be released."""
|
||||||
|
for topic_name, subscriptions in self._subscriptions.items():
|
||||||
|
for subscription in subscriptions:
|
||||||
|
if subscription.id == subscription_id:
|
||||||
|
subscription.consumer.close()
|
||||||
|
subscriptions.remove(subscription)
|
||||||
|
if len(subscriptions) == 0:
|
||||||
|
del self._subscriptions[topic_name]
|
||||||
|
return
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def min_seqid(self) -> SeqId:
|
||||||
|
"""Return the minimum possible SeqID in this implementation."""
|
||||||
|
return pulsar_to_int(pulsar.MessageId.earliest)
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def max_seqid(self) -> SeqId:
|
||||||
|
"""Return the maximum possible SeqID in this implementation."""
|
||||||
|
return 2**192 - 1
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def reset_state(self) -> None:
|
||||||
|
if not self._settings.require("allow_reset"):
|
||||||
|
raise ValueError(
|
||||||
|
"Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted."
|
||||||
|
)
|
||||||
|
for topic_name, subscriptions in self._subscriptions.items():
|
||||||
|
for subscription in subscriptions:
|
||||||
|
subscription.consumer.close()
|
||||||
|
self._subscriptions = defaultdict(set)
|
||||||
|
super().reset_state()
|
||||||
81
chromadb/ingest/impl/pulsar_admin.py
Normal file
81
chromadb/ingest/impl/pulsar_admin.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
# A thin wrapper around the pulsar admin api
|
||||||
|
import requests
|
||||||
|
from chromadb.config import System
|
||||||
|
from chromadb.ingest.impl.utils import parse_topic_name
|
||||||
|
|
||||||
|
|
||||||
|
class PulsarAdmin:
|
||||||
|
"""A thin wrapper around the pulsar admin api, only used for interim development towards distributed chroma.
|
||||||
|
This functionality will be moved to the chroma coordinator."""
|
||||||
|
|
||||||
|
_connection_str: str
|
||||||
|
|
||||||
|
def __init__(self, system: System):
|
||||||
|
pulsar_host = system.settings.require("pulsar_broker_url")
|
||||||
|
pulsar_port = system.settings.require("pulsar_admin_port")
|
||||||
|
self._connection_str = f"http://{pulsar_host}:{pulsar_port}"
|
||||||
|
|
||||||
|
# Create the default tenant and namespace
|
||||||
|
# This is a temporary workaround until we have a proper tenant/namespace management system
|
||||||
|
self.create_tenant("default")
|
||||||
|
self.create_namespace("default", "default")
|
||||||
|
|
||||||
|
def create_tenant(self, tenant: str) -> None:
|
||||||
|
"""Make a PUT request to the admin api to create the tenant"""
|
||||||
|
|
||||||
|
path = f"/admin/v2/tenants/{tenant}"
|
||||||
|
url = self._connection_str + path
|
||||||
|
response = requests.put(
|
||||||
|
url, json={"allowedClusters": ["standalone"], "adminRoles": []}
|
||||||
|
) # TODO: how to manage clusters?
|
||||||
|
|
||||||
|
if response.status_code != 204 and response.status_code != 409:
|
||||||
|
raise RuntimeError(f"Failed to create tenant {tenant}")
|
||||||
|
|
||||||
|
def create_namespace(self, tenant: str, namespace: str) -> None:
|
||||||
|
"""Make a PUT request to the admin api to create the namespace"""
|
||||||
|
|
||||||
|
path = f"/admin/v2/namespaces/{tenant}/{namespace}"
|
||||||
|
url = self._connection_str + path
|
||||||
|
response = requests.put(url)
|
||||||
|
|
||||||
|
if response.status_code != 204 and response.status_code != 409:
|
||||||
|
raise RuntimeError(f"Failed to create namespace {namespace}")
|
||||||
|
|
||||||
|
def create_topic(self, topic: str) -> None:
|
||||||
|
# TODO: support non-persistent topics?
|
||||||
|
tenant, namespace, topic_name = parse_topic_name(topic)
|
||||||
|
|
||||||
|
if tenant != "default":
|
||||||
|
raise ValueError(f"Only the default tenant is supported, got {tenant}")
|
||||||
|
if namespace != "default":
|
||||||
|
raise ValueError(
|
||||||
|
f"Only the default namespace is supported, got {namespace}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make a PUT request to the admin api to create the topic
|
||||||
|
path = f"/admin/v2/persistent/{tenant}/{namespace}/{topic_name}"
|
||||||
|
url = self._connection_str + path
|
||||||
|
response = requests.put(url)
|
||||||
|
|
||||||
|
if response.status_code != 204 and response.status_code != 409:
|
||||||
|
raise RuntimeError(f"Failed to create topic {topic_name}")
|
||||||
|
|
||||||
|
def delete_topic(self, topic: str) -> None:
|
||||||
|
tenant, namespace, topic_name = parse_topic_name(topic)
|
||||||
|
|
||||||
|
if tenant != "default":
|
||||||
|
raise ValueError(f"Only the default tenant is supported, got {tenant}")
|
||||||
|
if namespace != "default":
|
||||||
|
raise ValueError(
|
||||||
|
f"Only the default namespace is supported, got {namespace}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make a PUT request to the admin api to delete the topic
|
||||||
|
path = f"/admin/v2/persistent/{tenant}/{namespace}/{topic_name}"
|
||||||
|
# Force delete the topic
|
||||||
|
path += "?force=true"
|
||||||
|
url = self._connection_str + path
|
||||||
|
response = requests.delete(url)
|
||||||
|
if response.status_code != 204 and response.status_code != 409:
|
||||||
|
raise RuntimeError(f"Failed to delete topic {topic_name}")
|
||||||
20
chromadb/ingest/impl/utils.py
Normal file
20
chromadb/ingest/impl/utils.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import re
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
topic_regex = r"persistent:\/\/(?P<tenant>.+)\/(?P<namespace>.+)\/(?P<topic>.+)"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_topic_name(topic_name: str) -> Tuple[str, str, str]:
|
||||||
|
"""Parse the topic name into the tenant, namespace and topic name"""
|
||||||
|
match = re.match(topic_regex, topic_name)
|
||||||
|
if not match:
|
||||||
|
raise ValueError(f"Invalid topic name: {topic_name}")
|
||||||
|
return match.group("tenant"), match.group("namespace"), match.group("topic")
|
||||||
|
|
||||||
|
|
||||||
|
def create_pulsar_connection_str(host: str, port: str) -> str:
|
||||||
|
return f"pulsar://{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
|
def create_topic_name(tenant: str, namespace: str, topic: str) -> str:
|
||||||
|
return f"persistent://{tenant}/{namespace}/{topic}"
|
||||||
40
chromadb/proto/chroma.proto
Normal file
40
chromadb/proto/chroma.proto
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package chroma;
|
||||||
|
|
||||||
|
enum Operation {
|
||||||
|
ADD = 0;
|
||||||
|
UPDATE = 1;
|
||||||
|
UPSERT = 2;
|
||||||
|
DELETE = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ScalarEncoding {
|
||||||
|
FLOAT32 = 0;
|
||||||
|
INT32 = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Vector {
|
||||||
|
int32 dimension = 1;
|
||||||
|
bytes vector = 2;
|
||||||
|
ScalarEncoding encoding = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message UpdateMetadataValue {
|
||||||
|
oneof value {
|
||||||
|
string string_value = 1;
|
||||||
|
int64 int_value = 2;
|
||||||
|
double float_value = 3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
message UpdateMetadata {
|
||||||
|
map<string, UpdateMetadataValue> metadata = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message SubmitEmbeddingRecord {
|
||||||
|
string id = 1;
|
||||||
|
optional Vector vector = 2;
|
||||||
|
optional UpdateMetadata metadata = 3;
|
||||||
|
Operation operation = 4;
|
||||||
|
}
|
||||||
42
chromadb/proto/chroma_pb2.py
Normal file
42
chromadb/proto/chroma_pb2.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
|
# source: chromadb/proto/chroma.proto
|
||||||
|
"""Generated protocol buffer code."""
|
||||||
|
from google.protobuf import descriptor as _descriptor
|
||||||
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||||
|
from google.protobuf import symbol_database as _symbol_database
|
||||||
|
from google.protobuf.internal import builder as _builder
|
||||||
|
|
||||||
|
# @@protoc_insertion_point(imports)
|
||||||
|
|
||||||
|
_sym_db = _symbol_database.Default()
|
||||||
|
|
||||||
|
|
||||||
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
||||||
|
b'\n\x1b\x63hromadb/proto/chroma.proto\x12\x06\x63hroma"U\n\x06Vector\x12\x11\n\tdimension\x18\x01 \x01(\x05\x12\x0e\n\x06vector\x18\x02 \x01(\x0c\x12(\n\x08\x65ncoding\x18\x03 \x01(\x0e\x32\x16.chroma.ScalarEncoding"b\n\x13UpdateMetadataValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x42\x07\n\x05value"\x96\x01\n\x0eUpdateMetadata\x12\x36\n\x08metadata\x18\x01 \x03(\x0b\x32$.chroma.UpdateMetadata.MetadataEntry\x1aL\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.chroma.UpdateMetadataValue:\x02\x38\x01"\xb5\x01\n\x15SubmitEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12#\n\x06vector\x18\x02 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x01\x88\x01\x01\x12$\n\toperation\x18\x04 \x01(\x0e\x32\x11.chroma.OperationB\t\n\x07_vectorB\x0b\n\t_metadata*8\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\n\n\x06UPSERT\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*(\n\x0eScalarEncoding\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\t\n\x05INT32\x10\x01\x62\x06proto3'
|
||||||
|
)
|
||||||
|
|
||||||
|
_globals = globals()
|
||||||
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||||
|
_builder.BuildTopDescriptorsAndMessages(
|
||||||
|
DESCRIPTOR, "chromadb.proto.chroma_pb2", _globals
|
||||||
|
)
|
||||||
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||||
|
DESCRIPTOR._options = None
|
||||||
|
_UPDATEMETADATA_METADATAENTRY._options = None
|
||||||
|
_UPDATEMETADATA_METADATAENTRY._serialized_options = b"8\001"
|
||||||
|
_globals["_OPERATION"]._serialized_start = 563
|
||||||
|
_globals["_OPERATION"]._serialized_end = 619
|
||||||
|
_globals["_SCALARENCODING"]._serialized_start = 621
|
||||||
|
_globals["_SCALARENCODING"]._serialized_end = 661
|
||||||
|
_globals["_VECTOR"]._serialized_start = 39
|
||||||
|
_globals["_VECTOR"]._serialized_end = 124
|
||||||
|
_globals["_UPDATEMETADATAVALUE"]._serialized_start = 126
|
||||||
|
_globals["_UPDATEMETADATAVALUE"]._serialized_end = 224
|
||||||
|
_globals["_UPDATEMETADATA"]._serialized_start = 227
|
||||||
|
_globals["_UPDATEMETADATA"]._serialized_end = 377
|
||||||
|
_globals["_UPDATEMETADATA_METADATAENTRY"]._serialized_start = 301
|
||||||
|
_globals["_UPDATEMETADATA_METADATAENTRY"]._serialized_end = 377
|
||||||
|
_globals["_SUBMITEMBEDDINGRECORD"]._serialized_start = 380
|
||||||
|
_globals["_SUBMITEMBEDDINGRECORD"]._serialized_end = 561
|
||||||
|
# @@protoc_insertion_point(module_scope)
|
||||||
247
chromadb/proto/chroma_pb2.pyi
Normal file
247
chromadb/proto/chroma_pb2.pyi
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
"""
|
||||||
|
@generated by mypy-protobuf. Do not edit manually!
|
||||||
|
isort:skip_file
|
||||||
|
"""
|
||||||
|
import builtins
|
||||||
|
import collections.abc
|
||||||
|
import google.protobuf.descriptor
|
||||||
|
import google.protobuf.internal.containers
|
||||||
|
import google.protobuf.internal.enum_type_wrapper
|
||||||
|
import google.protobuf.message
|
||||||
|
import sys
|
||||||
|
import typing
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
import typing as typing_extensions
|
||||||
|
else:
|
||||||
|
import typing_extensions
|
||||||
|
|
||||||
|
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||||
|
|
||||||
|
class _Operation:
|
||||||
|
ValueType = typing.NewType("ValueType", builtins.int)
|
||||||
|
V: typing_extensions.TypeAlias = ValueType
|
||||||
|
|
||||||
|
class _OperationEnumTypeWrapper(
|
||||||
|
google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_Operation.ValueType],
|
||||||
|
builtins.type,
|
||||||
|
):
|
||||||
|
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||||
|
ADD: _Operation.ValueType # 0
|
||||||
|
UPDATE: _Operation.ValueType # 1
|
||||||
|
UPSERT: _Operation.ValueType # 2
|
||||||
|
DELETE: _Operation.ValueType # 3
|
||||||
|
|
||||||
|
class Operation(_Operation, metaclass=_OperationEnumTypeWrapper): ...
|
||||||
|
|
||||||
|
ADD: Operation.ValueType # 0
|
||||||
|
UPDATE: Operation.ValueType # 1
|
||||||
|
UPSERT: Operation.ValueType # 2
|
||||||
|
DELETE: Operation.ValueType # 3
|
||||||
|
global___Operation = Operation
|
||||||
|
|
||||||
|
class _ScalarEncoding:
|
||||||
|
ValueType = typing.NewType("ValueType", builtins.int)
|
||||||
|
V: typing_extensions.TypeAlias = ValueType
|
||||||
|
|
||||||
|
class _ScalarEncodingEnumTypeWrapper(
|
||||||
|
google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[
|
||||||
|
_ScalarEncoding.ValueType
|
||||||
|
],
|
||||||
|
builtins.type,
|
||||||
|
):
|
||||||
|
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||||
|
FLOAT32: _ScalarEncoding.ValueType # 0
|
||||||
|
INT32: _ScalarEncoding.ValueType # 1
|
||||||
|
|
||||||
|
class ScalarEncoding(_ScalarEncoding, metaclass=_ScalarEncodingEnumTypeWrapper): ...
|
||||||
|
|
||||||
|
FLOAT32: ScalarEncoding.ValueType # 0
|
||||||
|
INT32: ScalarEncoding.ValueType # 1
|
||||||
|
global___ScalarEncoding = ScalarEncoding
|
||||||
|
|
||||||
|
@typing_extensions.final
|
||||||
|
class Vector(google.protobuf.message.Message):
|
||||||
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||||
|
|
||||||
|
DIMENSION_FIELD_NUMBER: builtins.int
|
||||||
|
VECTOR_FIELD_NUMBER: builtins.int
|
||||||
|
ENCODING_FIELD_NUMBER: builtins.int
|
||||||
|
dimension: builtins.int
|
||||||
|
vector: builtins.bytes
|
||||||
|
encoding: global___ScalarEncoding.ValueType
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dimension: builtins.int = ...,
|
||||||
|
vector: builtins.bytes = ...,
|
||||||
|
encoding: global___ScalarEncoding.ValueType = ...,
|
||||||
|
) -> None: ...
|
||||||
|
def ClearField(
|
||||||
|
self,
|
||||||
|
field_name: typing_extensions.Literal[
|
||||||
|
"dimension", b"dimension", "encoding", b"encoding", "vector", b"vector"
|
||||||
|
],
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
global___Vector = Vector
|
||||||
|
|
||||||
|
@typing_extensions.final
|
||||||
|
class UpdateMetadataValue(google.protobuf.message.Message):
|
||||||
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||||
|
|
||||||
|
STRING_VALUE_FIELD_NUMBER: builtins.int
|
||||||
|
INT_VALUE_FIELD_NUMBER: builtins.int
|
||||||
|
FLOAT_VALUE_FIELD_NUMBER: builtins.int
|
||||||
|
string_value: builtins.str
|
||||||
|
int_value: builtins.int
|
||||||
|
float_value: builtins.float
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
string_value: builtins.str = ...,
|
||||||
|
int_value: builtins.int = ...,
|
||||||
|
float_value: builtins.float = ...,
|
||||||
|
) -> None: ...
|
||||||
|
def HasField(
|
||||||
|
self,
|
||||||
|
field_name: typing_extensions.Literal[
|
||||||
|
"float_value",
|
||||||
|
b"float_value",
|
||||||
|
"int_value",
|
||||||
|
b"int_value",
|
||||||
|
"string_value",
|
||||||
|
b"string_value",
|
||||||
|
"value",
|
||||||
|
b"value",
|
||||||
|
],
|
||||||
|
) -> builtins.bool: ...
|
||||||
|
def ClearField(
|
||||||
|
self,
|
||||||
|
field_name: typing_extensions.Literal[
|
||||||
|
"float_value",
|
||||||
|
b"float_value",
|
||||||
|
"int_value",
|
||||||
|
b"int_value",
|
||||||
|
"string_value",
|
||||||
|
b"string_value",
|
||||||
|
"value",
|
||||||
|
b"value",
|
||||||
|
],
|
||||||
|
) -> None: ...
|
||||||
|
def WhichOneof(
|
||||||
|
self, oneof_group: typing_extensions.Literal["value", b"value"]
|
||||||
|
) -> (
|
||||||
|
typing_extensions.Literal["string_value", "int_value", "float_value"] | None
|
||||||
|
): ...
|
||||||
|
|
||||||
|
global___UpdateMetadataValue = UpdateMetadataValue
|
||||||
|
|
||||||
|
@typing_extensions.final
|
||||||
|
class UpdateMetadata(google.protobuf.message.Message):
|
||||||
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||||
|
|
||||||
|
@typing_extensions.final
|
||||||
|
class MetadataEntry(google.protobuf.message.Message):
|
||||||
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||||
|
|
||||||
|
KEY_FIELD_NUMBER: builtins.int
|
||||||
|
VALUE_FIELD_NUMBER: builtins.int
|
||||||
|
key: builtins.str
|
||||||
|
@property
|
||||||
|
def value(self) -> global___UpdateMetadataValue: ...
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
key: builtins.str = ...,
|
||||||
|
value: global___UpdateMetadataValue | None = ...,
|
||||||
|
) -> None: ...
|
||||||
|
def HasField(
|
||||||
|
self, field_name: typing_extensions.Literal["value", b"value"]
|
||||||
|
) -> builtins.bool: ...
|
||||||
|
def ClearField(
|
||||||
|
self,
|
||||||
|
field_name: typing_extensions.Literal["key", b"key", "value", b"value"],
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
METADATA_FIELD_NUMBER: builtins.int
|
||||||
|
@property
|
||||||
|
def metadata(
|
||||||
|
self,
|
||||||
|
) -> google.protobuf.internal.containers.MessageMap[
|
||||||
|
builtins.str, global___UpdateMetadataValue
|
||||||
|
]: ...
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
metadata: collections.abc.Mapping[builtins.str, global___UpdateMetadataValue]
|
||||||
|
| None = ...,
|
||||||
|
) -> None: ...
|
||||||
|
def ClearField(
|
||||||
|
self, field_name: typing_extensions.Literal["metadata", b"metadata"]
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
global___UpdateMetadata = UpdateMetadata
|
||||||
|
|
||||||
|
@typing_extensions.final
|
||||||
|
class SubmitEmbeddingRecord(google.protobuf.message.Message):
|
||||||
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||||
|
|
||||||
|
ID_FIELD_NUMBER: builtins.int
|
||||||
|
VECTOR_FIELD_NUMBER: builtins.int
|
||||||
|
METADATA_FIELD_NUMBER: builtins.int
|
||||||
|
OPERATION_FIELD_NUMBER: builtins.int
|
||||||
|
id: builtins.str
|
||||||
|
@property
|
||||||
|
def vector(self) -> global___Vector: ...
|
||||||
|
@property
|
||||||
|
def metadata(self) -> global___UpdateMetadata: ...
|
||||||
|
operation: global___Operation.ValueType
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
id: builtins.str = ...,
|
||||||
|
vector: global___Vector | None = ...,
|
||||||
|
metadata: global___UpdateMetadata | None = ...,
|
||||||
|
operation: global___Operation.ValueType = ...,
|
||||||
|
) -> None: ...
|
||||||
|
def HasField(
|
||||||
|
self,
|
||||||
|
field_name: typing_extensions.Literal[
|
||||||
|
"_metadata",
|
||||||
|
b"_metadata",
|
||||||
|
"_vector",
|
||||||
|
b"_vector",
|
||||||
|
"metadata",
|
||||||
|
b"metadata",
|
||||||
|
"vector",
|
||||||
|
b"vector",
|
||||||
|
],
|
||||||
|
) -> builtins.bool: ...
|
||||||
|
def ClearField(
|
||||||
|
self,
|
||||||
|
field_name: typing_extensions.Literal[
|
||||||
|
"_metadata",
|
||||||
|
b"_metadata",
|
||||||
|
"_vector",
|
||||||
|
b"_vector",
|
||||||
|
"id",
|
||||||
|
b"id",
|
||||||
|
"metadata",
|
||||||
|
b"metadata",
|
||||||
|
"operation",
|
||||||
|
b"operation",
|
||||||
|
"vector",
|
||||||
|
b"vector",
|
||||||
|
],
|
||||||
|
) -> None: ...
|
||||||
|
@typing.overload
|
||||||
|
def WhichOneof(
|
||||||
|
self, oneof_group: typing_extensions.Literal["_metadata", b"_metadata"]
|
||||||
|
) -> typing_extensions.Literal["metadata"] | None: ...
|
||||||
|
@typing.overload
|
||||||
|
def WhichOneof(
|
||||||
|
self, oneof_group: typing_extensions.Literal["_vector", b"_vector"]
|
||||||
|
) -> typing_extensions.Literal["vector"] | None: ...
|
||||||
|
|
||||||
|
global___SubmitEmbeddingRecord = SubmitEmbeddingRecord
|
||||||
150
chromadb/proto/convert.py
Normal file
150
chromadb/proto/convert.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
import array
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
from chromadb.api.types import Embedding
|
||||||
|
import chromadb.proto.chroma_pb2 as proto
|
||||||
|
from chromadb.types import (
|
||||||
|
EmbeddingRecord,
|
||||||
|
Metadata,
|
||||||
|
Operation,
|
||||||
|
ScalarEncoding,
|
||||||
|
SeqId,
|
||||||
|
SubmitEmbeddingRecord,
|
||||||
|
Vector,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def to_proto_vector(vector: Vector, encoding: ScalarEncoding) -> proto.Vector:
|
||||||
|
if encoding == ScalarEncoding.FLOAT32:
|
||||||
|
as_bytes = array.array("f", vector).tobytes()
|
||||||
|
proto_encoding = proto.ScalarEncoding.FLOAT32
|
||||||
|
elif encoding == ScalarEncoding.INT32:
|
||||||
|
as_bytes = array.array("i", vector).tobytes()
|
||||||
|
proto_encoding = proto.ScalarEncoding.INT32
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown encoding {encoding}, expected one of {ScalarEncoding.FLOAT32} \
|
||||||
|
or {ScalarEncoding.INT32}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return proto.Vector(dimension=len(vector), vector=as_bytes, encoding=proto_encoding)
|
||||||
|
|
||||||
|
|
||||||
|
def from_proto_vector(vector: proto.Vector) -> Tuple[Embedding, ScalarEncoding]:
|
||||||
|
encoding = vector.encoding
|
||||||
|
as_array: array.array[float] | array.array[int]
|
||||||
|
if encoding == proto.ScalarEncoding.FLOAT32:
|
||||||
|
as_array = array.array("f")
|
||||||
|
out_encoding = ScalarEncoding.FLOAT32
|
||||||
|
elif encoding == proto.ScalarEncoding.INT32:
|
||||||
|
as_array = array.array("i")
|
||||||
|
out_encoding = ScalarEncoding.INT32
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown encoding {encoding}, expected one of \
|
||||||
|
{proto.ScalarEncoding.FLOAT32} or {proto.ScalarEncoding.INT32}"
|
||||||
|
)
|
||||||
|
|
||||||
|
as_array.frombytes(vector.vector)
|
||||||
|
return (as_array.tolist(), out_encoding)
|
||||||
|
|
||||||
|
|
||||||
|
def from_proto_operation(operation: proto.Operation.ValueType) -> Operation:
|
||||||
|
if operation == proto.Operation.ADD:
|
||||||
|
return Operation.ADD
|
||||||
|
elif operation == proto.Operation.UPDATE:
|
||||||
|
return Operation.UPDATE
|
||||||
|
elif operation == proto.Operation.UPSERT:
|
||||||
|
return Operation.UPSERT
|
||||||
|
elif operation == proto.Operation.DELETE:
|
||||||
|
return Operation.DELETE
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unknown operation {operation}") # TODO: full error
|
||||||
|
|
||||||
|
|
||||||
|
def from_proto_metadata(metadata: proto.UpdateMetadata) -> Optional[Metadata]:
|
||||||
|
if not metadata.metadata:
|
||||||
|
return None
|
||||||
|
out_metadata = {}
|
||||||
|
for key, value in metadata.metadata.items():
|
||||||
|
if value.HasField("string_value"):
|
||||||
|
out_metadata[key] = value.string_value
|
||||||
|
elif value.HasField("int_value"):
|
||||||
|
out_metadata[key] = value.int_value
|
||||||
|
elif value.HasField("float_value"):
|
||||||
|
out_metadata[key] = value.float_value
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unknown metadata value type {value}")
|
||||||
|
return out_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def from_proto_submit(
|
||||||
|
submit_embedding_record: proto.SubmitEmbeddingRecord, seq_id: SeqId
|
||||||
|
) -> EmbeddingRecord:
|
||||||
|
embedding, encoding = from_proto_vector(submit_embedding_record.vector)
|
||||||
|
record = EmbeddingRecord(
|
||||||
|
id=submit_embedding_record.id,
|
||||||
|
seq_id=seq_id,
|
||||||
|
embedding=embedding,
|
||||||
|
encoding=encoding,
|
||||||
|
metadata=from_proto_metadata(submit_embedding_record.metadata),
|
||||||
|
operation=from_proto_operation(submit_embedding_record.operation),
|
||||||
|
)
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
def to_proto_metadata_update_value(
|
||||||
|
value: Union[str, int, float, None]
|
||||||
|
) -> proto.UpdateMetadataValue:
|
||||||
|
if isinstance(value, str):
|
||||||
|
return proto.UpdateMetadataValue(string_value=value)
|
||||||
|
elif isinstance(value, int):
|
||||||
|
return proto.UpdateMetadataValue(int_value=value)
|
||||||
|
elif isinstance(value, float):
|
||||||
|
return proto.UpdateMetadataValue(float_value=value)
|
||||||
|
elif value is None:
|
||||||
|
return proto.UpdateMetadataValue()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown metadata value type {type(value)}, expected one of str, int, \
|
||||||
|
float, or None"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def to_proto_operation(operation: Operation) -> proto.Operation.ValueType:
|
||||||
|
if operation == Operation.ADD:
|
||||||
|
return proto.Operation.ADD
|
||||||
|
elif operation == Operation.UPDATE:
|
||||||
|
return proto.Operation.UPDATE
|
||||||
|
elif operation == Operation.UPSERT:
|
||||||
|
return proto.Operation.UPSERT
|
||||||
|
elif operation == Operation.DELETE:
|
||||||
|
return proto.Operation.DELETE
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown operation {operation}, expected one of {Operation.ADD}, \
|
||||||
|
{Operation.UPDATE}, {Operation.UPDATE}, or {Operation.DELETE}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def to_proto_submit(
|
||||||
|
submit_record: SubmitEmbeddingRecord,
|
||||||
|
) -> proto.SubmitEmbeddingRecord:
|
||||||
|
vector = None
|
||||||
|
if submit_record["embedding"] is not None and submit_record["encoding"] is not None:
|
||||||
|
vector = to_proto_vector(submit_record["embedding"], submit_record["encoding"])
|
||||||
|
|
||||||
|
metadata = None
|
||||||
|
if submit_record["metadata"] is not None:
|
||||||
|
metadata = {
|
||||||
|
k: to_proto_metadata_update_value(v)
|
||||||
|
for k, v in submit_record["metadata"].items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return proto.SubmitEmbeddingRecord(
|
||||||
|
id=submit_record["id"],
|
||||||
|
vector=vector,
|
||||||
|
metadata=proto.UpdateMetadata(metadata=metadata)
|
||||||
|
if metadata is not None
|
||||||
|
else None,
|
||||||
|
operation=to_proto_operation(submit_record["operation"]),
|
||||||
|
)
|
||||||
@@ -469,9 +469,9 @@ class SqliteMetadataSegment(MetadataReader):
|
|||||||
|
|
||||||
def _encode_seq_id(seq_id: SeqId) -> bytes:
|
def _encode_seq_id(seq_id: SeqId) -> bytes:
|
||||||
"""Encode a SeqID into a byte array"""
|
"""Encode a SeqID into a byte array"""
|
||||||
if seq_id.bit_length() < 64:
|
if seq_id.bit_length() <= 64:
|
||||||
return int.to_bytes(seq_id, 8, "big")
|
return int.to_bytes(seq_id, 8, "big")
|
||||||
elif seq_id.bit_length() < 192:
|
elif seq_id.bit_length() <= 192:
|
||||||
return int.to_bytes(seq_id, 24, "big")
|
return int.to_bytes(seq_id, 24, "big")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported SeqID: {seq_id}")
|
raise ValueError(f"Unsupported SeqID: {seq_id}")
|
||||||
|
|||||||
@@ -207,7 +207,6 @@ class PersistentLocalHnswSegment(LocalHnswSegment):
|
|||||||
"""Add a batch of embeddings to the index"""
|
"""Add a batch of embeddings to the index"""
|
||||||
if not self._running:
|
if not self._running:
|
||||||
raise RuntimeError("Cannot add embeddings to stopped component")
|
raise RuntimeError("Cannot add embeddings to stopped component")
|
||||||
|
|
||||||
with WriteRWLock(self._lock):
|
with WriteRWLock(self._lock):
|
||||||
for record in records:
|
for record in records:
|
||||||
if record["embedding"] is not None:
|
if record["embedding"] is not None:
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -16,6 +17,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
from chromadb.ingest import Producer, Consumer
|
from chromadb.ingest import Producer, Consumer
|
||||||
from chromadb.db.impl.sqlite import SqliteDB
|
from chromadb.db.impl.sqlite import SqliteDB
|
||||||
|
from chromadb.ingest.impl.utils import create_topic_name
|
||||||
from chromadb.test.conftest import ProducerFn
|
from chromadb.test.conftest import ProducerFn
|
||||||
from chromadb.types import (
|
from chromadb.types import (
|
||||||
SubmitEmbeddingRecord,
|
SubmitEmbeddingRecord,
|
||||||
@@ -51,8 +53,33 @@ def sqlite_persistent() -> Generator[Tuple[Producer, Consumer], None, None]:
|
|||||||
shutil.rmtree(save_path)
|
shutil.rmtree(save_path)
|
||||||
|
|
||||||
|
|
||||||
|
def pulsar() -> Generator[Tuple[Producer, Consumer], None, None]:
|
||||||
|
"""Fixture generator for pulsar Producer + Consumer. This fixture requires a running
|
||||||
|
pulsar cluster. You can use bin/cluster-test.sh to start a standalone pulsar and run this test
|
||||||
|
"""
|
||||||
|
system = System(
|
||||||
|
Settings(
|
||||||
|
allow_reset=True,
|
||||||
|
chroma_producer_impl="chromadb.ingest.impl.pulsar.PulsarProducer",
|
||||||
|
chroma_consumer_impl="chromadb.ingest.impl.pulsar.PulsarConsumer",
|
||||||
|
pulsar_broker_url="localhost",
|
||||||
|
pulsar_admin_port="8080",
|
||||||
|
pulsar_broker_port="6650",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
producer = system.require(Producer)
|
||||||
|
consumer = system.require(Consumer)
|
||||||
|
system.start()
|
||||||
|
yield producer, consumer
|
||||||
|
system.stop()
|
||||||
|
|
||||||
|
|
||||||
def fixtures() -> List[Callable[[], Generator[Tuple[Producer, Consumer], None, None]]]:
|
def fixtures() -> List[Callable[[], Generator[Tuple[Producer, Consumer], None, None]]]:
|
||||||
return [sqlite, sqlite_persistent]
|
fixtures = [sqlite, sqlite_persistent]
|
||||||
|
if "CHROMA_CLUSTER_TEST_ONLY" in os.environ:
|
||||||
|
fixtures = [pulsar]
|
||||||
|
|
||||||
|
return fixtures
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=fixtures())
|
@pytest.fixture(scope="module", params=fixtures())
|
||||||
@@ -89,14 +116,20 @@ class CapturingConsumeFn:
|
|||||||
waiters: List[Tuple[int, Event]]
|
waiters: List[Tuple[int, Event]]
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
"""A function that captures embeddings and allows you to wait for a certain
|
||||||
|
number of embeddings to be available. It must be constructed in the thread with
|
||||||
|
the main event loop
|
||||||
|
"""
|
||||||
self.embeddings = []
|
self.embeddings = []
|
||||||
self.waiters = []
|
self.waiters = []
|
||||||
|
self._loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
def __call__(self, embeddings: Sequence[EmbeddingRecord]) -> None:
|
def __call__(self, embeddings: Sequence[EmbeddingRecord]) -> None:
|
||||||
self.embeddings.extend(embeddings)
|
self.embeddings.extend(embeddings)
|
||||||
for n, event in self.waiters:
|
for n, event in self.waiters:
|
||||||
if len(self.embeddings) >= n:
|
if len(self.embeddings) >= n:
|
||||||
event.set()
|
# event.set() is not thread safe, so we need to call it in the main event loop
|
||||||
|
self._loop.call_soon_threadsafe(event.set)
|
||||||
|
|
||||||
async def get(self, n: int, timeout_secs: int = 10) -> Sequence[EmbeddingRecord]:
|
async def get(self, n: int, timeout_secs: int = 10) -> Sequence[EmbeddingRecord]:
|
||||||
"Wait until at least N embeddings are available, then return all embeddings"
|
"Wait until at least N embeddings are available, then return all embeddings"
|
||||||
@@ -132,6 +165,10 @@ def assert_records_match(
|
|||||||
assert_approx_equal(inserted["embedding"], consumed["embedding"])
|
assert_approx_equal(inserted["embedding"], consumed["embedding"])
|
||||||
|
|
||||||
|
|
||||||
|
def full_topic_name(topic_name: str) -> str:
|
||||||
|
return create_topic_name("default", "default", topic_name)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_backfill(
|
async def test_backfill(
|
||||||
producer_consumer: Tuple[Producer, Consumer],
|
producer_consumer: Tuple[Producer, Consumer],
|
||||||
@@ -140,12 +177,14 @@ async def test_backfill(
|
|||||||
) -> None:
|
) -> None:
|
||||||
producer, consumer = producer_consumer
|
producer, consumer = producer_consumer
|
||||||
producer.reset_state()
|
producer.reset_state()
|
||||||
|
consumer.reset_state()
|
||||||
|
|
||||||
producer.create_topic("test_topic")
|
topic_name = full_topic_name("test_topic")
|
||||||
embeddings = produce_fns(producer, "test_topic", sample_embeddings, 3)[0]
|
producer.create_topic(topic_name)
|
||||||
|
embeddings = produce_fns(producer, topic_name, sample_embeddings, 3)[0]
|
||||||
|
|
||||||
consume_fn = CapturingConsumeFn()
|
consume_fn = CapturingConsumeFn()
|
||||||
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
|
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid())
|
||||||
|
|
||||||
recieved = await consume_fn.get(3)
|
recieved = await consume_fn.get(3)
|
||||||
assert_records_match(embeddings, recieved)
|
assert_records_match(embeddings, recieved)
|
||||||
@@ -158,18 +197,21 @@ async def test_notifications(
|
|||||||
) -> None:
|
) -> None:
|
||||||
producer, consumer = producer_consumer
|
producer, consumer = producer_consumer
|
||||||
producer.reset_state()
|
producer.reset_state()
|
||||||
producer.create_topic("test_topic")
|
consumer.reset_state()
|
||||||
|
topic_name = full_topic_name("test_topic")
|
||||||
|
|
||||||
|
producer.create_topic(topic_name)
|
||||||
|
|
||||||
embeddings: List[SubmitEmbeddingRecord] = []
|
embeddings: List[SubmitEmbeddingRecord] = []
|
||||||
|
|
||||||
consume_fn = CapturingConsumeFn()
|
consume_fn = CapturingConsumeFn()
|
||||||
|
|
||||||
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
|
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid())
|
||||||
|
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
e = next(sample_embeddings)
|
e = next(sample_embeddings)
|
||||||
embeddings.append(e)
|
embeddings.append(e)
|
||||||
producer.submit_embedding("test_topic", e)
|
producer.submit_embedding(topic_name, e)
|
||||||
received = await consume_fn.get(i + 1)
|
received = await consume_fn.get(i + 1)
|
||||||
assert_records_match(embeddings, received)
|
assert_records_match(embeddings, received)
|
||||||
|
|
||||||
@@ -181,8 +223,11 @@ async def test_multiple_topics(
|
|||||||
) -> None:
|
) -> None:
|
||||||
producer, consumer = producer_consumer
|
producer, consumer = producer_consumer
|
||||||
producer.reset_state()
|
producer.reset_state()
|
||||||
producer.create_topic("test_topic_1")
|
consumer.reset_state()
|
||||||
producer.create_topic("test_topic_2")
|
topic_name_1 = full_topic_name("test_topic_1")
|
||||||
|
topic_name_2 = full_topic_name("test_topic_2")
|
||||||
|
producer.create_topic(topic_name_1)
|
||||||
|
producer.create_topic(topic_name_2)
|
||||||
|
|
||||||
embeddings_1: List[SubmitEmbeddingRecord] = []
|
embeddings_1: List[SubmitEmbeddingRecord] = []
|
||||||
embeddings_2: List[SubmitEmbeddingRecord] = []
|
embeddings_2: List[SubmitEmbeddingRecord] = []
|
||||||
@@ -190,19 +235,19 @@ async def test_multiple_topics(
|
|||||||
consume_fn_1 = CapturingConsumeFn()
|
consume_fn_1 = CapturingConsumeFn()
|
||||||
consume_fn_2 = CapturingConsumeFn()
|
consume_fn_2 = CapturingConsumeFn()
|
||||||
|
|
||||||
consumer.subscribe("test_topic_1", consume_fn_1, start=consumer.min_seqid())
|
consumer.subscribe(topic_name_1, consume_fn_1, start=consumer.min_seqid())
|
||||||
consumer.subscribe("test_topic_2", consume_fn_2, start=consumer.min_seqid())
|
consumer.subscribe(topic_name_2, consume_fn_2, start=consumer.min_seqid())
|
||||||
|
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
e_1 = next(sample_embeddings)
|
e_1 = next(sample_embeddings)
|
||||||
embeddings_1.append(e_1)
|
embeddings_1.append(e_1)
|
||||||
producer.submit_embedding("test_topic_1", e_1)
|
producer.submit_embedding(topic_name_1, e_1)
|
||||||
results_2 = await consume_fn_1.get(i + 1)
|
results_2 = await consume_fn_1.get(i + 1)
|
||||||
assert_records_match(embeddings_1, results_2)
|
assert_records_match(embeddings_1, results_2)
|
||||||
|
|
||||||
e_2 = next(sample_embeddings)
|
e_2 = next(sample_embeddings)
|
||||||
embeddings_2.append(e_2)
|
embeddings_2.append(e_2)
|
||||||
producer.submit_embedding("test_topic_2", e_2)
|
producer.submit_embedding(topic_name_2, e_2)
|
||||||
results_2 = await consume_fn_2.get(i + 1)
|
results_2 = await consume_fn_2.get(i + 1)
|
||||||
assert_records_match(embeddings_2, results_2)
|
assert_records_match(embeddings_2, results_2)
|
||||||
|
|
||||||
@@ -215,21 +260,23 @@ async def test_start_seq_id(
|
|||||||
) -> None:
|
) -> None:
|
||||||
producer, consumer = producer_consumer
|
producer, consumer = producer_consumer
|
||||||
producer.reset_state()
|
producer.reset_state()
|
||||||
producer.create_topic("test_topic")
|
consumer.reset_state()
|
||||||
|
topic_name = full_topic_name("test_topic")
|
||||||
|
producer.create_topic(topic_name)
|
||||||
|
|
||||||
consume_fn_1 = CapturingConsumeFn()
|
consume_fn_1 = CapturingConsumeFn()
|
||||||
consume_fn_2 = CapturingConsumeFn()
|
consume_fn_2 = CapturingConsumeFn()
|
||||||
|
|
||||||
consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid())
|
consumer.subscribe(topic_name, consume_fn_1, start=consumer.min_seqid())
|
||||||
|
|
||||||
embeddings = produce_fns(producer, "test_topic", sample_embeddings, 5)[0]
|
embeddings = produce_fns(producer, topic_name, sample_embeddings, 5)[0]
|
||||||
|
|
||||||
results_1 = await consume_fn_1.get(5)
|
results_1 = await consume_fn_1.get(5)
|
||||||
assert_records_match(embeddings, results_1)
|
assert_records_match(embeddings, results_1)
|
||||||
|
|
||||||
start = consume_fn_1.embeddings[-1]["seq_id"]
|
start = consume_fn_1.embeddings[-1]["seq_id"]
|
||||||
consumer.subscribe("test_topic", consume_fn_2, start=start)
|
consumer.subscribe(topic_name, consume_fn_2, start=start)
|
||||||
second_embeddings = produce_fns(producer, "test_topic", sample_embeddings, 5)[0]
|
second_embeddings = produce_fns(producer, topic_name, sample_embeddings, 5)[0]
|
||||||
assert isinstance(embeddings, list)
|
assert isinstance(embeddings, list)
|
||||||
embeddings.extend(second_embeddings)
|
embeddings.extend(second_embeddings)
|
||||||
results_2 = await consume_fn_2.get(5)
|
results_2 = await consume_fn_2.get(5)
|
||||||
@@ -244,20 +291,22 @@ async def test_end_seq_id(
|
|||||||
) -> None:
|
) -> None:
|
||||||
producer, consumer = producer_consumer
|
producer, consumer = producer_consumer
|
||||||
producer.reset_state()
|
producer.reset_state()
|
||||||
producer.create_topic("test_topic")
|
consumer.reset_state()
|
||||||
|
topic_name = full_topic_name("test_topic")
|
||||||
|
producer.create_topic(topic_name)
|
||||||
|
|
||||||
consume_fn_1 = CapturingConsumeFn()
|
consume_fn_1 = CapturingConsumeFn()
|
||||||
consume_fn_2 = CapturingConsumeFn()
|
consume_fn_2 = CapturingConsumeFn()
|
||||||
|
|
||||||
consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid())
|
consumer.subscribe(topic_name, consume_fn_1, start=consumer.min_seqid())
|
||||||
|
|
||||||
embeddings = produce_fns(producer, "test_topic", sample_embeddings, 10)[0]
|
embeddings = produce_fns(producer, topic_name, sample_embeddings, 10)[0]
|
||||||
|
|
||||||
results_1 = await consume_fn_1.get(10)
|
results_1 = await consume_fn_1.get(10)
|
||||||
assert_records_match(embeddings, results_1)
|
assert_records_match(embeddings, results_1)
|
||||||
|
|
||||||
end = consume_fn_1.embeddings[-5]["seq_id"]
|
end = consume_fn_1.embeddings[-5]["seq_id"]
|
||||||
consumer.subscribe("test_topic", consume_fn_2, start=consumer.min_seqid(), end=end)
|
consumer.subscribe(topic_name, consume_fn_2, start=consumer.min_seqid(), end=end)
|
||||||
|
|
||||||
results_2 = await consume_fn_2.get(6)
|
results_2 = await consume_fn_2.get(6)
|
||||||
assert_records_match(embeddings[:6], results_2)
|
assert_records_match(embeddings[:6], results_2)
|
||||||
@@ -274,14 +323,16 @@ async def test_submit_batch(
|
|||||||
) -> None:
|
) -> None:
|
||||||
producer, consumer = producer_consumer
|
producer, consumer = producer_consumer
|
||||||
producer.reset_state()
|
producer.reset_state()
|
||||||
|
consumer.reset_state()
|
||||||
|
topic_name = full_topic_name("test_topic")
|
||||||
|
|
||||||
embeddings = [next(sample_embeddings) for _ in range(100)]
|
embeddings = [next(sample_embeddings) for _ in range(100)]
|
||||||
|
|
||||||
producer.create_topic("test_topic")
|
producer.create_topic(topic_name)
|
||||||
producer.submit_embeddings("test_topic", embeddings=embeddings)
|
producer.submit_embeddings(topic_name, embeddings=embeddings)
|
||||||
|
|
||||||
consume_fn = CapturingConsumeFn()
|
consume_fn = CapturingConsumeFn()
|
||||||
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
|
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid())
|
||||||
|
|
||||||
recieved = await consume_fn.get(100)
|
recieved = await consume_fn.get(100)
|
||||||
assert_records_match(embeddings, recieved)
|
assert_records_match(embeddings, recieved)
|
||||||
@@ -295,13 +346,16 @@ async def test_multiple_topics_batch(
|
|||||||
) -> None:
|
) -> None:
|
||||||
producer, consumer = producer_consumer
|
producer, consumer = producer_consumer
|
||||||
producer.reset_state()
|
producer.reset_state()
|
||||||
|
consumer.reset_state()
|
||||||
|
|
||||||
N_TOPICS = 100
|
N_TOPICS = 2
|
||||||
consume_fns = [CapturingConsumeFn() for _ in range(N_TOPICS)]
|
consume_fns = [CapturingConsumeFn() for _ in range(N_TOPICS)]
|
||||||
for i in range(N_TOPICS):
|
for i in range(N_TOPICS):
|
||||||
producer.create_topic(f"test_topic_{i}")
|
producer.create_topic(full_topic_name(f"test_topic_{i}"))
|
||||||
consumer.subscribe(
|
consumer.subscribe(
|
||||||
f"test_topic_{i}", consume_fns[i], start=consumer.min_seqid()
|
full_topic_name(f"test_topic_{i}"),
|
||||||
|
consume_fns[i],
|
||||||
|
start=consumer.min_seqid(),
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings_n: List[List[SubmitEmbeddingRecord]] = [[] for _ in range(N_TOPICS)]
|
embeddings_n: List[List[SubmitEmbeddingRecord]] = [[] for _ in range(N_TOPICS)]
|
||||||
@@ -310,17 +364,17 @@ async def test_multiple_topics_batch(
|
|||||||
N_TO_PRODUCE = 100
|
N_TO_PRODUCE = 100
|
||||||
total_produced = 0
|
total_produced = 0
|
||||||
for i in range(N_TO_PRODUCE // PRODUCE_BATCH_SIZE):
|
for i in range(N_TO_PRODUCE // PRODUCE_BATCH_SIZE):
|
||||||
for i in range(N_TOPICS):
|
for n in range(N_TOPICS):
|
||||||
embeddings_n[i].extend(
|
embeddings_n[n].extend(
|
||||||
produce_fns(
|
produce_fns(
|
||||||
producer,
|
producer,
|
||||||
f"test_topic_{i}",
|
full_topic_name(f"test_topic_{n}"),
|
||||||
sample_embeddings,
|
sample_embeddings,
|
||||||
PRODUCE_BATCH_SIZE,
|
PRODUCE_BATCH_SIZE,
|
||||||
)[0]
|
)[0]
|
||||||
)
|
)
|
||||||
recieved = await consume_fns[i].get(total_produced + PRODUCE_BATCH_SIZE)
|
recieved = await consume_fns[n].get(total_produced + PRODUCE_BATCH_SIZE)
|
||||||
assert_records_match(embeddings_n[i], recieved)
|
assert_records_match(embeddings_n[n], recieved)
|
||||||
total_produced += PRODUCE_BATCH_SIZE
|
total_produced += PRODUCE_BATCH_SIZE
|
||||||
|
|
||||||
|
|
||||||
@@ -331,19 +385,21 @@ async def test_max_batch_size(
|
|||||||
) -> None:
|
) -> None:
|
||||||
producer, consumer = producer_consumer
|
producer, consumer = producer_consumer
|
||||||
producer.reset_state()
|
producer.reset_state()
|
||||||
max_batch_size = producer_consumer[0].max_batch_size
|
consumer.reset_state()
|
||||||
|
topic_name = full_topic_name("test_topic")
|
||||||
|
max_batch_size = producer.max_batch_size
|
||||||
assert max_batch_size > 0
|
assert max_batch_size > 0
|
||||||
|
|
||||||
# Make sure that we can produce a batch of size max_batch_size
|
# Make sure that we can produce a batch of size max_batch_size
|
||||||
embeddings = [next(sample_embeddings) for _ in range(max_batch_size)]
|
embeddings = [next(sample_embeddings) for _ in range(max_batch_size)]
|
||||||
consume_fn = CapturingConsumeFn()
|
consume_fn = CapturingConsumeFn()
|
||||||
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
|
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid())
|
||||||
producer.submit_embeddings("test_topic", embeddings=embeddings)
|
producer.submit_embeddings(topic_name, embeddings=embeddings)
|
||||||
received = await consume_fn.get(max_batch_size, timeout_secs=120)
|
received = await consume_fn.get(max_batch_size, timeout_secs=120)
|
||||||
assert_records_match(embeddings, received)
|
assert_records_match(embeddings, received)
|
||||||
|
|
||||||
embeddings = [next(sample_embeddings) for _ in range(max_batch_size + 1)]
|
embeddings = [next(sample_embeddings) for _ in range(max_batch_size + 1)]
|
||||||
# Make sure that we can't produce a batch of size > max_batch_size
|
# Make sure that we can't produce a batch of size > max_batch_size
|
||||||
with pytest.raises(ValueError) as e:
|
with pytest.raises(ValueError) as e:
|
||||||
producer.submit_embeddings("test_topic", embeddings=embeddings)
|
producer.submit_embeddings(topic_name, embeddings=embeddings)
|
||||||
assert "Cannot submit more than" in str(e.value)
|
assert "Cannot submit more than" in str(e.value)
|
||||||
|
|||||||
66
docker-compose.cluster.yml
Normal file
66
docker-compose.cluster.yml
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
# This docker compose file is not meant to be used. It is a work in progress
|
||||||
|
# for the distributed version of Chroma. It is not yet functional.
|
||||||
|
|
||||||
|
version: '3.9'
|
||||||
|
|
||||||
|
networks:
|
||||||
|
net:
|
||||||
|
driver: bridge
|
||||||
|
|
||||||
|
services:
|
||||||
|
server:
|
||||||
|
image: server
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
volumes:
|
||||||
|
- ./:/chroma
|
||||||
|
- index_data:/index_data
|
||||||
|
command: uvicorn chromadb.app:app --reload --workers 1 --host 0.0.0.0 --port 8000 --log-config log_config.yml
|
||||||
|
environment:
|
||||||
|
- IS_PERSISTENT=TRUE
|
||||||
|
- CHROMA_PRODUCER_IMPL=chromadb.ingest.impl.pulsar.PulsarProducer
|
||||||
|
- CHROMA_CONSUMER_IMPL=chromadb.ingest.impl.pulsar.PulsarConsumer
|
||||||
|
- PULSAR_BROKER_URL=pulsar
|
||||||
|
- PULSAR_BROKER_PORT=6650
|
||||||
|
- PULSAR_ADMIN_PORT=8080
|
||||||
|
ports:
|
||||||
|
- 8000:8000
|
||||||
|
depends_on:
|
||||||
|
pulsar:
|
||||||
|
condition: service_healthy
|
||||||
|
networks:
|
||||||
|
- net
|
||||||
|
|
||||||
|
pulsar:
|
||||||
|
image: apachepulsar/pulsar
|
||||||
|
volumes:
|
||||||
|
- pulsardata:/pulsar/data
|
||||||
|
- pulsarconf:/pulsar/conf
|
||||||
|
command: bin/pulsar standalone
|
||||||
|
ports:
|
||||||
|
- 6650:6650
|
||||||
|
- 8080:8080
|
||||||
|
networks:
|
||||||
|
- net
|
||||||
|
healthcheck:
|
||||||
|
test:
|
||||||
|
[
|
||||||
|
"CMD",
|
||||||
|
"curl",
|
||||||
|
"-f",
|
||||||
|
"localhost:8080/admin/v2/brokers/health"
|
||||||
|
]
|
||||||
|
interval: 3s
|
||||||
|
timeout: 1m
|
||||||
|
retries: 10
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
index_data:
|
||||||
|
driver: local
|
||||||
|
backups:
|
||||||
|
driver: local
|
||||||
|
pulsardata:
|
||||||
|
driver: local
|
||||||
|
pulsarconf:
|
||||||
|
driver: local
|
||||||
@@ -3,8 +3,10 @@ build
|
|||||||
httpx
|
httpx
|
||||||
hypothesis
|
hypothesis
|
||||||
hypothesis[numpy]
|
hypothesis[numpy]
|
||||||
|
mypy-protobuf
|
||||||
pre-commit
|
pre-commit
|
||||||
pytest
|
pytest
|
||||||
pytest-asyncio
|
pytest-asyncio
|
||||||
setuptools_scm
|
setuptools_scm
|
||||||
|
types-protobuf
|
||||||
types-requests==2.30.0.0
|
types-requests==2.30.0.0
|
||||||
|
|||||||
Reference in New Issue
Block a user