mirror of
https://github.com/placeholder-soft/chroma.git
synced 2026-01-12 17:02:54 +08:00
[PERF] Batch SQLite Embeddings Queue (#959)
## Description of changes Adds a batch submit_embeddings() to the producer API. Implements this behavior in sqlite based embeddings queue. Refactors tests to add a fixture for both single and batch submit and substitutes all use of produce in tests with the fixture so both methods are uniformly tested. This PR is stacked on #958, please read that before reading this. *Summarize the changes made by this PR.* - Improvements & Bug fixes - This PR makes the sqlite based embeddings queue batch calls to sqlite based on the batch size of calls to DML operations. ### Quick Benchmarks N = 100k D = 128 metadata = {simple key: simple value} Document = random 100 character string **Before - Overall time: 102s** <img width="587" alt="Screenshot 2023-08-09 at 5 43 12 PM" src="https://github.com/chroma-core/chroma/assets/5598697/752809a2-78a0-4d5b-a238-fc41ca2635cc"> **After - Overall time: 36s** <img width="583" alt="Screenshot 2023-08-09 at 7 37 41 PM" src="https://github.com/chroma-core/chroma/assets/5598697/eb1759c8-8f62-4ed7-ad88-450b3f72692b"> **_Overall, when compared to main this and #958 decrease the time of this benchmark from 469s -> 36s._** Todo: - [x] Clean up code - [x] Rethink assumptions about submit_embeddings ordering guarantees - [x] Add tests - [x] Add batching to other DML - upsert(), update(), delete() ## Test plan Existing tests cover a bulk of the functionality. New tests TODO: - [x] Producer/Consumer tests for submit_embeddings - [x] Segment tests with batch embeddings ## Documentation Changes None required. But a misc feature idea is hyperparameter sweep for the optimal batch_size for your platform
This commit is contained in:
@@ -242,9 +242,11 @@ class SegmentAPI(API):
|
||||
coll = self._get_collection(collection_id)
|
||||
self._manager.hint_use_collection(collection_id, t.Operation.ADD)
|
||||
|
||||
records_to_submit = []
|
||||
for r in _records(t.Operation.ADD, ids, embeddings, metadatas, documents):
|
||||
self._validate_embedding_record(coll, r)
|
||||
self._producer.submit_embedding(coll["topic"], r)
|
||||
records_to_submit.append(r)
|
||||
self._producer.submit_embeddings(coll["topic"], records_to_submit)
|
||||
|
||||
self._telemetry_client.capture(CollectionAddEvent(str(collection_id), len(ids)))
|
||||
return True
|
||||
@@ -261,9 +263,11 @@ class SegmentAPI(API):
|
||||
coll = self._get_collection(collection_id)
|
||||
self._manager.hint_use_collection(collection_id, t.Operation.UPDATE)
|
||||
|
||||
records_to_submit = []
|
||||
for r in _records(t.Operation.UPDATE, ids, embeddings, metadatas, documents):
|
||||
self._validate_embedding_record(coll, r)
|
||||
self._producer.submit_embedding(coll["topic"], r)
|
||||
records_to_submit.append(r)
|
||||
self._producer.submit_embeddings(coll["topic"], records_to_submit)
|
||||
|
||||
return True
|
||||
|
||||
@@ -279,9 +283,11 @@ class SegmentAPI(API):
|
||||
coll = self._get_collection(collection_id)
|
||||
self._manager.hint_use_collection(collection_id, t.Operation.UPSERT)
|
||||
|
||||
records_to_submit = []
|
||||
for r in _records(t.Operation.UPSERT, ids, embeddings, metadatas, documents):
|
||||
self._validate_embedding_record(coll, r)
|
||||
self._producer.submit_embedding(coll["topic"], r)
|
||||
records_to_submit.append(r)
|
||||
self._producer.submit_embeddings(coll["topic"], records_to_submit)
|
||||
|
||||
return True
|
||||
|
||||
@@ -376,9 +382,11 @@ class SegmentAPI(API):
|
||||
else:
|
||||
ids_to_delete = ids
|
||||
|
||||
records_to_submit = []
|
||||
for r in _records(t.Operation.DELETE, ids_to_delete):
|
||||
self._validate_embedding_record(coll, r)
|
||||
self._producer.submit_embedding(coll["topic"], r)
|
||||
records_to_submit.append(r)
|
||||
self._producer.submit_embeddings(coll["topic"], records_to_submit)
|
||||
|
||||
self._telemetry_client.capture(
|
||||
CollectionDeleteEvent(str(collection_id), len(ids_to_delete))
|
||||
|
||||
@@ -16,7 +16,7 @@ from chromadb.types import (
|
||||
from chromadb.config import System
|
||||
from overrides import override
|
||||
from collections import defaultdict
|
||||
from typing import Tuple, Optional, Dict, Set, cast
|
||||
from typing import Sequence, Tuple, Optional, Dict, Set, cast
|
||||
from uuid import UUID
|
||||
from pypika import Table, functions
|
||||
import uuid
|
||||
@@ -103,22 +103,32 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
|
||||
if not self._running:
|
||||
raise RuntimeError("Component not running")
|
||||
|
||||
if embedding["embedding"]:
|
||||
encoding_type = cast(ScalarEncoding, embedding["encoding"])
|
||||
encoding = encoding_type.value
|
||||
embedding_bytes = encode_vector(embedding["embedding"], encoding_type)
|
||||
return self.submit_embeddings(topic_name, [embedding])[0]
|
||||
|
||||
else:
|
||||
embedding_bytes = None
|
||||
encoding = None
|
||||
metadata = json.dumps(embedding["metadata"]) if embedding["metadata"] else None
|
||||
@override
|
||||
def submit_embeddings(
|
||||
self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord]
|
||||
) -> Sequence[SeqId]:
|
||||
if not self._running:
|
||||
raise RuntimeError("Component not running")
|
||||
|
||||
if len(embeddings) == 0:
|
||||
return []
|
||||
|
||||
t = Table("embeddings_queue")
|
||||
insert = (
|
||||
self.querybuilder()
|
||||
.into(t)
|
||||
.columns(t.operation, t.topic, t.id, t.vector, t.encoding, t.metadata)
|
||||
.insert(
|
||||
)
|
||||
id_to_idx: Dict[str, int] = {}
|
||||
for embedding in embeddings:
|
||||
(
|
||||
embedding_bytes,
|
||||
encoding,
|
||||
metadata,
|
||||
) = self._prepare_vector_encoding_metadata(embedding)
|
||||
insert = insert.insert(
|
||||
ParameterValue(_operation_codes[embedding["operation"]]),
|
||||
ParameterValue(topic_name),
|
||||
ParameterValue(embedding["id"]),
|
||||
@@ -126,21 +136,34 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
|
||||
ParameterValue(encoding),
|
||||
ParameterValue(metadata),
|
||||
)
|
||||
)
|
||||
id_to_idx[embedding["id"]] = len(id_to_idx)
|
||||
with self.tx() as cur:
|
||||
sql, params = get_sql(insert, self.parameter_format())
|
||||
sql = f"{sql} RETURNING seq_id" # Pypika doesn't support RETURNING
|
||||
seq_id = int(cur.execute(sql, params).fetchone()[0])
|
||||
embedding_record = EmbeddingRecord(
|
||||
id=embedding["id"],
|
||||
seq_id=seq_id,
|
||||
embedding=embedding["embedding"],
|
||||
encoding=embedding["encoding"],
|
||||
metadata=embedding["metadata"],
|
||||
operation=embedding["operation"],
|
||||
)
|
||||
self._notify_all(topic_name, embedding_record)
|
||||
return seq_id
|
||||
# The returning clause does not guarantee order, so we need to do reorder
|
||||
# the results. https://www.sqlite.org/lang_returning.html
|
||||
sql = f"{sql} RETURNING seq_id, id" # Pypika doesn't support RETURNING
|
||||
results = cur.execute(sql, params).fetchall()
|
||||
# Reorder the results
|
||||
seq_ids = [cast(SeqId, None)] * len(
|
||||
results
|
||||
) # Lie to mypy: https://stackoverflow.com/questions/76694215/python-type-casting-when-preallocating-list
|
||||
embedding_records = []
|
||||
for seq_id, id in results:
|
||||
seq_ids[id_to_idx[id]] = seq_id
|
||||
submit_embedding_record = embeddings[id_to_idx[id]]
|
||||
# We allow notifying consumers out of order relative to one call to
|
||||
# submit_embeddings so we do not reorder the records before submitting them
|
||||
embedding_record = EmbeddingRecord(
|
||||
id=id,
|
||||
seq_id=seq_id,
|
||||
embedding=submit_embedding_record["embedding"],
|
||||
encoding=submit_embedding_record["encoding"],
|
||||
metadata=submit_embedding_record["metadata"],
|
||||
operation=submit_embedding_record["operation"],
|
||||
)
|
||||
embedding_records.append(embedding_record)
|
||||
self._notify_all(topic_name, embedding_records)
|
||||
return seq_ids
|
||||
|
||||
@override
|
||||
def subscribe(
|
||||
@@ -185,6 +208,19 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
|
||||
def max_seqid(self) -> SeqId:
|
||||
return 2**63 - 1
|
||||
|
||||
def _prepare_vector_encoding_metadata(
|
||||
self, embedding: SubmitEmbeddingRecord
|
||||
) -> Tuple[Optional[bytes], Optional[str], Optional[str]]:
|
||||
if embedding["embedding"]:
|
||||
encoding_type = cast(ScalarEncoding, embedding["encoding"])
|
||||
encoding = encoding_type.value
|
||||
embedding_bytes = encode_vector(embedding["embedding"], encoding_type)
|
||||
else:
|
||||
embedding_bytes = None
|
||||
encoding = None
|
||||
metadata = json.dumps(embedding["metadata"]) if embedding["metadata"] else None
|
||||
return embedding_bytes, encoding, metadata
|
||||
|
||||
def _backfill(self, subscription: Subscription) -> None:
|
||||
"""Backfill the given subscription with any currently matching records in the
|
||||
DB"""
|
||||
@@ -211,14 +247,16 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
|
||||
vector = None
|
||||
self._notify_one(
|
||||
subscription,
|
||||
EmbeddingRecord(
|
||||
seq_id=row[0],
|
||||
operation=_operation_codes_inv[row[1]],
|
||||
id=row[2],
|
||||
embedding=vector,
|
||||
encoding=encoding,
|
||||
metadata=json.loads(row[5]) if row[5] else None,
|
||||
),
|
||||
[
|
||||
EmbeddingRecord(
|
||||
seq_id=row[0],
|
||||
operation=_operation_codes_inv[row[1]],
|
||||
id=row[2],
|
||||
embedding=vector,
|
||||
encoding=encoding,
|
||||
metadata=json.loads(row[5]) if row[5] else None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def _validate_range(
|
||||
@@ -242,29 +280,37 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
|
||||
cur.execute(q.get_sql())
|
||||
return int(cur.fetchone()[0]) + 1
|
||||
|
||||
def _notify_all(self, topic: str, embedding: EmbeddingRecord) -> None:
|
||||
def _notify_all(self, topic: str, embeddings: Sequence[EmbeddingRecord]) -> None:
|
||||
"""Send a notification to each subscriber of the given topic."""
|
||||
if self._running:
|
||||
for sub in self._subscriptions[topic]:
|
||||
self._notify_one(sub, embedding)
|
||||
self._notify_one(sub, embeddings)
|
||||
|
||||
def _notify_one(self, sub: Subscription, embedding: EmbeddingRecord) -> None:
|
||||
def _notify_one(
|
||||
self, sub: Subscription, embeddings: Sequence[EmbeddingRecord]
|
||||
) -> None:
|
||||
"""Send a notification to a single subscriber."""
|
||||
if embedding["seq_id"] > sub.end:
|
||||
self.unsubscribe(sub.id)
|
||||
return
|
||||
|
||||
if embedding["seq_id"] <= sub.start:
|
||||
return
|
||||
# Filter out any embeddings that are not in the subscription range
|
||||
should_unsubscribe = False
|
||||
filtered_embeddings = []
|
||||
for embedding in embeddings:
|
||||
if embedding["seq_id"] <= sub.start:
|
||||
continue
|
||||
if embedding["seq_id"] > sub.end:
|
||||
should_unsubscribe = True
|
||||
break
|
||||
filtered_embeddings.append(embedding)
|
||||
|
||||
# Log errors instead of throwing them to preserve async semantics
|
||||
# for consistency between local and distributed configurations
|
||||
try:
|
||||
sub.callback([embedding])
|
||||
if len(filtered_embeddings) > 0:
|
||||
sub.callback(filtered_embeddings)
|
||||
if should_unsubscribe:
|
||||
self.unsubscribe(sub.id)
|
||||
except BaseException as e:
|
||||
id = embedding.get("id", embedding.get("delete_id"))
|
||||
logger.error(
|
||||
f"Exception occurred invoking consumer for subscription {sub.id}"
|
||||
+ f"to topic {sub.topic_name} for embedding id {id} ",
|
||||
+ f"to topic {sub.topic_name}",
|
||||
e,
|
||||
)
|
||||
|
||||
@@ -52,6 +52,16 @@ class Producer(Component):
|
||||
"""Add an embedding record to the given topic. Returns the SeqID of the record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def submit_embeddings(
|
||||
self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord]
|
||||
) -> Sequence[SeqId]:
|
||||
"""Add a batch of embedding records to the given topic. Returns the SeqIDs of
|
||||
the records. The returned SeqIDs will be in the same order as the given
|
||||
SubmitEmbeddingRecords. However, it is not guaranteed that the SeqIDs will be
|
||||
processed in the same order as the given SubmitEmbeddingRecords."""
|
||||
pass
|
||||
|
||||
|
||||
ConsumerCallbackFn = Callable[[Sequence[EmbeddingRecord]], None]
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from chromadb.config import Settings, System
|
||||
from chromadb.api import API
|
||||
from chromadb.ingest import Producer
|
||||
import chromadb.server.fastapi
|
||||
from requests.exceptions import ConnectionError
|
||||
import hypothesis
|
||||
@@ -8,12 +9,23 @@ import os
|
||||
import uvicorn
|
||||
import time
|
||||
import pytest
|
||||
from typing import Generator, List, Callable, Optional, Tuple
|
||||
from typing import (
|
||||
Generator,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Callable,
|
||||
)
|
||||
from typing_extensions import Protocol
|
||||
import shutil
|
||||
import logging
|
||||
import socket
|
||||
import multiprocessing
|
||||
|
||||
from chromadb.types import SeqId, SubmitEmbeddingRecord
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(logging.DEBUG) # This will only run when testing
|
||||
|
||||
@@ -189,3 +201,59 @@ def api(system: System) -> Generator[API, None, None]:
|
||||
system.reset_state()
|
||||
api = system.instance(API)
|
||||
yield api
|
||||
|
||||
|
||||
# Producer / Consumer fixtures #
|
||||
|
||||
|
||||
class ProducerFn(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
producer: Producer,
|
||||
topic: str,
|
||||
embeddings: Iterator[SubmitEmbeddingRecord],
|
||||
n: int,
|
||||
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]:
|
||||
...
|
||||
|
||||
|
||||
def produce_n_single(
|
||||
producer: Producer,
|
||||
topic: str,
|
||||
embeddings: Iterator[SubmitEmbeddingRecord],
|
||||
n: int,
|
||||
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]:
|
||||
submitted_embeddings = []
|
||||
seq_ids = []
|
||||
for _ in range(n):
|
||||
e = next(embeddings)
|
||||
seq_id = producer.submit_embedding(topic, e)
|
||||
submitted_embeddings.append(e)
|
||||
seq_ids.append(seq_id)
|
||||
return submitted_embeddings, seq_ids
|
||||
|
||||
|
||||
def produce_n_batch(
|
||||
producer: Producer,
|
||||
topic: str,
|
||||
embeddings: Iterator[SubmitEmbeddingRecord],
|
||||
n: int,
|
||||
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]:
|
||||
submitted_embeddings = []
|
||||
seq_ids: Sequence[SeqId] = []
|
||||
for _ in range(n):
|
||||
e = next(embeddings)
|
||||
submitted_embeddings.append(e)
|
||||
seq_ids = producer.submit_embeddings(topic, submitted_embeddings)
|
||||
return submitted_embeddings, seq_ids
|
||||
|
||||
|
||||
def produce_fn_fixtures() -> List[ProducerFn]:
|
||||
return [produce_n_single, produce_n_batch]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=produce_fn_fixtures())
|
||||
def produce_fns(
|
||||
request: pytest.FixtureRequest,
|
||||
) -> Generator[ProducerFn, None, None]:
|
||||
yield request.param
|
||||
|
||||
@@ -16,6 +16,7 @@ from typing import (
|
||||
)
|
||||
from chromadb.ingest import Producer, Consumer
|
||||
from chromadb.db.impl.sqlite import SqliteDB
|
||||
from chromadb.test.conftest import ProducerFn
|
||||
from chromadb.types import (
|
||||
SubmitEmbeddingRecord,
|
||||
Operation,
|
||||
@@ -135,15 +136,13 @@ def assert_records_match(
|
||||
async def test_backfill(
|
||||
producer_consumer: Tuple[Producer, Consumer],
|
||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||
produce_fns: ProducerFn,
|
||||
) -> None:
|
||||
producer, consumer = producer_consumer
|
||||
producer.reset_state()
|
||||
|
||||
embeddings = [next(sample_embeddings) for _ in range(3)]
|
||||
|
||||
producer.create_topic("test_topic")
|
||||
for e in embeddings:
|
||||
producer.submit_embedding("test_topic", e)
|
||||
embeddings = produce_fns(producer, "test_topic", sample_embeddings, 3)[0]
|
||||
|
||||
consume_fn = CapturingConsumeFn()
|
||||
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
|
||||
@@ -212,6 +211,7 @@ async def test_multiple_topics(
|
||||
async def test_start_seq_id(
|
||||
producer_consumer: Tuple[Producer, Consumer],
|
||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||
produce_fns: ProducerFn,
|
||||
) -> None:
|
||||
producer, consumer = producer_consumer
|
||||
producer.reset_state()
|
||||
@@ -222,22 +222,16 @@ async def test_start_seq_id(
|
||||
|
||||
consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid())
|
||||
|
||||
embeddings = []
|
||||
for _ in range(5):
|
||||
e = next(sample_embeddings)
|
||||
embeddings.append(e)
|
||||
producer.submit_embedding("test_topic", e)
|
||||
embeddings = produce_fns(producer, "test_topic", sample_embeddings, 5)[0]
|
||||
|
||||
results_1 = await consume_fn_1.get(5)
|
||||
assert_records_match(embeddings, results_1)
|
||||
|
||||
start = consume_fn_1.embeddings[-1]["seq_id"]
|
||||
consumer.subscribe("test_topic", consume_fn_2, start=start)
|
||||
for _ in range(5):
|
||||
e = next(sample_embeddings)
|
||||
embeddings.append(e)
|
||||
producer.submit_embedding("test_topic", e)
|
||||
|
||||
second_embeddings = produce_fns(producer, "test_topic", sample_embeddings, 5)[0]
|
||||
assert isinstance(embeddings, list)
|
||||
embeddings.extend(second_embeddings)
|
||||
results_2 = await consume_fn_2.get(5)
|
||||
assert_records_match(embeddings[-5:], results_2)
|
||||
|
||||
@@ -246,6 +240,7 @@ async def test_start_seq_id(
|
||||
async def test_end_seq_id(
|
||||
producer_consumer: Tuple[Producer, Consumer],
|
||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||
produce_fns: ProducerFn,
|
||||
) -> None:
|
||||
producer, consumer = producer_consumer
|
||||
producer.reset_state()
|
||||
@@ -256,11 +251,7 @@ async def test_end_seq_id(
|
||||
|
||||
consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid())
|
||||
|
||||
embeddings = []
|
||||
for _ in range(10):
|
||||
e = next(sample_embeddings)
|
||||
embeddings.append(e)
|
||||
producer.submit_embedding("test_topic", e)
|
||||
embeddings = produce_fns(producer, "test_topic", sample_embeddings, 10)[0]
|
||||
|
||||
results_1 = await consume_fn_1.get(10)
|
||||
assert_records_match(embeddings, results_1)
|
||||
@@ -274,3 +265,60 @@ async def test_end_seq_id(
|
||||
# Should never produce a 7th
|
||||
with pytest.raises(TimeoutError):
|
||||
_ = await wait_for(consume_fn_2.get(7), timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_batch(
|
||||
producer_consumer: Tuple[Producer, Consumer],
|
||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||
) -> None:
|
||||
producer, consumer = producer_consumer
|
||||
producer.reset_state()
|
||||
|
||||
embeddings = [next(sample_embeddings) for _ in range(100)]
|
||||
|
||||
producer.create_topic("test_topic")
|
||||
producer.submit_embeddings("test_topic", embeddings=embeddings)
|
||||
|
||||
consume_fn = CapturingConsumeFn()
|
||||
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
|
||||
|
||||
recieved = await consume_fn.get(100)
|
||||
assert_records_match(embeddings, recieved)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_topics_batch(
|
||||
producer_consumer: Tuple[Producer, Consumer],
|
||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||
produce_fns: ProducerFn,
|
||||
) -> None:
|
||||
producer, consumer = producer_consumer
|
||||
producer.reset_state()
|
||||
|
||||
N_TOPICS = 100
|
||||
consume_fns = [CapturingConsumeFn() for _ in range(N_TOPICS)]
|
||||
for i in range(N_TOPICS):
|
||||
producer.create_topic(f"test_topic_{i}")
|
||||
consumer.subscribe(
|
||||
f"test_topic_{i}", consume_fns[i], start=consumer.min_seqid()
|
||||
)
|
||||
|
||||
embeddings_n: List[List[SubmitEmbeddingRecord]] = [[] for _ in range(N_TOPICS)]
|
||||
|
||||
PRODUCE_BATCH_SIZE = 10
|
||||
N_TO_PRODUCE = 100
|
||||
total_produced = 0
|
||||
for i in range(N_TO_PRODUCE // PRODUCE_BATCH_SIZE):
|
||||
for i in range(N_TOPICS):
|
||||
embeddings_n[i].extend(
|
||||
produce_fns(
|
||||
producer,
|
||||
f"test_topic_{i}",
|
||||
sample_embeddings,
|
||||
PRODUCE_BATCH_SIZE,
|
||||
)[0]
|
||||
)
|
||||
recieved = await consume_fns[i].get(total_produced + PRODUCE_BATCH_SIZE)
|
||||
assert_records_match(embeddings_n[i], recieved)
|
||||
total_produced += PRODUCE_BATCH_SIZE
|
||||
|
||||
@@ -4,6 +4,7 @@ import tempfile
|
||||
import pytest
|
||||
from typing import Generator, List, Callable, Iterator, Dict, Optional, Union, Sequence
|
||||
from chromadb.config import System, Settings
|
||||
from chromadb.test.conftest import ProducerFn
|
||||
from chromadb.types import (
|
||||
SubmitEmbeddingRecord,
|
||||
MetadataEmbeddingRecord,
|
||||
@@ -128,16 +129,16 @@ def sync(segment: MetadataReader, seq_id: SeqId) -> None:
|
||||
|
||||
|
||||
def test_insert_and_count(
|
||||
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
|
||||
system: System,
|
||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||
produce_fns: ProducerFn,
|
||||
) -> None:
|
||||
producer = system.instance(Producer)
|
||||
system.reset_state()
|
||||
|
||||
topic = str(segment_definition["topic"])
|
||||
|
||||
max_id = 0
|
||||
for i in range(3):
|
||||
max_id = producer.submit_embedding(topic, next(sample_embeddings))
|
||||
max_id = produce_fns(producer, topic, sample_embeddings, 3)[1][-1]
|
||||
|
||||
segment = SqliteMetadataSegment(system, segment_definition)
|
||||
segment.start()
|
||||
@@ -166,17 +167,15 @@ def assert_equiv_records(
|
||||
|
||||
|
||||
def test_get(
|
||||
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
|
||||
system: System,
|
||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||
produce_fns: ProducerFn,
|
||||
) -> None:
|
||||
producer = system.instance(Producer)
|
||||
system.reset_state()
|
||||
topic = str(segment_definition["topic"])
|
||||
|
||||
embeddings = [next(sample_embeddings) for i in range(10)]
|
||||
|
||||
seq_ids = []
|
||||
for e in embeddings:
|
||||
seq_ids.append(producer.submit_embedding(topic, e))
|
||||
embeddings, seq_ids = produce_fns(producer, topic, sample_embeddings, 10)
|
||||
|
||||
segment = SqliteMetadataSegment(system, segment_definition)
|
||||
segment.start()
|
||||
@@ -270,7 +269,9 @@ def test_get(
|
||||
|
||||
|
||||
def test_fulltext(
|
||||
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
|
||||
system: System,
|
||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||
produce_fns: ProducerFn,
|
||||
) -> None:
|
||||
producer = system.instance(Producer)
|
||||
system.reset_state()
|
||||
@@ -279,9 +280,7 @@ def test_fulltext(
|
||||
segment = SqliteMetadataSegment(system, segment_definition)
|
||||
segment.start()
|
||||
|
||||
max_id = 0
|
||||
for i in range(100):
|
||||
max_id = producer.submit_embedding(topic, next(sample_embeddings))
|
||||
max_id = produce_fns(producer, topic, sample_embeddings, 100)[1][-1]
|
||||
|
||||
sync(segment, max_id)
|
||||
|
||||
@@ -331,7 +330,9 @@ def test_fulltext(
|
||||
|
||||
|
||||
def test_delete(
|
||||
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
|
||||
system: System,
|
||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||
produce_fns: ProducerFn,
|
||||
) -> None:
|
||||
producer = system.instance(Producer)
|
||||
system.reset_state()
|
||||
@@ -340,11 +341,8 @@ def test_delete(
|
||||
segment = SqliteMetadataSegment(system, segment_definition)
|
||||
segment.start()
|
||||
|
||||
embeddings = [next(sample_embeddings) for i in range(10)]
|
||||
|
||||
max_id = 0
|
||||
for e in embeddings:
|
||||
max_id = producer.submit_embedding(topic, e)
|
||||
embeddings, seq_ids = produce_fns(producer, topic, sample_embeddings, 10)
|
||||
max_id = seq_ids[-1]
|
||||
|
||||
sync(segment, max_id)
|
||||
|
||||
@@ -353,16 +351,16 @@ def test_delete(
|
||||
assert_equiv_records(embeddings[:1], results)
|
||||
|
||||
# Delete by ID
|
||||
max_id = producer.submit_embedding(
|
||||
topic,
|
||||
SubmitEmbeddingRecord(
|
||||
id="embedding_0",
|
||||
embedding=None,
|
||||
encoding=None,
|
||||
metadata=None,
|
||||
operation=Operation.DELETE,
|
||||
),
|
||||
delete_embedding = SubmitEmbeddingRecord(
|
||||
id="embedding_0",
|
||||
embedding=None,
|
||||
encoding=None,
|
||||
metadata=None,
|
||||
operation=Operation.DELETE,
|
||||
)
|
||||
max_id = produce_fns(producer, topic, (delete_embedding for _ in range(1)), 1)[1][
|
||||
-1
|
||||
]
|
||||
|
||||
sync(segment, max_id)
|
||||
|
||||
@@ -370,16 +368,9 @@ def test_delete(
|
||||
assert segment.get_metadata(ids=["embedding_0"]) == []
|
||||
|
||||
# Delete is idempotent
|
||||
max_id = producer.submit_embedding(
|
||||
topic,
|
||||
SubmitEmbeddingRecord(
|
||||
id="embedding_0",
|
||||
embedding=None,
|
||||
encoding=None,
|
||||
metadata=None,
|
||||
operation=Operation.DELETE,
|
||||
),
|
||||
)
|
||||
max_id = produce_fns(producer, topic, (delete_embedding for _ in range(1)), 1)[1][
|
||||
-1
|
||||
]
|
||||
|
||||
sync(segment, max_id)
|
||||
assert segment.count() == 9
|
||||
@@ -420,7 +411,9 @@ def test_update(
|
||||
|
||||
|
||||
def test_upsert(
|
||||
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
|
||||
system: System,
|
||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||
produce_fns: ProducerFn,
|
||||
) -> None:
|
||||
producer = system.instance(Producer)
|
||||
system.reset_state()
|
||||
@@ -439,7 +432,12 @@ def test_upsert(
|
||||
encoding=None,
|
||||
operation=Operation.UPSERT,
|
||||
)
|
||||
max_id = producer.submit_embedding(topic, update_record)
|
||||
max_id = produce_fns(
|
||||
producer=producer,
|
||||
topic=topic,
|
||||
embeddings=(update_record for _ in range(1)),
|
||||
n=1,
|
||||
)[1][-1]
|
||||
sync(segment, max_id)
|
||||
results = segment.get_metadata(ids=["no_such_id"])
|
||||
assert results[0]["metadata"] == {"foo": "bar"}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
from typing import Generator, List, Callable, Iterator, Type, cast
|
||||
from chromadb.config import System, Settings
|
||||
from chromadb.test.conftest import ProducerFn
|
||||
from chromadb.types import (
|
||||
SubmitEmbeddingRecord,
|
||||
VectorQuery,
|
||||
@@ -129,6 +130,7 @@ def test_insert_and_count(
|
||||
system: System,
|
||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||
vector_reader: Type[VectorReader],
|
||||
produce_fns: ProducerFn,
|
||||
) -> None:
|
||||
producer = system.instance(Producer)
|
||||
|
||||
@@ -136,9 +138,9 @@ def test_insert_and_count(
|
||||
segment_definition = create_random_segment_definition()
|
||||
topic = str(segment_definition["topic"])
|
||||
|
||||
max_id = 0
|
||||
for i in range(3):
|
||||
max_id = producer.submit_embedding(topic, next(sample_embeddings))
|
||||
max_id = produce_fns(
|
||||
producer=producer, topic=topic, n=3, embeddings=sample_embeddings
|
||||
)[1][-1]
|
||||
|
||||
segment = vector_reader(system, segment_definition)
|
||||
segment.start()
|
||||
@@ -146,8 +148,10 @@ def test_insert_and_count(
|
||||
sync(segment, max_id)
|
||||
|
||||
assert segment.count() == 3
|
||||
for i in range(3):
|
||||
max_id = producer.submit_embedding(topic, next(sample_embeddings))
|
||||
|
||||
max_id = produce_fns(
|
||||
producer=producer, topic=topic, n=3, embeddings=sample_embeddings
|
||||
)[1][-1]
|
||||
|
||||
sync(segment, max_id)
|
||||
assert segment.count() == 6
|
||||
@@ -165,6 +169,7 @@ def test_get_vectors(
|
||||
system: System,
|
||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||
vector_reader: Type[VectorReader],
|
||||
produce_fns: ProducerFn,
|
||||
) -> None:
|
||||
producer = system.instance(Producer)
|
||||
system.reset_state()
|
||||
@@ -174,11 +179,9 @@ def test_get_vectors(
|
||||
segment = vector_reader(system, segment_definition)
|
||||
segment.start()
|
||||
|
||||
embeddings = [next(sample_embeddings) for i in range(10)]
|
||||
|
||||
seq_ids: List[SeqId] = []
|
||||
for e in embeddings:
|
||||
seq_ids.append(producer.submit_embedding(topic, e))
|
||||
embeddings, seq_ids = produce_fns(
|
||||
producer=producer, topic=topic, embeddings=sample_embeddings, n=10
|
||||
)
|
||||
|
||||
sync(segment, seq_ids[-1])
|
||||
|
||||
@@ -210,6 +213,7 @@ def test_ann_query(
|
||||
system: System,
|
||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||
vector_reader: Type[VectorReader],
|
||||
produce_fns: ProducerFn,
|
||||
) -> None:
|
||||
producer = system.instance(Producer)
|
||||
system.reset_state()
|
||||
@@ -219,11 +223,9 @@ def test_ann_query(
|
||||
segment = vector_reader(system, segment_definition)
|
||||
segment.start()
|
||||
|
||||
embeddings = [next(sample_embeddings) for i in range(100)]
|
||||
|
||||
seq_ids: List[SeqId] = []
|
||||
for e in embeddings:
|
||||
seq_ids.append(producer.submit_embedding(topic, e))
|
||||
embeddings, seq_ids = produce_fns(
|
||||
producer=producer, topic=topic, embeddings=sample_embeddings, n=100
|
||||
)
|
||||
|
||||
sync(segment, seq_ids[-1])
|
||||
|
||||
@@ -275,6 +277,7 @@ def test_delete(
|
||||
system: System,
|
||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||
vector_reader: Type[VectorReader],
|
||||
produce_fns: ProducerFn,
|
||||
) -> None:
|
||||
producer = system.instance(Producer)
|
||||
system.reset_state()
|
||||
@@ -284,26 +287,28 @@ def test_delete(
|
||||
segment = vector_reader(system, segment_definition)
|
||||
segment.start()
|
||||
|
||||
embeddings = [next(sample_embeddings) for i in range(5)]
|
||||
|
||||
seq_ids: List[SeqId] = []
|
||||
for e in embeddings:
|
||||
seq_ids.append(producer.submit_embedding(topic, e))
|
||||
embeddings, seq_ids = produce_fns(
|
||||
producer=producer, topic=topic, embeddings=sample_embeddings, n=5
|
||||
)
|
||||
|
||||
sync(segment, seq_ids[-1])
|
||||
assert segment.count() == 5
|
||||
|
||||
delete_record = SubmitEmbeddingRecord(
|
||||
id=embeddings[0]["id"],
|
||||
embedding=None,
|
||||
encoding=None,
|
||||
metadata=None,
|
||||
operation=Operation.DELETE,
|
||||
)
|
||||
assert isinstance(seq_ids, List)
|
||||
seq_ids.append(
|
||||
producer.submit_embedding(
|
||||
topic,
|
||||
SubmitEmbeddingRecord(
|
||||
id=embeddings[0]["id"],
|
||||
embedding=None,
|
||||
encoding=None,
|
||||
metadata=None,
|
||||
operation=Operation.DELETE,
|
||||
),
|
||||
)
|
||||
produce_fns(
|
||||
producer=producer,
|
||||
topic=topic,
|
||||
n=1,
|
||||
embeddings=(delete_record for _ in range(1)),
|
||||
)[1][0]
|
||||
)
|
||||
|
||||
sync(segment, seq_ids[-1])
|
||||
@@ -334,16 +339,12 @@ def test_delete(
|
||||
|
||||
# Delete is idempotent
|
||||
seq_ids.append(
|
||||
producer.submit_embedding(
|
||||
topic,
|
||||
SubmitEmbeddingRecord(
|
||||
id=embeddings[0]["id"],
|
||||
embedding=None,
|
||||
encoding=None,
|
||||
metadata=None,
|
||||
operation=Operation.DELETE,
|
||||
),
|
||||
)
|
||||
produce_fns(
|
||||
producer=producer,
|
||||
topic=topic,
|
||||
n=1,
|
||||
embeddings=(delete_record for _ in range(1)),
|
||||
)[1][0]
|
||||
)
|
||||
|
||||
sync(segment, seq_ids[-1])
|
||||
@@ -416,6 +417,7 @@ def test_update(
|
||||
system: System,
|
||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||
vector_reader: Type[VectorReader],
|
||||
produce_fns: ProducerFn,
|
||||
) -> None:
|
||||
producer = system.instance(Producer)
|
||||
system.reset_state()
|
||||
@@ -428,16 +430,19 @@ def test_update(
|
||||
_test_update(producer, topic, segment, sample_embeddings, Operation.UPDATE)
|
||||
|
||||
# test updating a nonexistent record
|
||||
seq_id = producer.submit_embedding(
|
||||
topic,
|
||||
SubmitEmbeddingRecord(
|
||||
id="no_such_record",
|
||||
embedding=[10.0, 10.0],
|
||||
encoding=ScalarEncoding.FLOAT32,
|
||||
metadata=None,
|
||||
operation=Operation.UPDATE,
|
||||
),
|
||||
update_record = SubmitEmbeddingRecord(
|
||||
id="no_such_record",
|
||||
embedding=[10.0, 10.0],
|
||||
encoding=ScalarEncoding.FLOAT32,
|
||||
metadata=None,
|
||||
operation=Operation.UPDATE,
|
||||
)
|
||||
seq_id = produce_fns(
|
||||
producer=producer,
|
||||
topic=topic,
|
||||
n=1,
|
||||
embeddings=(update_record for _ in range(1)),
|
||||
)[1][0]
|
||||
|
||||
sync(segment, seq_id)
|
||||
|
||||
@@ -449,6 +454,7 @@ def test_upsert(
|
||||
system: System,
|
||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||
vector_reader: Type[VectorReader],
|
||||
produce_fns: ProducerFn,
|
||||
) -> None:
|
||||
producer = system.instance(Producer)
|
||||
system.reset_state()
|
||||
@@ -461,16 +467,19 @@ def test_upsert(
|
||||
_test_update(producer, topic, segment, sample_embeddings, Operation.UPSERT)
|
||||
|
||||
# test updating a nonexistent record
|
||||
seq_id = producer.submit_embedding(
|
||||
topic,
|
||||
SubmitEmbeddingRecord(
|
||||
id="no_such_record",
|
||||
embedding=[42, 42],
|
||||
encoding=ScalarEncoding.FLOAT32,
|
||||
metadata=None,
|
||||
operation=Operation.UPSERT,
|
||||
),
|
||||
upsert_record = SubmitEmbeddingRecord(
|
||||
id="no_such_record",
|
||||
embedding=[42, 42],
|
||||
encoding=ScalarEncoding.FLOAT32,
|
||||
metadata=None,
|
||||
operation=Operation.UPSERT,
|
||||
)
|
||||
seq_id = produce_fns(
|
||||
producer=producer,
|
||||
topic=topic,
|
||||
n=1,
|
||||
embeddings=(upsert_record for _ in range(1)),
|
||||
)[1][0]
|
||||
|
||||
sync(segment, seq_id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user