mirror of
https://github.com/placeholder-soft/chroma.git
synced 2026-01-12 08:44:18 +08:00
[ENH] Pulsar Producer & Consumer (#921)
## Description of changes *Summarize the changes made by this PR.* - New functionality - Adds a basic pulsar producer, consumer and associated tests. As well as a docker compose for the distributed version of chroma. ## Test plan We added bin/cluster-test.sh, which starts pulsar and allows test_producer_consumer to run the pulsar fixture. ## Documentation Changes None required.
This commit is contained in:
31
.github/workflows/chroma-cluster-test.yml
vendored
Normal file
31
.github/workflows/chroma-cluster-test.yml
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
name: Chroma Cluster Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- '**'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
python: ['3.7']
|
||||
platform: [ubuntu-latest]
|
||||
testfile: ["chromadb/test/ingest/test_producer_consumer.py"] # Just this one test for now
|
||||
runs-on: ${{ matrix.platform }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
- name: Install test dependencies
|
||||
run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt
|
||||
- name: Integration Test
|
||||
run: bin/cluster-test.sh ${{ matrix.testfile }}
|
||||
@@ -32,4 +32,4 @@ repos:
|
||||
hooks:
|
||||
- id: mypy
|
||||
args: [--strict, --ignore-missing-imports, --follow-imports=silent, --disable-error-code=type-abstract]
|
||||
additional_dependencies: ["types-requests", "pydantic", "overrides", "hypothesis", "pytest", "pypika", "numpy"]
|
||||
additional_dependencies: ["types-requests", "pydantic", "overrides", "hypothesis", "pytest", "pypika", "numpy", "types-protobuf"]
|
||||
|
||||
16
bin/cluster-test.sh
Executable file
16
bin/cluster-test.sh
Executable file
@@ -0,0 +1,16 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
function cleanup {
|
||||
docker compose -f docker-compose.cluster.yml down --rmi local --volumes
|
||||
}
|
||||
|
||||
trap cleanup EXIT
|
||||
|
||||
docker compose -f docker-compose.cluster.yml up -d --wait pulsar
|
||||
|
||||
export CHROMA_CLUSTER_TEST_ONLY=1
|
||||
|
||||
echo testing: python -m pytest "$@"
|
||||
python -m pytest "$@"
|
||||
@@ -1,6 +1,7 @@
|
||||
from chromadb.api import API
|
||||
from chromadb.config import Settings, System
|
||||
from chromadb.db.system import SysDB
|
||||
from chromadb.ingest.impl.utils import create_topic_name
|
||||
from chromadb.segment import SegmentManager, MetadataReader, VectorReader
|
||||
from chromadb.telemetry import Telemetry
|
||||
from chromadb.ingest import Producer
|
||||
@@ -130,6 +131,9 @@ class SegmentAPI(API):
|
||||
coll = t.Collection(
|
||||
id=id, name=name, metadata=metadata, topic=self._topic(id), dimension=None
|
||||
)
|
||||
# TODO: Topic creation right now lives in the producer but it should be moved to the coordinator,
|
||||
# and the producer should just be responsible for publishing messages. Coordinator should
|
||||
# be responsible for all management of topics.
|
||||
self._producer.create_topic(coll["topic"])
|
||||
segments = self._manager.create_segments(coll)
|
||||
self._sysdb.create_collection(coll)
|
||||
@@ -559,7 +563,7 @@ class SegmentAPI(API):
|
||||
return self._producer.max_batch_size
|
||||
|
||||
def _topic(self, collection_id: UUID) -> str:
|
||||
return f"persistent://{self._tenant_id}/{self._topic_ns}/{collection_id}"
|
||||
return create_topic_name(self._tenant_id, self._topic_ns, str(collection_id))
|
||||
|
||||
# TODO: This could potentially cause race conditions in a distributed version of the
|
||||
# system, since the cache is only local.
|
||||
|
||||
@@ -92,6 +92,10 @@ class Settings(BaseSettings): # type: ignore
|
||||
chroma_server_grpc_port: Optional[str] = None
|
||||
chroma_server_cors_allow_origins: List[str] = [] # eg ["http://localhost:3000"]
|
||||
|
||||
pulsar_broker_url: Optional[str] = None
|
||||
pulsar_admin_port: Optional[str] = None
|
||||
pulsar_broker_port: Optional[str] = None
|
||||
|
||||
chroma_server_auth_provider: Optional[str] = None
|
||||
|
||||
@validator("chroma_server_auth_provider", pre=True, always=True, allow_reuse=True)
|
||||
|
||||
304
chromadb/ingest/impl/pulsar.py
Normal file
304
chromadb/ingest/impl/pulsar.py
Normal file
@@ -0,0 +1,304 @@
|
||||
from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
|
||||
import uuid
|
||||
from chromadb.config import Settings, System
|
||||
from chromadb.ingest import Consumer, ConsumerCallbackFn, Producer
|
||||
from overrides import overrides, EnforceOverrides
|
||||
from uuid import UUID
|
||||
from chromadb.ingest.impl.pulsar_admin import PulsarAdmin
|
||||
from chromadb.ingest.impl.utils import create_pulsar_connection_str
|
||||
from chromadb.proto.convert import from_proto_submit, to_proto_submit
|
||||
import chromadb.proto.chroma_pb2 as proto
|
||||
from chromadb.types import SeqId, SubmitEmbeddingRecord
|
||||
import pulsar
|
||||
from concurrent.futures import wait, Future
|
||||
|
||||
from chromadb.utils.messageid import int_to_pulsar, pulsar_to_int
|
||||
|
||||
|
||||
class PulsarProducer(Producer, EnforceOverrides):
|
||||
_connection_str: str
|
||||
_topic_to_producer: Dict[str, pulsar.Producer]
|
||||
_client: pulsar.Client
|
||||
_admin: PulsarAdmin
|
||||
_settings: Settings
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
pulsar_host = system.settings.require("pulsar_broker_url")
|
||||
pulsar_port = system.settings.require("pulsar_broker_port")
|
||||
self._connection_str = create_pulsar_connection_str(pulsar_host, pulsar_port)
|
||||
self._topic_to_producer = {}
|
||||
self._settings = system.settings
|
||||
self._admin = PulsarAdmin(system)
|
||||
super().__init__(system)
|
||||
|
||||
@overrides
|
||||
def start(self) -> None:
|
||||
self._client = pulsar.Client(self._connection_str)
|
||||
super().start()
|
||||
|
||||
@overrides
|
||||
def stop(self) -> None:
|
||||
self._client.close()
|
||||
super().stop()
|
||||
|
||||
@overrides
|
||||
def create_topic(self, topic_name: str) -> None:
|
||||
self._admin.create_topic(topic_name)
|
||||
|
||||
@overrides
|
||||
def delete_topic(self, topic_name: str) -> None:
|
||||
self._admin.delete_topic(topic_name)
|
||||
|
||||
@overrides
|
||||
def submit_embedding(
|
||||
self, topic_name: str, embedding: SubmitEmbeddingRecord
|
||||
) -> SeqId:
|
||||
"""Add an embedding record to the given topic. Returns the SeqID of the record."""
|
||||
producer = self._get_or_create_producer(topic_name)
|
||||
proto_submit: proto.SubmitEmbeddingRecord = to_proto_submit(embedding)
|
||||
# TODO: batch performance / async
|
||||
msg_id: pulsar.MessageId = producer.send(proto_submit.SerializeToString())
|
||||
return pulsar_to_int(msg_id)
|
||||
|
||||
@overrides
|
||||
def submit_embeddings(
|
||||
self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord]
|
||||
) -> Sequence[SeqId]:
|
||||
if not self._running:
|
||||
raise RuntimeError("Component not running")
|
||||
|
||||
if len(embeddings) == 0:
|
||||
return []
|
||||
|
||||
if len(embeddings) > self.max_batch_size:
|
||||
raise ValueError(
|
||||
f"""
|
||||
Cannot submit more than {self.max_batch_size:,} embeddings at once.
|
||||
Please submit your embeddings in batches of size
|
||||
{self.max_batch_size:,} or less.
|
||||
"""
|
||||
)
|
||||
|
||||
producer = self._get_or_create_producer(topic_name)
|
||||
protos_to_submit = [to_proto_submit(embedding) for embedding in embeddings]
|
||||
|
||||
def create_producer_callback(
|
||||
future: Future[int],
|
||||
) -> Callable[[Any, pulsar.MessageId], None]:
|
||||
def producer_callback(res: Any, msg_id: pulsar.MessageId) -> None:
|
||||
if msg_id:
|
||||
future.set_result(pulsar_to_int(msg_id))
|
||||
else:
|
||||
future.set_exception(
|
||||
Exception(
|
||||
"Unknown error while submitting embedding in producer_callback"
|
||||
)
|
||||
)
|
||||
|
||||
return producer_callback
|
||||
|
||||
futures = []
|
||||
for proto_to_submit in protos_to_submit:
|
||||
future: Future[int] = Future()
|
||||
producer.send_async(
|
||||
proto_to_submit.SerializeToString(),
|
||||
callback=create_producer_callback(future),
|
||||
)
|
||||
futures.append(future)
|
||||
|
||||
wait(futures)
|
||||
|
||||
results: List[SeqId] = []
|
||||
for future in futures:
|
||||
exception = future.exception()
|
||||
if exception is not None:
|
||||
raise exception
|
||||
results.append(future.result())
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
@overrides
|
||||
def max_batch_size(self) -> int:
|
||||
# For now, we use 1,000
|
||||
# TODO: tune this to a reasonable value by default
|
||||
return 1000
|
||||
|
||||
def _get_or_create_producer(self, topic_name: str) -> pulsar.Producer:
|
||||
if topic_name not in self._topic_to_producer:
|
||||
producer = self._client.create_producer(topic_name)
|
||||
self._topic_to_producer[topic_name] = producer
|
||||
return self._topic_to_producer[topic_name]
|
||||
|
||||
@overrides
|
||||
def reset_state(self) -> None:
|
||||
if not self._settings.require("allow_reset"):
|
||||
raise ValueError(
|
||||
"Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted."
|
||||
)
|
||||
for topic_name in self._topic_to_producer:
|
||||
self._admin.delete_topic(topic_name)
|
||||
self._topic_to_producer = {}
|
||||
super().reset_state()
|
||||
|
||||
|
||||
class PulsarConsumer(Consumer, EnforceOverrides):
|
||||
class PulsarSubscription:
|
||||
id: UUID
|
||||
topic_name: str
|
||||
start: int
|
||||
end: int
|
||||
callback: ConsumerCallbackFn
|
||||
consumer: pulsar.Consumer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: UUID,
|
||||
topic_name: str,
|
||||
start: int,
|
||||
end: int,
|
||||
callback: ConsumerCallbackFn,
|
||||
consumer: pulsar.Consumer,
|
||||
):
|
||||
self.id = id
|
||||
self.topic_name = topic_name
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.callback = callback
|
||||
self.consumer = consumer
|
||||
|
||||
_connection_str: str
|
||||
_client: pulsar.Client
|
||||
_subscriptions: Dict[str, Set[PulsarSubscription]]
|
||||
_settings: Settings
|
||||
|
||||
def __init__(self, system: System) -> None:
|
||||
pulsar_host = system.settings.require("pulsar_broker_url")
|
||||
pulsar_port = system.settings.require("pulsar_broker_port")
|
||||
self._connection_str = create_pulsar_connection_str(pulsar_host, pulsar_port)
|
||||
self._subscriptions = defaultdict(set)
|
||||
self._settings = system.settings
|
||||
super().__init__(system)
|
||||
|
||||
@overrides
|
||||
def start(self) -> None:
|
||||
self._client = pulsar.Client(self._connection_str)
|
||||
super().start()
|
||||
|
||||
@overrides
|
||||
def stop(self) -> None:
|
||||
self._client.close()
|
||||
super().stop()
|
||||
|
||||
@overrides
|
||||
def subscribe(
|
||||
self,
|
||||
topic_name: str,
|
||||
consume_fn: ConsumerCallbackFn,
|
||||
start: Optional[SeqId] = None,
|
||||
end: Optional[SeqId] = None,
|
||||
id: Optional[UUID] = None,
|
||||
) -> UUID:
|
||||
"""Register a function that will be called to recieve embeddings for a given
|
||||
topic. The given function may be called any number of times, with any number of
|
||||
records, and may be called concurrently.
|
||||
|
||||
Only records between start (exclusive) and end (inclusive) SeqIDs will be
|
||||
returned. If start is None, the first record returned will be the next record
|
||||
generated, not including those generated before creating the subscription. If
|
||||
end is None, the consumer will consume indefinitely, otherwise it will
|
||||
automatically be unsubscribed when the end SeqID is reached.
|
||||
|
||||
If the function throws an exception, the function may be called again with the
|
||||
same or different records.
|
||||
|
||||
Takes an optional UUID as a unique subscription ID. If no ID is provided, a new
|
||||
ID will be generated and returned."""
|
||||
if not self._running:
|
||||
raise RuntimeError("Consumer must be started before subscribing")
|
||||
|
||||
subscription_id = (
|
||||
id or uuid.uuid4()
|
||||
) # TODO: this should really be created by the coordinator and stored in sysdb
|
||||
|
||||
start, end = self._validate_range(start, end)
|
||||
|
||||
def wrap_callback(consumer: pulsar.Consumer, message: pulsar.Message) -> None:
|
||||
msg_data = message.data()
|
||||
msg_id = pulsar_to_int(message.message_id())
|
||||
submit_embedding_record = proto.SubmitEmbeddingRecord()
|
||||
proto.SubmitEmbeddingRecord.ParseFromString(
|
||||
submit_embedding_record, msg_data
|
||||
)
|
||||
embedding_record = from_proto_submit(submit_embedding_record, msg_id)
|
||||
consume_fn([embedding_record])
|
||||
consumer.acknowledge(message)
|
||||
if msg_id == end:
|
||||
self.unsubscribe(subscription_id)
|
||||
|
||||
consumer = self._client.subscribe(
|
||||
topic_name,
|
||||
subscription_id.hex,
|
||||
message_listener=wrap_callback,
|
||||
)
|
||||
|
||||
subscription = self.PulsarSubscription(
|
||||
subscription_id, topic_name, start, end, consume_fn, consumer
|
||||
)
|
||||
self._subscriptions[topic_name].add(subscription)
|
||||
|
||||
# NOTE: For some reason the seek() method expects a shadowed MessageId type
|
||||
# which resides in _msg_id.
|
||||
consumer.seek(int_to_pulsar(start)._msg_id)
|
||||
|
||||
return subscription_id
|
||||
|
||||
def _validate_range(
|
||||
self, start: Optional[SeqId], end: Optional[SeqId]
|
||||
) -> Tuple[int, int]:
|
||||
"""Validate and normalize the start and end SeqIDs for a subscription using this
|
||||
impl."""
|
||||
start = start or pulsar_to_int(pulsar.MessageId.latest)
|
||||
end = end or self.max_seqid()
|
||||
if not isinstance(start, int) or not isinstance(end, int):
|
||||
raise ValueError("SeqIDs must be integers")
|
||||
if start >= end:
|
||||
raise ValueError(f"Invalid SeqID range: {start} to {end}")
|
||||
return start, end
|
||||
|
||||
@overrides
|
||||
def unsubscribe(self, subscription_id: UUID) -> None:
|
||||
"""Unregister a subscription. The consume function will no longer be invoked,
|
||||
and resources associated with the subscription will be released."""
|
||||
for topic_name, subscriptions in self._subscriptions.items():
|
||||
for subscription in subscriptions:
|
||||
if subscription.id == subscription_id:
|
||||
subscription.consumer.close()
|
||||
subscriptions.remove(subscription)
|
||||
if len(subscriptions) == 0:
|
||||
del self._subscriptions[topic_name]
|
||||
return
|
||||
|
||||
@overrides
|
||||
def min_seqid(self) -> SeqId:
|
||||
"""Return the minimum possible SeqID in this implementation."""
|
||||
return pulsar_to_int(pulsar.MessageId.earliest)
|
||||
|
||||
@overrides
|
||||
def max_seqid(self) -> SeqId:
|
||||
"""Return the maximum possible SeqID in this implementation."""
|
||||
return 2**192 - 1
|
||||
|
||||
@overrides
|
||||
def reset_state(self) -> None:
|
||||
if not self._settings.require("allow_reset"):
|
||||
raise ValueError(
|
||||
"Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted."
|
||||
)
|
||||
for topic_name, subscriptions in self._subscriptions.items():
|
||||
for subscription in subscriptions:
|
||||
subscription.consumer.close()
|
||||
self._subscriptions = defaultdict(set)
|
||||
super().reset_state()
|
||||
81
chromadb/ingest/impl/pulsar_admin.py
Normal file
81
chromadb/ingest/impl/pulsar_admin.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# A thin wrapper around the pulsar admin api
|
||||
import requests
|
||||
from chromadb.config import System
|
||||
from chromadb.ingest.impl.utils import parse_topic_name
|
||||
|
||||
|
||||
class PulsarAdmin:
|
||||
"""A thin wrapper around the pulsar admin api, only used for interim development towards distributed chroma.
|
||||
This functionality will be moved to the chroma coordinator."""
|
||||
|
||||
_connection_str: str
|
||||
|
||||
def __init__(self, system: System):
|
||||
pulsar_host = system.settings.require("pulsar_broker_url")
|
||||
pulsar_port = system.settings.require("pulsar_admin_port")
|
||||
self._connection_str = f"http://{pulsar_host}:{pulsar_port}"
|
||||
|
||||
# Create the default tenant and namespace
|
||||
# This is a temporary workaround until we have a proper tenant/namespace management system
|
||||
self.create_tenant("default")
|
||||
self.create_namespace("default", "default")
|
||||
|
||||
def create_tenant(self, tenant: str) -> None:
|
||||
"""Make a PUT request to the admin api to create the tenant"""
|
||||
|
||||
path = f"/admin/v2/tenants/{tenant}"
|
||||
url = self._connection_str + path
|
||||
response = requests.put(
|
||||
url, json={"allowedClusters": ["standalone"], "adminRoles": []}
|
||||
) # TODO: how to manage clusters?
|
||||
|
||||
if response.status_code != 204 and response.status_code != 409:
|
||||
raise RuntimeError(f"Failed to create tenant {tenant}")
|
||||
|
||||
def create_namespace(self, tenant: str, namespace: str) -> None:
|
||||
"""Make a PUT request to the admin api to create the namespace"""
|
||||
|
||||
path = f"/admin/v2/namespaces/{tenant}/{namespace}"
|
||||
url = self._connection_str + path
|
||||
response = requests.put(url)
|
||||
|
||||
if response.status_code != 204 and response.status_code != 409:
|
||||
raise RuntimeError(f"Failed to create namespace {namespace}")
|
||||
|
||||
def create_topic(self, topic: str) -> None:
|
||||
# TODO: support non-persistent topics?
|
||||
tenant, namespace, topic_name = parse_topic_name(topic)
|
||||
|
||||
if tenant != "default":
|
||||
raise ValueError(f"Only the default tenant is supported, got {tenant}")
|
||||
if namespace != "default":
|
||||
raise ValueError(
|
||||
f"Only the default namespace is supported, got {namespace}"
|
||||
)
|
||||
|
||||
# Make a PUT request to the admin api to create the topic
|
||||
path = f"/admin/v2/persistent/{tenant}/{namespace}/{topic_name}"
|
||||
url = self._connection_str + path
|
||||
response = requests.put(url)
|
||||
|
||||
if response.status_code != 204 and response.status_code != 409:
|
||||
raise RuntimeError(f"Failed to create topic {topic_name}")
|
||||
|
||||
def delete_topic(self, topic: str) -> None:
|
||||
tenant, namespace, topic_name = parse_topic_name(topic)
|
||||
|
||||
if tenant != "default":
|
||||
raise ValueError(f"Only the default tenant is supported, got {tenant}")
|
||||
if namespace != "default":
|
||||
raise ValueError(
|
||||
f"Only the default namespace is supported, got {namespace}"
|
||||
)
|
||||
|
||||
# Make a PUT request to the admin api to delete the topic
|
||||
path = f"/admin/v2/persistent/{tenant}/{namespace}/{topic_name}"
|
||||
# Force delete the topic
|
||||
path += "?force=true"
|
||||
url = self._connection_str + path
|
||||
response = requests.delete(url)
|
||||
if response.status_code != 204 and response.status_code != 409:
|
||||
raise RuntimeError(f"Failed to delete topic {topic_name}")
|
||||
20
chromadb/ingest/impl/utils.py
Normal file
20
chromadb/ingest/impl/utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
topic_regex = r"persistent:\/\/(?P<tenant>.+)\/(?P<namespace>.+)\/(?P<topic>.+)"
|
||||
|
||||
|
||||
def parse_topic_name(topic_name: str) -> Tuple[str, str, str]:
|
||||
"""Parse the topic name into the tenant, namespace and topic name"""
|
||||
match = re.match(topic_regex, topic_name)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid topic name: {topic_name}")
|
||||
return match.group("tenant"), match.group("namespace"), match.group("topic")
|
||||
|
||||
|
||||
def create_pulsar_connection_str(host: str, port: str) -> str:
|
||||
return f"pulsar://{host}:{port}"
|
||||
|
||||
|
||||
def create_topic_name(tenant: str, namespace: str, topic: str) -> str:
|
||||
return f"persistent://{tenant}/{namespace}/{topic}"
|
||||
40
chromadb/proto/chroma.proto
Normal file
40
chromadb/proto/chroma.proto
Normal file
@@ -0,0 +1,40 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package chroma;
|
||||
|
||||
enum Operation {
|
||||
ADD = 0;
|
||||
UPDATE = 1;
|
||||
UPSERT = 2;
|
||||
DELETE = 3;
|
||||
}
|
||||
|
||||
enum ScalarEncoding {
|
||||
FLOAT32 = 0;
|
||||
INT32 = 1;
|
||||
}
|
||||
|
||||
message Vector {
|
||||
int32 dimension = 1;
|
||||
bytes vector = 2;
|
||||
ScalarEncoding encoding = 3;
|
||||
}
|
||||
|
||||
message UpdateMetadataValue {
|
||||
oneof value {
|
||||
string string_value = 1;
|
||||
int64 int_value = 2;
|
||||
double float_value = 3;
|
||||
}
|
||||
}
|
||||
|
||||
message UpdateMetadata {
|
||||
map<string, UpdateMetadataValue> metadata = 1;
|
||||
}
|
||||
|
||||
message SubmitEmbeddingRecord {
|
||||
string id = 1;
|
||||
optional Vector vector = 2;
|
||||
optional UpdateMetadata metadata = 3;
|
||||
Operation operation = 4;
|
||||
}
|
||||
42
chromadb/proto/chroma_pb2.py
Normal file
42
chromadb/proto/chroma_pb2.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: chromadb/proto/chroma.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
||||
b'\n\x1b\x63hromadb/proto/chroma.proto\x12\x06\x63hroma"U\n\x06Vector\x12\x11\n\tdimension\x18\x01 \x01(\x05\x12\x0e\n\x06vector\x18\x02 \x01(\x0c\x12(\n\x08\x65ncoding\x18\x03 \x01(\x0e\x32\x16.chroma.ScalarEncoding"b\n\x13UpdateMetadataValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x42\x07\n\x05value"\x96\x01\n\x0eUpdateMetadata\x12\x36\n\x08metadata\x18\x01 \x03(\x0b\x32$.chroma.UpdateMetadata.MetadataEntry\x1aL\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.chroma.UpdateMetadataValue:\x02\x38\x01"\xb5\x01\n\x15SubmitEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12#\n\x06vector\x18\x02 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x01\x88\x01\x01\x12$\n\toperation\x18\x04 \x01(\x0e\x32\x11.chroma.OperationB\t\n\x07_vectorB\x0b\n\t_metadata*8\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\n\n\x06UPSERT\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*(\n\x0eScalarEncoding\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\t\n\x05INT32\x10\x01\x62\x06proto3'
|
||||
)
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(
|
||||
DESCRIPTOR, "chromadb.proto.chroma_pb2", _globals
|
||||
)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
DESCRIPTOR._options = None
|
||||
_UPDATEMETADATA_METADATAENTRY._options = None
|
||||
_UPDATEMETADATA_METADATAENTRY._serialized_options = b"8\001"
|
||||
_globals["_OPERATION"]._serialized_start = 563
|
||||
_globals["_OPERATION"]._serialized_end = 619
|
||||
_globals["_SCALARENCODING"]._serialized_start = 621
|
||||
_globals["_SCALARENCODING"]._serialized_end = 661
|
||||
_globals["_VECTOR"]._serialized_start = 39
|
||||
_globals["_VECTOR"]._serialized_end = 124
|
||||
_globals["_UPDATEMETADATAVALUE"]._serialized_start = 126
|
||||
_globals["_UPDATEMETADATAVALUE"]._serialized_end = 224
|
||||
_globals["_UPDATEMETADATA"]._serialized_start = 227
|
||||
_globals["_UPDATEMETADATA"]._serialized_end = 377
|
||||
_globals["_UPDATEMETADATA_METADATAENTRY"]._serialized_start = 301
|
||||
_globals["_UPDATEMETADATA_METADATAENTRY"]._serialized_end = 377
|
||||
_globals["_SUBMITEMBEDDINGRECORD"]._serialized_start = 380
|
||||
_globals["_SUBMITEMBEDDINGRECORD"]._serialized_end = 561
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
247
chromadb/proto/chroma_pb2.pyi
Normal file
247
chromadb/proto/chroma_pb2.pyi
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
import builtins
|
||||
import collections.abc
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.internal.enum_type_wrapper
|
||||
import google.protobuf.message
|
||||
import sys
|
||||
import typing
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
import typing as typing_extensions
|
||||
else:
|
||||
import typing_extensions
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
class _Operation:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _OperationEnumTypeWrapper(
|
||||
google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_Operation.ValueType],
|
||||
builtins.type,
|
||||
):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
ADD: _Operation.ValueType # 0
|
||||
UPDATE: _Operation.ValueType # 1
|
||||
UPSERT: _Operation.ValueType # 2
|
||||
DELETE: _Operation.ValueType # 3
|
||||
|
||||
class Operation(_Operation, metaclass=_OperationEnumTypeWrapper): ...
|
||||
|
||||
ADD: Operation.ValueType # 0
|
||||
UPDATE: Operation.ValueType # 1
|
||||
UPSERT: Operation.ValueType # 2
|
||||
DELETE: Operation.ValueType # 3
|
||||
global___Operation = Operation
|
||||
|
||||
class _ScalarEncoding:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _ScalarEncodingEnumTypeWrapper(
|
||||
google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[
|
||||
_ScalarEncoding.ValueType
|
||||
],
|
||||
builtins.type,
|
||||
):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
FLOAT32: _ScalarEncoding.ValueType # 0
|
||||
INT32: _ScalarEncoding.ValueType # 1
|
||||
|
||||
class ScalarEncoding(_ScalarEncoding, metaclass=_ScalarEncodingEnumTypeWrapper): ...
|
||||
|
||||
FLOAT32: ScalarEncoding.ValueType # 0
|
||||
INT32: ScalarEncoding.ValueType # 1
|
||||
global___ScalarEncoding = ScalarEncoding
|
||||
|
||||
@typing_extensions.final
|
||||
class Vector(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
DIMENSION_FIELD_NUMBER: builtins.int
|
||||
VECTOR_FIELD_NUMBER: builtins.int
|
||||
ENCODING_FIELD_NUMBER: builtins.int
|
||||
dimension: builtins.int
|
||||
vector: builtins.bytes
|
||||
encoding: global___ScalarEncoding.ValueType
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dimension: builtins.int = ...,
|
||||
vector: builtins.bytes = ...,
|
||||
encoding: global___ScalarEncoding.ValueType = ...,
|
||||
) -> None: ...
|
||||
def ClearField(
|
||||
self,
|
||||
field_name: typing_extensions.Literal[
|
||||
"dimension", b"dimension", "encoding", b"encoding", "vector", b"vector"
|
||||
],
|
||||
) -> None: ...
|
||||
|
||||
global___Vector = Vector
|
||||
|
||||
@typing_extensions.final
|
||||
class UpdateMetadataValue(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
STRING_VALUE_FIELD_NUMBER: builtins.int
|
||||
INT_VALUE_FIELD_NUMBER: builtins.int
|
||||
FLOAT_VALUE_FIELD_NUMBER: builtins.int
|
||||
string_value: builtins.str
|
||||
int_value: builtins.int
|
||||
float_value: builtins.float
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
string_value: builtins.str = ...,
|
||||
int_value: builtins.int = ...,
|
||||
float_value: builtins.float = ...,
|
||||
) -> None: ...
|
||||
def HasField(
|
||||
self,
|
||||
field_name: typing_extensions.Literal[
|
||||
"float_value",
|
||||
b"float_value",
|
||||
"int_value",
|
||||
b"int_value",
|
||||
"string_value",
|
||||
b"string_value",
|
||||
"value",
|
||||
b"value",
|
||||
],
|
||||
) -> builtins.bool: ...
|
||||
def ClearField(
|
||||
self,
|
||||
field_name: typing_extensions.Literal[
|
||||
"float_value",
|
||||
b"float_value",
|
||||
"int_value",
|
||||
b"int_value",
|
||||
"string_value",
|
||||
b"string_value",
|
||||
"value",
|
||||
b"value",
|
||||
],
|
||||
) -> None: ...
|
||||
def WhichOneof(
|
||||
self, oneof_group: typing_extensions.Literal["value", b"value"]
|
||||
) -> (
|
||||
typing_extensions.Literal["string_value", "int_value", "float_value"] | None
|
||||
): ...
|
||||
|
||||
global___UpdateMetadataValue = UpdateMetadataValue
|
||||
|
||||
@typing_extensions.final
|
||||
class UpdateMetadata(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
@typing_extensions.final
|
||||
class MetadataEntry(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
VALUE_FIELD_NUMBER: builtins.int
|
||||
key: builtins.str
|
||||
@property
|
||||
def value(self) -> global___UpdateMetadataValue: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
key: builtins.str = ...,
|
||||
value: global___UpdateMetadataValue | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(
|
||||
self, field_name: typing_extensions.Literal["value", b"value"]
|
||||
) -> builtins.bool: ...
|
||||
def ClearField(
|
||||
self,
|
||||
field_name: typing_extensions.Literal["key", b"key", "value", b"value"],
|
||||
) -> None: ...
|
||||
|
||||
METADATA_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def metadata(
|
||||
self,
|
||||
) -> google.protobuf.internal.containers.MessageMap[
|
||||
builtins.str, global___UpdateMetadataValue
|
||||
]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metadata: collections.abc.Mapping[builtins.str, global___UpdateMetadataValue]
|
||||
| None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(
|
||||
self, field_name: typing_extensions.Literal["metadata", b"metadata"]
|
||||
) -> None: ...
|
||||
|
||||
global___UpdateMetadata = UpdateMetadata
|
||||
|
||||
@typing_extensions.final
|
||||
class SubmitEmbeddingRecord(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
ID_FIELD_NUMBER: builtins.int
|
||||
VECTOR_FIELD_NUMBER: builtins.int
|
||||
METADATA_FIELD_NUMBER: builtins.int
|
||||
OPERATION_FIELD_NUMBER: builtins.int
|
||||
id: builtins.str
|
||||
@property
|
||||
def vector(self) -> global___Vector: ...
|
||||
@property
|
||||
def metadata(self) -> global___UpdateMetadata: ...
|
||||
operation: global___Operation.ValueType
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id: builtins.str = ...,
|
||||
vector: global___Vector | None = ...,
|
||||
metadata: global___UpdateMetadata | None = ...,
|
||||
operation: global___Operation.ValueType = ...,
|
||||
) -> None: ...
|
||||
def HasField(
|
||||
self,
|
||||
field_name: typing_extensions.Literal[
|
||||
"_metadata",
|
||||
b"_metadata",
|
||||
"_vector",
|
||||
b"_vector",
|
||||
"metadata",
|
||||
b"metadata",
|
||||
"vector",
|
||||
b"vector",
|
||||
],
|
||||
) -> builtins.bool: ...
|
||||
def ClearField(
|
||||
self,
|
||||
field_name: typing_extensions.Literal[
|
||||
"_metadata",
|
||||
b"_metadata",
|
||||
"_vector",
|
||||
b"_vector",
|
||||
"id",
|
||||
b"id",
|
||||
"metadata",
|
||||
b"metadata",
|
||||
"operation",
|
||||
b"operation",
|
||||
"vector",
|
||||
b"vector",
|
||||
],
|
||||
) -> None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(
|
||||
self, oneof_group: typing_extensions.Literal["_metadata", b"_metadata"]
|
||||
) -> typing_extensions.Literal["metadata"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(
|
||||
self, oneof_group: typing_extensions.Literal["_vector", b"_vector"]
|
||||
) -> typing_extensions.Literal["vector"] | None: ...
|
||||
|
||||
global___SubmitEmbeddingRecord = SubmitEmbeddingRecord
|
||||
150
chromadb/proto/convert.py
Normal file
150
chromadb/proto/convert.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import array
|
||||
from typing import Optional, Tuple, Union
|
||||
from chromadb.api.types import Embedding
|
||||
import chromadb.proto.chroma_pb2 as proto
|
||||
from chromadb.types import (
|
||||
EmbeddingRecord,
|
||||
Metadata,
|
||||
Operation,
|
||||
ScalarEncoding,
|
||||
SeqId,
|
||||
SubmitEmbeddingRecord,
|
||||
Vector,
|
||||
)
|
||||
|
||||
|
||||
def to_proto_vector(vector: Vector, encoding: ScalarEncoding) -> proto.Vector:
|
||||
if encoding == ScalarEncoding.FLOAT32:
|
||||
as_bytes = array.array("f", vector).tobytes()
|
||||
proto_encoding = proto.ScalarEncoding.FLOAT32
|
||||
elif encoding == ScalarEncoding.INT32:
|
||||
as_bytes = array.array("i", vector).tobytes()
|
||||
proto_encoding = proto.ScalarEncoding.INT32
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown encoding {encoding}, expected one of {ScalarEncoding.FLOAT32} \
|
||||
or {ScalarEncoding.INT32}"
|
||||
)
|
||||
|
||||
return proto.Vector(dimension=len(vector), vector=as_bytes, encoding=proto_encoding)
|
||||
|
||||
|
||||
def from_proto_vector(vector: proto.Vector) -> Tuple[Embedding, ScalarEncoding]:
|
||||
encoding = vector.encoding
|
||||
as_array: array.array[float] | array.array[int]
|
||||
if encoding == proto.ScalarEncoding.FLOAT32:
|
||||
as_array = array.array("f")
|
||||
out_encoding = ScalarEncoding.FLOAT32
|
||||
elif encoding == proto.ScalarEncoding.INT32:
|
||||
as_array = array.array("i")
|
||||
out_encoding = ScalarEncoding.INT32
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown encoding {encoding}, expected one of \
|
||||
{proto.ScalarEncoding.FLOAT32} or {proto.ScalarEncoding.INT32}"
|
||||
)
|
||||
|
||||
as_array.frombytes(vector.vector)
|
||||
return (as_array.tolist(), out_encoding)
|
||||
|
||||
|
||||
def from_proto_operation(operation: proto.Operation.ValueType) -> Operation:
|
||||
if operation == proto.Operation.ADD:
|
||||
return Operation.ADD
|
||||
elif operation == proto.Operation.UPDATE:
|
||||
return Operation.UPDATE
|
||||
elif operation == proto.Operation.UPSERT:
|
||||
return Operation.UPSERT
|
||||
elif operation == proto.Operation.DELETE:
|
||||
return Operation.DELETE
|
||||
else:
|
||||
raise RuntimeError(f"Unknown operation {operation}") # TODO: full error
|
||||
|
||||
|
||||
def from_proto_metadata(metadata: proto.UpdateMetadata) -> Optional[Metadata]:
|
||||
if not metadata.metadata:
|
||||
return None
|
||||
out_metadata = {}
|
||||
for key, value in metadata.metadata.items():
|
||||
if value.HasField("string_value"):
|
||||
out_metadata[key] = value.string_value
|
||||
elif value.HasField("int_value"):
|
||||
out_metadata[key] = value.int_value
|
||||
elif value.HasField("float_value"):
|
||||
out_metadata[key] = value.float_value
|
||||
else:
|
||||
raise RuntimeError(f"Unknown metadata value type {value}")
|
||||
return out_metadata
|
||||
|
||||
|
||||
def from_proto_submit(
|
||||
submit_embedding_record: proto.SubmitEmbeddingRecord, seq_id: SeqId
|
||||
) -> EmbeddingRecord:
|
||||
embedding, encoding = from_proto_vector(submit_embedding_record.vector)
|
||||
record = EmbeddingRecord(
|
||||
id=submit_embedding_record.id,
|
||||
seq_id=seq_id,
|
||||
embedding=embedding,
|
||||
encoding=encoding,
|
||||
metadata=from_proto_metadata(submit_embedding_record.metadata),
|
||||
operation=from_proto_operation(submit_embedding_record.operation),
|
||||
)
|
||||
return record
|
||||
|
||||
|
||||
def to_proto_metadata_update_value(
|
||||
value: Union[str, int, float, None]
|
||||
) -> proto.UpdateMetadataValue:
|
||||
if isinstance(value, str):
|
||||
return proto.UpdateMetadataValue(string_value=value)
|
||||
elif isinstance(value, int):
|
||||
return proto.UpdateMetadataValue(int_value=value)
|
||||
elif isinstance(value, float):
|
||||
return proto.UpdateMetadataValue(float_value=value)
|
||||
elif value is None:
|
||||
return proto.UpdateMetadataValue()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown metadata value type {type(value)}, expected one of str, int, \
|
||||
float, or None"
|
||||
)
|
||||
|
||||
|
||||
def to_proto_operation(operation: Operation) -> proto.Operation.ValueType:
|
||||
if operation == Operation.ADD:
|
||||
return proto.Operation.ADD
|
||||
elif operation == Operation.UPDATE:
|
||||
return proto.Operation.UPDATE
|
||||
elif operation == Operation.UPSERT:
|
||||
return proto.Operation.UPSERT
|
||||
elif operation == Operation.DELETE:
|
||||
return proto.Operation.DELETE
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown operation {operation}, expected one of {Operation.ADD}, \
|
||||
{Operation.UPDATE}, {Operation.UPDATE}, or {Operation.DELETE}"
|
||||
)
|
||||
|
||||
|
||||
def to_proto_submit(
|
||||
submit_record: SubmitEmbeddingRecord,
|
||||
) -> proto.SubmitEmbeddingRecord:
|
||||
vector = None
|
||||
if submit_record["embedding"] is not None and submit_record["encoding"] is not None:
|
||||
vector = to_proto_vector(submit_record["embedding"], submit_record["encoding"])
|
||||
|
||||
metadata = None
|
||||
if submit_record["metadata"] is not None:
|
||||
metadata = {
|
||||
k: to_proto_metadata_update_value(v)
|
||||
for k, v in submit_record["metadata"].items()
|
||||
}
|
||||
|
||||
return proto.SubmitEmbeddingRecord(
|
||||
id=submit_record["id"],
|
||||
vector=vector,
|
||||
metadata=proto.UpdateMetadata(metadata=metadata)
|
||||
if metadata is not None
|
||||
else None,
|
||||
operation=to_proto_operation(submit_record["operation"]),
|
||||
)
|
||||
@@ -469,9 +469,9 @@ class SqliteMetadataSegment(MetadataReader):
|
||||
|
||||
def _encode_seq_id(seq_id: SeqId) -> bytes:
|
||||
"""Encode a SeqID into a byte array"""
|
||||
if seq_id.bit_length() < 64:
|
||||
if seq_id.bit_length() <= 64:
|
||||
return int.to_bytes(seq_id, 8, "big")
|
||||
elif seq_id.bit_length() < 192:
|
||||
elif seq_id.bit_length() <= 192:
|
||||
return int.to_bytes(seq_id, 24, "big")
|
||||
else:
|
||||
raise ValueError(f"Unsupported SeqID: {seq_id}")
|
||||
|
||||
@@ -207,7 +207,6 @@ class PersistentLocalHnswSegment(LocalHnswSegment):
|
||||
"""Add a batch of embeddings to the index"""
|
||||
if not self._running:
|
||||
raise RuntimeError("Cannot add embeddings to stopped component")
|
||||
|
||||
with WriteRWLock(self._lock):
|
||||
for record in records:
|
||||
if record["embedding"] is not None:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
@@ -16,6 +17,7 @@ from typing import (
|
||||
)
|
||||
from chromadb.ingest import Producer, Consumer
|
||||
from chromadb.db.impl.sqlite import SqliteDB
|
||||
from chromadb.ingest.impl.utils import create_topic_name
|
||||
from chromadb.test.conftest import ProducerFn
|
||||
from chromadb.types import (
|
||||
SubmitEmbeddingRecord,
|
||||
@@ -51,8 +53,33 @@ def sqlite_persistent() -> Generator[Tuple[Producer, Consumer], None, None]:
|
||||
shutil.rmtree(save_path)
|
||||
|
||||
|
||||
def pulsar() -> Generator[Tuple[Producer, Consumer], None, None]:
|
||||
"""Fixture generator for pulsar Producer + Consumer. This fixture requires a running
|
||||
pulsar cluster. You can use bin/cluster-test.sh to start a standalone pulsar and run this test
|
||||
"""
|
||||
system = System(
|
||||
Settings(
|
||||
allow_reset=True,
|
||||
chroma_producer_impl="chromadb.ingest.impl.pulsar.PulsarProducer",
|
||||
chroma_consumer_impl="chromadb.ingest.impl.pulsar.PulsarConsumer",
|
||||
pulsar_broker_url="localhost",
|
||||
pulsar_admin_port="8080",
|
||||
pulsar_broker_port="6650",
|
||||
)
|
||||
)
|
||||
producer = system.require(Producer)
|
||||
consumer = system.require(Consumer)
|
||||
system.start()
|
||||
yield producer, consumer
|
||||
system.stop()
|
||||
|
||||
|
||||
def fixtures() -> List[Callable[[], Generator[Tuple[Producer, Consumer], None, None]]]:
|
||||
return [sqlite, sqlite_persistent]
|
||||
fixtures = [sqlite, sqlite_persistent]
|
||||
if "CHROMA_CLUSTER_TEST_ONLY" in os.environ:
|
||||
fixtures = [pulsar]
|
||||
|
||||
return fixtures
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=fixtures())
|
||||
@@ -89,14 +116,20 @@ class CapturingConsumeFn:
|
||||
waiters: List[Tuple[int, Event]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""A function that captures embeddings and allows you to wait for a certain
|
||||
number of embeddings to be available. It must be constructed in the thread with
|
||||
the main event loop
|
||||
"""
|
||||
self.embeddings = []
|
||||
self.waiters = []
|
||||
self._loop = asyncio.get_event_loop()
|
||||
|
||||
def __call__(self, embeddings: Sequence[EmbeddingRecord]) -> None:
|
||||
self.embeddings.extend(embeddings)
|
||||
for n, event in self.waiters:
|
||||
if len(self.embeddings) >= n:
|
||||
event.set()
|
||||
# event.set() is not thread safe, so we need to call it in the main event loop
|
||||
self._loop.call_soon_threadsafe(event.set)
|
||||
|
||||
async def get(self, n: int, timeout_secs: int = 10) -> Sequence[EmbeddingRecord]:
|
||||
"Wait until at least N embeddings are available, then return all embeddings"
|
||||
@@ -132,6 +165,10 @@ def assert_records_match(
|
||||
assert_approx_equal(inserted["embedding"], consumed["embedding"])
|
||||
|
||||
|
||||
def full_topic_name(topic_name: str) -> str:
|
||||
return create_topic_name("default", "default", topic_name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill(
|
||||
producer_consumer: Tuple[Producer, Consumer],
|
||||
@@ -140,12 +177,14 @@ async def test_backfill(
|
||||
) -> None:
|
||||
producer, consumer = producer_consumer
|
||||
producer.reset_state()
|
||||
consumer.reset_state()
|
||||
|
||||
producer.create_topic("test_topic")
|
||||
embeddings = produce_fns(producer, "test_topic", sample_embeddings, 3)[0]
|
||||
topic_name = full_topic_name("test_topic")
|
||||
producer.create_topic(topic_name)
|
||||
embeddings = produce_fns(producer, topic_name, sample_embeddings, 3)[0]
|
||||
|
||||
consume_fn = CapturingConsumeFn()
|
||||
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
|
||||
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid())
|
||||
|
||||
recieved = await consume_fn.get(3)
|
||||
assert_records_match(embeddings, recieved)
|
||||
@@ -158,18 +197,21 @@ async def test_notifications(
|
||||
) -> None:
|
||||
producer, consumer = producer_consumer
|
||||
producer.reset_state()
|
||||
producer.create_topic("test_topic")
|
||||
consumer.reset_state()
|
||||
topic_name = full_topic_name("test_topic")
|
||||
|
||||
producer.create_topic(topic_name)
|
||||
|
||||
embeddings: List[SubmitEmbeddingRecord] = []
|
||||
|
||||
consume_fn = CapturingConsumeFn()
|
||||
|
||||
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
|
||||
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid())
|
||||
|
||||
for i in range(10):
|
||||
e = next(sample_embeddings)
|
||||
embeddings.append(e)
|
||||
producer.submit_embedding("test_topic", e)
|
||||
producer.submit_embedding(topic_name, e)
|
||||
received = await consume_fn.get(i + 1)
|
||||
assert_records_match(embeddings, received)
|
||||
|
||||
@@ -181,8 +223,11 @@ async def test_multiple_topics(
|
||||
) -> None:
|
||||
producer, consumer = producer_consumer
|
||||
producer.reset_state()
|
||||
producer.create_topic("test_topic_1")
|
||||
producer.create_topic("test_topic_2")
|
||||
consumer.reset_state()
|
||||
topic_name_1 = full_topic_name("test_topic_1")
|
||||
topic_name_2 = full_topic_name("test_topic_2")
|
||||
producer.create_topic(topic_name_1)
|
||||
producer.create_topic(topic_name_2)
|
||||
|
||||
embeddings_1: List[SubmitEmbeddingRecord] = []
|
||||
embeddings_2: List[SubmitEmbeddingRecord] = []
|
||||
@@ -190,19 +235,19 @@ async def test_multiple_topics(
|
||||
consume_fn_1 = CapturingConsumeFn()
|
||||
consume_fn_2 = CapturingConsumeFn()
|
||||
|
||||
consumer.subscribe("test_topic_1", consume_fn_1, start=consumer.min_seqid())
|
||||
consumer.subscribe("test_topic_2", consume_fn_2, start=consumer.min_seqid())
|
||||
consumer.subscribe(topic_name_1, consume_fn_1, start=consumer.min_seqid())
|
||||
consumer.subscribe(topic_name_2, consume_fn_2, start=consumer.min_seqid())
|
||||
|
||||
for i in range(10):
|
||||
e_1 = next(sample_embeddings)
|
||||
embeddings_1.append(e_1)
|
||||
producer.submit_embedding("test_topic_1", e_1)
|
||||
producer.submit_embedding(topic_name_1, e_1)
|
||||
results_2 = await consume_fn_1.get(i + 1)
|
||||
assert_records_match(embeddings_1, results_2)
|
||||
|
||||
e_2 = next(sample_embeddings)
|
||||
embeddings_2.append(e_2)
|
||||
producer.submit_embedding("test_topic_2", e_2)
|
||||
producer.submit_embedding(topic_name_2, e_2)
|
||||
results_2 = await consume_fn_2.get(i + 1)
|
||||
assert_records_match(embeddings_2, results_2)
|
||||
|
||||
@@ -215,21 +260,23 @@ async def test_start_seq_id(
|
||||
) -> None:
|
||||
producer, consumer = producer_consumer
|
||||
producer.reset_state()
|
||||
producer.create_topic("test_topic")
|
||||
consumer.reset_state()
|
||||
topic_name = full_topic_name("test_topic")
|
||||
producer.create_topic(topic_name)
|
||||
|
||||
consume_fn_1 = CapturingConsumeFn()
|
||||
consume_fn_2 = CapturingConsumeFn()
|
||||
|
||||
consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid())
|
||||
consumer.subscribe(topic_name, consume_fn_1, start=consumer.min_seqid())
|
||||
|
||||
embeddings = produce_fns(producer, "test_topic", sample_embeddings, 5)[0]
|
||||
embeddings = produce_fns(producer, topic_name, sample_embeddings, 5)[0]
|
||||
|
||||
results_1 = await consume_fn_1.get(5)
|
||||
assert_records_match(embeddings, results_1)
|
||||
|
||||
start = consume_fn_1.embeddings[-1]["seq_id"]
|
||||
consumer.subscribe("test_topic", consume_fn_2, start=start)
|
||||
second_embeddings = produce_fns(producer, "test_topic", sample_embeddings, 5)[0]
|
||||
consumer.subscribe(topic_name, consume_fn_2, start=start)
|
||||
second_embeddings = produce_fns(producer, topic_name, sample_embeddings, 5)[0]
|
||||
assert isinstance(embeddings, list)
|
||||
embeddings.extend(second_embeddings)
|
||||
results_2 = await consume_fn_2.get(5)
|
||||
@@ -244,20 +291,22 @@ async def test_end_seq_id(
|
||||
) -> None:
|
||||
producer, consumer = producer_consumer
|
||||
producer.reset_state()
|
||||
producer.create_topic("test_topic")
|
||||
consumer.reset_state()
|
||||
topic_name = full_topic_name("test_topic")
|
||||
producer.create_topic(topic_name)
|
||||
|
||||
consume_fn_1 = CapturingConsumeFn()
|
||||
consume_fn_2 = CapturingConsumeFn()
|
||||
|
||||
consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid())
|
||||
consumer.subscribe(topic_name, consume_fn_1, start=consumer.min_seqid())
|
||||
|
||||
embeddings = produce_fns(producer, "test_topic", sample_embeddings, 10)[0]
|
||||
embeddings = produce_fns(producer, topic_name, sample_embeddings, 10)[0]
|
||||
|
||||
results_1 = await consume_fn_1.get(10)
|
||||
assert_records_match(embeddings, results_1)
|
||||
|
||||
end = consume_fn_1.embeddings[-5]["seq_id"]
|
||||
consumer.subscribe("test_topic", consume_fn_2, start=consumer.min_seqid(), end=end)
|
||||
consumer.subscribe(topic_name, consume_fn_2, start=consumer.min_seqid(), end=end)
|
||||
|
||||
results_2 = await consume_fn_2.get(6)
|
||||
assert_records_match(embeddings[:6], results_2)
|
||||
@@ -274,14 +323,16 @@ async def test_submit_batch(
|
||||
) -> None:
|
||||
producer, consumer = producer_consumer
|
||||
producer.reset_state()
|
||||
consumer.reset_state()
|
||||
topic_name = full_topic_name("test_topic")
|
||||
|
||||
embeddings = [next(sample_embeddings) for _ in range(100)]
|
||||
|
||||
producer.create_topic("test_topic")
|
||||
producer.submit_embeddings("test_topic", embeddings=embeddings)
|
||||
producer.create_topic(topic_name)
|
||||
producer.submit_embeddings(topic_name, embeddings=embeddings)
|
||||
|
||||
consume_fn = CapturingConsumeFn()
|
||||
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
|
||||
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid())
|
||||
|
||||
recieved = await consume_fn.get(100)
|
||||
assert_records_match(embeddings, recieved)
|
||||
@@ -295,13 +346,16 @@ async def test_multiple_topics_batch(
|
||||
) -> None:
|
||||
producer, consumer = producer_consumer
|
||||
producer.reset_state()
|
||||
consumer.reset_state()
|
||||
|
||||
N_TOPICS = 100
|
||||
N_TOPICS = 2
|
||||
consume_fns = [CapturingConsumeFn() for _ in range(N_TOPICS)]
|
||||
for i in range(N_TOPICS):
|
||||
producer.create_topic(f"test_topic_{i}")
|
||||
producer.create_topic(full_topic_name(f"test_topic_{i}"))
|
||||
consumer.subscribe(
|
||||
f"test_topic_{i}", consume_fns[i], start=consumer.min_seqid()
|
||||
full_topic_name(f"test_topic_{i}"),
|
||||
consume_fns[i],
|
||||
start=consumer.min_seqid(),
|
||||
)
|
||||
|
||||
embeddings_n: List[List[SubmitEmbeddingRecord]] = [[] for _ in range(N_TOPICS)]
|
||||
@@ -310,17 +364,17 @@ async def test_multiple_topics_batch(
|
||||
N_TO_PRODUCE = 100
|
||||
total_produced = 0
|
||||
for i in range(N_TO_PRODUCE // PRODUCE_BATCH_SIZE):
|
||||
for i in range(N_TOPICS):
|
||||
embeddings_n[i].extend(
|
||||
for n in range(N_TOPICS):
|
||||
embeddings_n[n].extend(
|
||||
produce_fns(
|
||||
producer,
|
||||
f"test_topic_{i}",
|
||||
full_topic_name(f"test_topic_{n}"),
|
||||
sample_embeddings,
|
||||
PRODUCE_BATCH_SIZE,
|
||||
)[0]
|
||||
)
|
||||
recieved = await consume_fns[i].get(total_produced + PRODUCE_BATCH_SIZE)
|
||||
assert_records_match(embeddings_n[i], recieved)
|
||||
recieved = await consume_fns[n].get(total_produced + PRODUCE_BATCH_SIZE)
|
||||
assert_records_match(embeddings_n[n], recieved)
|
||||
total_produced += PRODUCE_BATCH_SIZE
|
||||
|
||||
|
||||
@@ -331,19 +385,21 @@ async def test_max_batch_size(
|
||||
) -> None:
|
||||
producer, consumer = producer_consumer
|
||||
producer.reset_state()
|
||||
max_batch_size = producer_consumer[0].max_batch_size
|
||||
consumer.reset_state()
|
||||
topic_name = full_topic_name("test_topic")
|
||||
max_batch_size = producer.max_batch_size
|
||||
assert max_batch_size > 0
|
||||
|
||||
# Make sure that we can produce a batch of size max_batch_size
|
||||
embeddings = [next(sample_embeddings) for _ in range(max_batch_size)]
|
||||
consume_fn = CapturingConsumeFn()
|
||||
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
|
||||
producer.submit_embeddings("test_topic", embeddings=embeddings)
|
||||
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid())
|
||||
producer.submit_embeddings(topic_name, embeddings=embeddings)
|
||||
received = await consume_fn.get(max_batch_size, timeout_secs=120)
|
||||
assert_records_match(embeddings, received)
|
||||
|
||||
embeddings = [next(sample_embeddings) for _ in range(max_batch_size + 1)]
|
||||
# Make sure that we can't produce a batch of size > max_batch_size
|
||||
with pytest.raises(ValueError) as e:
|
||||
producer.submit_embeddings("test_topic", embeddings=embeddings)
|
||||
producer.submit_embeddings(topic_name, embeddings=embeddings)
|
||||
assert "Cannot submit more than" in str(e.value)
|
||||
|
||||
66
docker-compose.cluster.yml
Normal file
66
docker-compose.cluster.yml
Normal file
@@ -0,0 +1,66 @@
|
||||
# This docker compose file is not meant to be used. It is a work in progress
|
||||
# for the distributed version of Chroma. It is not yet functional.
|
||||
|
||||
version: '3.9'
|
||||
|
||||
networks:
|
||||
net:
|
||||
driver: bridge
|
||||
|
||||
services:
|
||||
server:
|
||||
image: server
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
volumes:
|
||||
- ./:/chroma
|
||||
- index_data:/index_data
|
||||
command: uvicorn chromadb.app:app --reload --workers 1 --host 0.0.0.0 --port 8000 --log-config log_config.yml
|
||||
environment:
|
||||
- IS_PERSISTENT=TRUE
|
||||
- CHROMA_PRODUCER_IMPL=chromadb.ingest.impl.pulsar.PulsarProducer
|
||||
- CHROMA_CONSUMER_IMPL=chromadb.ingest.impl.pulsar.PulsarConsumer
|
||||
- PULSAR_BROKER_URL=pulsar
|
||||
- PULSAR_BROKER_PORT=6650
|
||||
- PULSAR_ADMIN_PORT=8080
|
||||
ports:
|
||||
- 8000:8000
|
||||
depends_on:
|
||||
pulsar:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- net
|
||||
|
||||
pulsar:
|
||||
image: apachepulsar/pulsar
|
||||
volumes:
|
||||
- pulsardata:/pulsar/data
|
||||
- pulsarconf:/pulsar/conf
|
||||
command: bin/pulsar standalone
|
||||
ports:
|
||||
- 6650:6650
|
||||
- 8080:8080
|
||||
networks:
|
||||
- net
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD",
|
||||
"curl",
|
||||
"-f",
|
||||
"localhost:8080/admin/v2/brokers/health"
|
||||
]
|
||||
interval: 3s
|
||||
timeout: 1m
|
||||
retries: 10
|
||||
|
||||
volumes:
|
||||
index_data:
|
||||
driver: local
|
||||
backups:
|
||||
driver: local
|
||||
pulsardata:
|
||||
driver: local
|
||||
pulsarconf:
|
||||
driver: local
|
||||
@@ -3,8 +3,10 @@ build
|
||||
httpx
|
||||
hypothesis
|
||||
hypothesis[numpy]
|
||||
mypy-protobuf
|
||||
pre-commit
|
||||
pytest
|
||||
pytest-asyncio
|
||||
setuptools_scm
|
||||
types-protobuf
|
||||
types-requests==2.30.0.0
|
||||
|
||||
Reference in New Issue
Block a user