mirror of
https://github.com/placeholder-soft/chroma.git
synced 2026-04-30 21:01:46 +08:00
[PERF] Batch SQLite Embeddings Queue (#959)
## Description of changes Adds a batch submit_embeddings() to the producer API. Implements this behavior in sqlite based embeddings queue. Refactors tests to add a fixture for both single and batch submit and substitutes all use of produce in tests with the fixture so both methods are uniformly tested. This PR is stacked on #958, please read that before reading this. *Summarize the changes made by this PR.* - Improvements & Bug fixes - This PR makes the sqlite based embeddings queue batch calls to sqlite based on the batch size of calls to DML operations. ### Quick Benchmarks N = 100k D = 128 metadata = {simple key: simple value} Document = random 100 character string **Before - Overall time: 102s** <img width="587" alt="Screenshot 2023-08-09 at 5 43 12 PM" src="https://github.com/chroma-core/chroma/assets/5598697/752809a2-78a0-4d5b-a238-fc41ca2635cc"> **After - Overall time: 36s** <img width="583" alt="Screenshot 2023-08-09 at 7 37 41 PM" src="https://github.com/chroma-core/chroma/assets/5598697/eb1759c8-8f62-4ed7-ad88-450b3f72692b"> **_Overall, when compared to main this and #958 decrease the time of this benchmark from 469s -> 36s._** Todo: - [x] Clean up code - [x] Rethink assumptions about submit_embeddings ordering guarantees - [x] Add tests - [x] Add batching to other DML - upsert(), update(), delete() ## Test plan Existing tests cover a bulk of the functionality. New tests TODO: - [x] Producer/Consumer tests for submit_embeddings - [x] Segment tests with batch embeddings ## Documentation Changes None required. But a misc feature idea is hyperparameter sweep for the optimal batch_size for your platform
This commit is contained in:
@@ -242,9 +242,11 @@ class SegmentAPI(API):
|
|||||||
coll = self._get_collection(collection_id)
|
coll = self._get_collection(collection_id)
|
||||||
self._manager.hint_use_collection(collection_id, t.Operation.ADD)
|
self._manager.hint_use_collection(collection_id, t.Operation.ADD)
|
||||||
|
|
||||||
|
records_to_submit = []
|
||||||
for r in _records(t.Operation.ADD, ids, embeddings, metadatas, documents):
|
for r in _records(t.Operation.ADD, ids, embeddings, metadatas, documents):
|
||||||
self._validate_embedding_record(coll, r)
|
self._validate_embedding_record(coll, r)
|
||||||
self._producer.submit_embedding(coll["topic"], r)
|
records_to_submit.append(r)
|
||||||
|
self._producer.submit_embeddings(coll["topic"], records_to_submit)
|
||||||
|
|
||||||
self._telemetry_client.capture(CollectionAddEvent(str(collection_id), len(ids)))
|
self._telemetry_client.capture(CollectionAddEvent(str(collection_id), len(ids)))
|
||||||
return True
|
return True
|
||||||
@@ -261,9 +263,11 @@ class SegmentAPI(API):
|
|||||||
coll = self._get_collection(collection_id)
|
coll = self._get_collection(collection_id)
|
||||||
self._manager.hint_use_collection(collection_id, t.Operation.UPDATE)
|
self._manager.hint_use_collection(collection_id, t.Operation.UPDATE)
|
||||||
|
|
||||||
|
records_to_submit = []
|
||||||
for r in _records(t.Operation.UPDATE, ids, embeddings, metadatas, documents):
|
for r in _records(t.Operation.UPDATE, ids, embeddings, metadatas, documents):
|
||||||
self._validate_embedding_record(coll, r)
|
self._validate_embedding_record(coll, r)
|
||||||
self._producer.submit_embedding(coll["topic"], r)
|
records_to_submit.append(r)
|
||||||
|
self._producer.submit_embeddings(coll["topic"], records_to_submit)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -279,9 +283,11 @@ class SegmentAPI(API):
|
|||||||
coll = self._get_collection(collection_id)
|
coll = self._get_collection(collection_id)
|
||||||
self._manager.hint_use_collection(collection_id, t.Operation.UPSERT)
|
self._manager.hint_use_collection(collection_id, t.Operation.UPSERT)
|
||||||
|
|
||||||
|
records_to_submit = []
|
||||||
for r in _records(t.Operation.UPSERT, ids, embeddings, metadatas, documents):
|
for r in _records(t.Operation.UPSERT, ids, embeddings, metadatas, documents):
|
||||||
self._validate_embedding_record(coll, r)
|
self._validate_embedding_record(coll, r)
|
||||||
self._producer.submit_embedding(coll["topic"], r)
|
records_to_submit.append(r)
|
||||||
|
self._producer.submit_embeddings(coll["topic"], records_to_submit)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -376,9 +382,11 @@ class SegmentAPI(API):
|
|||||||
else:
|
else:
|
||||||
ids_to_delete = ids
|
ids_to_delete = ids
|
||||||
|
|
||||||
|
records_to_submit = []
|
||||||
for r in _records(t.Operation.DELETE, ids_to_delete):
|
for r in _records(t.Operation.DELETE, ids_to_delete):
|
||||||
self._validate_embedding_record(coll, r)
|
self._validate_embedding_record(coll, r)
|
||||||
self._producer.submit_embedding(coll["topic"], r)
|
records_to_submit.append(r)
|
||||||
|
self._producer.submit_embeddings(coll["topic"], records_to_submit)
|
||||||
|
|
||||||
self._telemetry_client.capture(
|
self._telemetry_client.capture(
|
||||||
CollectionDeleteEvent(str(collection_id), len(ids_to_delete))
|
CollectionDeleteEvent(str(collection_id), len(ids_to_delete))
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from chromadb.types import (
|
|||||||
from chromadb.config import System
|
from chromadb.config import System
|
||||||
from overrides import override
|
from overrides import override
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Tuple, Optional, Dict, Set, cast
|
from typing import Sequence, Tuple, Optional, Dict, Set, cast
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from pypika import Table, functions
|
from pypika import Table, functions
|
||||||
import uuid
|
import uuid
|
||||||
@@ -103,22 +103,32 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
|
|||||||
if not self._running:
|
if not self._running:
|
||||||
raise RuntimeError("Component not running")
|
raise RuntimeError("Component not running")
|
||||||
|
|
||||||
if embedding["embedding"]:
|
return self.submit_embeddings(topic_name, [embedding])[0]
|
||||||
encoding_type = cast(ScalarEncoding, embedding["encoding"])
|
|
||||||
encoding = encoding_type.value
|
|
||||||
embedding_bytes = encode_vector(embedding["embedding"], encoding_type)
|
|
||||||
|
|
||||||
else:
|
@override
|
||||||
embedding_bytes = None
|
def submit_embeddings(
|
||||||
encoding = None
|
self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord]
|
||||||
metadata = json.dumps(embedding["metadata"]) if embedding["metadata"] else None
|
) -> Sequence[SeqId]:
|
||||||
|
if not self._running:
|
||||||
|
raise RuntimeError("Component not running")
|
||||||
|
|
||||||
|
if len(embeddings) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
t = Table("embeddings_queue")
|
t = Table("embeddings_queue")
|
||||||
insert = (
|
insert = (
|
||||||
self.querybuilder()
|
self.querybuilder()
|
||||||
.into(t)
|
.into(t)
|
||||||
.columns(t.operation, t.topic, t.id, t.vector, t.encoding, t.metadata)
|
.columns(t.operation, t.topic, t.id, t.vector, t.encoding, t.metadata)
|
||||||
.insert(
|
)
|
||||||
|
id_to_idx: Dict[str, int] = {}
|
||||||
|
for embedding in embeddings:
|
||||||
|
(
|
||||||
|
embedding_bytes,
|
||||||
|
encoding,
|
||||||
|
metadata,
|
||||||
|
) = self._prepare_vector_encoding_metadata(embedding)
|
||||||
|
insert = insert.insert(
|
||||||
ParameterValue(_operation_codes[embedding["operation"]]),
|
ParameterValue(_operation_codes[embedding["operation"]]),
|
||||||
ParameterValue(topic_name),
|
ParameterValue(topic_name),
|
||||||
ParameterValue(embedding["id"]),
|
ParameterValue(embedding["id"]),
|
||||||
@@ -126,21 +136,34 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
|
|||||||
ParameterValue(encoding),
|
ParameterValue(encoding),
|
||||||
ParameterValue(metadata),
|
ParameterValue(metadata),
|
||||||
)
|
)
|
||||||
)
|
id_to_idx[embedding["id"]] = len(id_to_idx)
|
||||||
with self.tx() as cur:
|
with self.tx() as cur:
|
||||||
sql, params = get_sql(insert, self.parameter_format())
|
sql, params = get_sql(insert, self.parameter_format())
|
||||||
sql = f"{sql} RETURNING seq_id" # Pypika doesn't support RETURNING
|
# The returning clause does not guarantee order, so we need to do reorder
|
||||||
seq_id = int(cur.execute(sql, params).fetchone()[0])
|
# the results. https://www.sqlite.org/lang_returning.html
|
||||||
embedding_record = EmbeddingRecord(
|
sql = f"{sql} RETURNING seq_id, id" # Pypika doesn't support RETURNING
|
||||||
id=embedding["id"],
|
results = cur.execute(sql, params).fetchall()
|
||||||
seq_id=seq_id,
|
# Reorder the results
|
||||||
embedding=embedding["embedding"],
|
seq_ids = [cast(SeqId, None)] * len(
|
||||||
encoding=embedding["encoding"],
|
results
|
||||||
metadata=embedding["metadata"],
|
) # Lie to mypy: https://stackoverflow.com/questions/76694215/python-type-casting-when-preallocating-list
|
||||||
operation=embedding["operation"],
|
embedding_records = []
|
||||||
)
|
for seq_id, id in results:
|
||||||
self._notify_all(topic_name, embedding_record)
|
seq_ids[id_to_idx[id]] = seq_id
|
||||||
return seq_id
|
submit_embedding_record = embeddings[id_to_idx[id]]
|
||||||
|
# We allow notifying consumers out of order relative to one call to
|
||||||
|
# submit_embeddings so we do not reorder the records before submitting them
|
||||||
|
embedding_record = EmbeddingRecord(
|
||||||
|
id=id,
|
||||||
|
seq_id=seq_id,
|
||||||
|
embedding=submit_embedding_record["embedding"],
|
||||||
|
encoding=submit_embedding_record["encoding"],
|
||||||
|
metadata=submit_embedding_record["metadata"],
|
||||||
|
operation=submit_embedding_record["operation"],
|
||||||
|
)
|
||||||
|
embedding_records.append(embedding_record)
|
||||||
|
self._notify_all(topic_name, embedding_records)
|
||||||
|
return seq_ids
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def subscribe(
|
def subscribe(
|
||||||
@@ -185,6 +208,19 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
|
|||||||
def max_seqid(self) -> SeqId:
|
def max_seqid(self) -> SeqId:
|
||||||
return 2**63 - 1
|
return 2**63 - 1
|
||||||
|
|
||||||
|
def _prepare_vector_encoding_metadata(
|
||||||
|
self, embedding: SubmitEmbeddingRecord
|
||||||
|
) -> Tuple[Optional[bytes], Optional[str], Optional[str]]:
|
||||||
|
if embedding["embedding"]:
|
||||||
|
encoding_type = cast(ScalarEncoding, embedding["encoding"])
|
||||||
|
encoding = encoding_type.value
|
||||||
|
embedding_bytes = encode_vector(embedding["embedding"], encoding_type)
|
||||||
|
else:
|
||||||
|
embedding_bytes = None
|
||||||
|
encoding = None
|
||||||
|
metadata = json.dumps(embedding["metadata"]) if embedding["metadata"] else None
|
||||||
|
return embedding_bytes, encoding, metadata
|
||||||
|
|
||||||
def _backfill(self, subscription: Subscription) -> None:
|
def _backfill(self, subscription: Subscription) -> None:
|
||||||
"""Backfill the given subscription with any currently matching records in the
|
"""Backfill the given subscription with any currently matching records in the
|
||||||
DB"""
|
DB"""
|
||||||
@@ -211,14 +247,16 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
|
|||||||
vector = None
|
vector = None
|
||||||
self._notify_one(
|
self._notify_one(
|
||||||
subscription,
|
subscription,
|
||||||
EmbeddingRecord(
|
[
|
||||||
seq_id=row[0],
|
EmbeddingRecord(
|
||||||
operation=_operation_codes_inv[row[1]],
|
seq_id=row[0],
|
||||||
id=row[2],
|
operation=_operation_codes_inv[row[1]],
|
||||||
embedding=vector,
|
id=row[2],
|
||||||
encoding=encoding,
|
embedding=vector,
|
||||||
metadata=json.loads(row[5]) if row[5] else None,
|
encoding=encoding,
|
||||||
),
|
metadata=json.loads(row[5]) if row[5] else None,
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _validate_range(
|
def _validate_range(
|
||||||
@@ -242,29 +280,37 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
|
|||||||
cur.execute(q.get_sql())
|
cur.execute(q.get_sql())
|
||||||
return int(cur.fetchone()[0]) + 1
|
return int(cur.fetchone()[0]) + 1
|
||||||
|
|
||||||
def _notify_all(self, topic: str, embedding: EmbeddingRecord) -> None:
|
def _notify_all(self, topic: str, embeddings: Sequence[EmbeddingRecord]) -> None:
|
||||||
"""Send a notification to each subscriber of the given topic."""
|
"""Send a notification to each subscriber of the given topic."""
|
||||||
if self._running:
|
if self._running:
|
||||||
for sub in self._subscriptions[topic]:
|
for sub in self._subscriptions[topic]:
|
||||||
self._notify_one(sub, embedding)
|
self._notify_one(sub, embeddings)
|
||||||
|
|
||||||
def _notify_one(self, sub: Subscription, embedding: EmbeddingRecord) -> None:
|
def _notify_one(
|
||||||
|
self, sub: Subscription, embeddings: Sequence[EmbeddingRecord]
|
||||||
|
) -> None:
|
||||||
"""Send a notification to a single subscriber."""
|
"""Send a notification to a single subscriber."""
|
||||||
if embedding["seq_id"] > sub.end:
|
# Filter out any embeddings that are not in the subscription range
|
||||||
self.unsubscribe(sub.id)
|
should_unsubscribe = False
|
||||||
return
|
filtered_embeddings = []
|
||||||
|
for embedding in embeddings:
|
||||||
if embedding["seq_id"] <= sub.start:
|
if embedding["seq_id"] <= sub.start:
|
||||||
return
|
continue
|
||||||
|
if embedding["seq_id"] > sub.end:
|
||||||
|
should_unsubscribe = True
|
||||||
|
break
|
||||||
|
filtered_embeddings.append(embedding)
|
||||||
|
|
||||||
# Log errors instead of throwing them to preserve async semantics
|
# Log errors instead of throwing them to preserve async semantics
|
||||||
# for consistency between local and distributed configurations
|
# for consistency between local and distributed configurations
|
||||||
try:
|
try:
|
||||||
sub.callback([embedding])
|
if len(filtered_embeddings) > 0:
|
||||||
|
sub.callback(filtered_embeddings)
|
||||||
|
if should_unsubscribe:
|
||||||
|
self.unsubscribe(sub.id)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
id = embedding.get("id", embedding.get("delete_id"))
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Exception occurred invoking consumer for subscription {sub.id}"
|
f"Exception occurred invoking consumer for subscription {sub.id}"
|
||||||
+ f"to topic {sub.topic_name} for embedding id {id} ",
|
+ f"to topic {sub.topic_name}",
|
||||||
e,
|
e,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -52,6 +52,16 @@ class Producer(Component):
|
|||||||
"""Add an embedding record to the given topic. Returns the SeqID of the record."""
|
"""Add an embedding record to the given topic. Returns the SeqID of the record."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def submit_embeddings(
|
||||||
|
self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord]
|
||||||
|
) -> Sequence[SeqId]:
|
||||||
|
"""Add a batch of embedding records to the given topic. Returns the SeqIDs of
|
||||||
|
the records. The returned SeqIDs will be in the same order as the given
|
||||||
|
SubmitEmbeddingRecords. However, it is not guaranteed that the SeqIDs will be
|
||||||
|
processed in the same order as the given SubmitEmbeddingRecords."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
ConsumerCallbackFn = Callable[[Sequence[EmbeddingRecord]], None]
|
ConsumerCallbackFn = Callable[[Sequence[EmbeddingRecord]], None]
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from chromadb.config import Settings, System
|
from chromadb.config import Settings, System
|
||||||
from chromadb.api import API
|
from chromadb.api import API
|
||||||
|
from chromadb.ingest import Producer
|
||||||
import chromadb.server.fastapi
|
import chromadb.server.fastapi
|
||||||
from requests.exceptions import ConnectionError
|
from requests.exceptions import ConnectionError
|
||||||
import hypothesis
|
import hypothesis
|
||||||
@@ -8,12 +9,23 @@ import os
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
import time
|
import time
|
||||||
import pytest
|
import pytest
|
||||||
from typing import Generator, List, Callable, Optional, Tuple
|
from typing import (
|
||||||
|
Generator,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Callable,
|
||||||
|
)
|
||||||
|
from typing_extensions import Protocol
|
||||||
import shutil
|
import shutil
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
|
from chromadb.types import SeqId, SubmitEmbeddingRecord
|
||||||
|
|
||||||
root_logger = logging.getLogger()
|
root_logger = logging.getLogger()
|
||||||
root_logger.setLevel(logging.DEBUG) # This will only run when testing
|
root_logger.setLevel(logging.DEBUG) # This will only run when testing
|
||||||
|
|
||||||
@@ -189,3 +201,59 @@ def api(system: System) -> Generator[API, None, None]:
|
|||||||
system.reset_state()
|
system.reset_state()
|
||||||
api = system.instance(API)
|
api = system.instance(API)
|
||||||
yield api
|
yield api
|
||||||
|
|
||||||
|
|
||||||
|
# Producer / Consumer fixtures #
|
||||||
|
|
||||||
|
|
||||||
|
class ProducerFn(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
producer: Producer,
|
||||||
|
topic: str,
|
||||||
|
embeddings: Iterator[SubmitEmbeddingRecord],
|
||||||
|
n: int,
|
||||||
|
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def produce_n_single(
|
||||||
|
producer: Producer,
|
||||||
|
topic: str,
|
||||||
|
embeddings: Iterator[SubmitEmbeddingRecord],
|
||||||
|
n: int,
|
||||||
|
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]:
|
||||||
|
submitted_embeddings = []
|
||||||
|
seq_ids = []
|
||||||
|
for _ in range(n):
|
||||||
|
e = next(embeddings)
|
||||||
|
seq_id = producer.submit_embedding(topic, e)
|
||||||
|
submitted_embeddings.append(e)
|
||||||
|
seq_ids.append(seq_id)
|
||||||
|
return submitted_embeddings, seq_ids
|
||||||
|
|
||||||
|
|
||||||
|
def produce_n_batch(
|
||||||
|
producer: Producer,
|
||||||
|
topic: str,
|
||||||
|
embeddings: Iterator[SubmitEmbeddingRecord],
|
||||||
|
n: int,
|
||||||
|
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]:
|
||||||
|
submitted_embeddings = []
|
||||||
|
seq_ids: Sequence[SeqId] = []
|
||||||
|
for _ in range(n):
|
||||||
|
e = next(embeddings)
|
||||||
|
submitted_embeddings.append(e)
|
||||||
|
seq_ids = producer.submit_embeddings(topic, submitted_embeddings)
|
||||||
|
return submitted_embeddings, seq_ids
|
||||||
|
|
||||||
|
|
||||||
|
def produce_fn_fixtures() -> List[ProducerFn]:
|
||||||
|
return [produce_n_single, produce_n_batch]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=produce_fn_fixtures())
|
||||||
|
def produce_fns(
|
||||||
|
request: pytest.FixtureRequest,
|
||||||
|
) -> Generator[ProducerFn, None, None]:
|
||||||
|
yield request.param
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
from chromadb.ingest import Producer, Consumer
|
from chromadb.ingest import Producer, Consumer
|
||||||
from chromadb.db.impl.sqlite import SqliteDB
|
from chromadb.db.impl.sqlite import SqliteDB
|
||||||
|
from chromadb.test.conftest import ProducerFn
|
||||||
from chromadb.types import (
|
from chromadb.types import (
|
||||||
SubmitEmbeddingRecord,
|
SubmitEmbeddingRecord,
|
||||||
Operation,
|
Operation,
|
||||||
@@ -135,15 +136,13 @@ def assert_records_match(
|
|||||||
async def test_backfill(
|
async def test_backfill(
|
||||||
producer_consumer: Tuple[Producer, Consumer],
|
producer_consumer: Tuple[Producer, Consumer],
|
||||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||||
|
produce_fns: ProducerFn,
|
||||||
) -> None:
|
) -> None:
|
||||||
producer, consumer = producer_consumer
|
producer, consumer = producer_consumer
|
||||||
producer.reset_state()
|
producer.reset_state()
|
||||||
|
|
||||||
embeddings = [next(sample_embeddings) for _ in range(3)]
|
|
||||||
|
|
||||||
producer.create_topic("test_topic")
|
producer.create_topic("test_topic")
|
||||||
for e in embeddings:
|
embeddings = produce_fns(producer, "test_topic", sample_embeddings, 3)[0]
|
||||||
producer.submit_embedding("test_topic", e)
|
|
||||||
|
|
||||||
consume_fn = CapturingConsumeFn()
|
consume_fn = CapturingConsumeFn()
|
||||||
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
|
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
|
||||||
@@ -212,6 +211,7 @@ async def test_multiple_topics(
|
|||||||
async def test_start_seq_id(
|
async def test_start_seq_id(
|
||||||
producer_consumer: Tuple[Producer, Consumer],
|
producer_consumer: Tuple[Producer, Consumer],
|
||||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||||
|
produce_fns: ProducerFn,
|
||||||
) -> None:
|
) -> None:
|
||||||
producer, consumer = producer_consumer
|
producer, consumer = producer_consumer
|
||||||
producer.reset_state()
|
producer.reset_state()
|
||||||
@@ -222,22 +222,16 @@ async def test_start_seq_id(
|
|||||||
|
|
||||||
consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid())
|
consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid())
|
||||||
|
|
||||||
embeddings = []
|
embeddings = produce_fns(producer, "test_topic", sample_embeddings, 5)[0]
|
||||||
for _ in range(5):
|
|
||||||
e = next(sample_embeddings)
|
|
||||||
embeddings.append(e)
|
|
||||||
producer.submit_embedding("test_topic", e)
|
|
||||||
|
|
||||||
results_1 = await consume_fn_1.get(5)
|
results_1 = await consume_fn_1.get(5)
|
||||||
assert_records_match(embeddings, results_1)
|
assert_records_match(embeddings, results_1)
|
||||||
|
|
||||||
start = consume_fn_1.embeddings[-1]["seq_id"]
|
start = consume_fn_1.embeddings[-1]["seq_id"]
|
||||||
consumer.subscribe("test_topic", consume_fn_2, start=start)
|
consumer.subscribe("test_topic", consume_fn_2, start=start)
|
||||||
for _ in range(5):
|
second_embeddings = produce_fns(producer, "test_topic", sample_embeddings, 5)[0]
|
||||||
e = next(sample_embeddings)
|
assert isinstance(embeddings, list)
|
||||||
embeddings.append(e)
|
embeddings.extend(second_embeddings)
|
||||||
producer.submit_embedding("test_topic", e)
|
|
||||||
|
|
||||||
results_2 = await consume_fn_2.get(5)
|
results_2 = await consume_fn_2.get(5)
|
||||||
assert_records_match(embeddings[-5:], results_2)
|
assert_records_match(embeddings[-5:], results_2)
|
||||||
|
|
||||||
@@ -246,6 +240,7 @@ async def test_start_seq_id(
|
|||||||
async def test_end_seq_id(
|
async def test_end_seq_id(
|
||||||
producer_consumer: Tuple[Producer, Consumer],
|
producer_consumer: Tuple[Producer, Consumer],
|
||||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||||
|
produce_fns: ProducerFn,
|
||||||
) -> None:
|
) -> None:
|
||||||
producer, consumer = producer_consumer
|
producer, consumer = producer_consumer
|
||||||
producer.reset_state()
|
producer.reset_state()
|
||||||
@@ -256,11 +251,7 @@ async def test_end_seq_id(
|
|||||||
|
|
||||||
consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid())
|
consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid())
|
||||||
|
|
||||||
embeddings = []
|
embeddings = produce_fns(producer, "test_topic", sample_embeddings, 10)[0]
|
||||||
for _ in range(10):
|
|
||||||
e = next(sample_embeddings)
|
|
||||||
embeddings.append(e)
|
|
||||||
producer.submit_embedding("test_topic", e)
|
|
||||||
|
|
||||||
results_1 = await consume_fn_1.get(10)
|
results_1 = await consume_fn_1.get(10)
|
||||||
assert_records_match(embeddings, results_1)
|
assert_records_match(embeddings, results_1)
|
||||||
@@ -274,3 +265,60 @@ async def test_end_seq_id(
|
|||||||
# Should never produce a 7th
|
# Should never produce a 7th
|
||||||
with pytest.raises(TimeoutError):
|
with pytest.raises(TimeoutError):
|
||||||
_ = await wait_for(consume_fn_2.get(7), timeout=1)
|
_ = await wait_for(consume_fn_2.get(7), timeout=1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_submit_batch(
|
||||||
|
producer_consumer: Tuple[Producer, Consumer],
|
||||||
|
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||||
|
) -> None:
|
||||||
|
producer, consumer = producer_consumer
|
||||||
|
producer.reset_state()
|
||||||
|
|
||||||
|
embeddings = [next(sample_embeddings) for _ in range(100)]
|
||||||
|
|
||||||
|
producer.create_topic("test_topic")
|
||||||
|
producer.submit_embeddings("test_topic", embeddings=embeddings)
|
||||||
|
|
||||||
|
consume_fn = CapturingConsumeFn()
|
||||||
|
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
|
||||||
|
|
||||||
|
recieved = await consume_fn.get(100)
|
||||||
|
assert_records_match(embeddings, recieved)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_topics_batch(
|
||||||
|
producer_consumer: Tuple[Producer, Consumer],
|
||||||
|
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||||
|
produce_fns: ProducerFn,
|
||||||
|
) -> None:
|
||||||
|
producer, consumer = producer_consumer
|
||||||
|
producer.reset_state()
|
||||||
|
|
||||||
|
N_TOPICS = 100
|
||||||
|
consume_fns = [CapturingConsumeFn() for _ in range(N_TOPICS)]
|
||||||
|
for i in range(N_TOPICS):
|
||||||
|
producer.create_topic(f"test_topic_{i}")
|
||||||
|
consumer.subscribe(
|
||||||
|
f"test_topic_{i}", consume_fns[i], start=consumer.min_seqid()
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings_n: List[List[SubmitEmbeddingRecord]] = [[] for _ in range(N_TOPICS)]
|
||||||
|
|
||||||
|
PRODUCE_BATCH_SIZE = 10
|
||||||
|
N_TO_PRODUCE = 100
|
||||||
|
total_produced = 0
|
||||||
|
for i in range(N_TO_PRODUCE // PRODUCE_BATCH_SIZE):
|
||||||
|
for i in range(N_TOPICS):
|
||||||
|
embeddings_n[i].extend(
|
||||||
|
produce_fns(
|
||||||
|
producer,
|
||||||
|
f"test_topic_{i}",
|
||||||
|
sample_embeddings,
|
||||||
|
PRODUCE_BATCH_SIZE,
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
recieved = await consume_fns[i].get(total_produced + PRODUCE_BATCH_SIZE)
|
||||||
|
assert_records_match(embeddings_n[i], recieved)
|
||||||
|
total_produced += PRODUCE_BATCH_SIZE
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import tempfile
|
|||||||
import pytest
|
import pytest
|
||||||
from typing import Generator, List, Callable, Iterator, Dict, Optional, Union, Sequence
|
from typing import Generator, List, Callable, Iterator, Dict, Optional, Union, Sequence
|
||||||
from chromadb.config import System, Settings
|
from chromadb.config import System, Settings
|
||||||
|
from chromadb.test.conftest import ProducerFn
|
||||||
from chromadb.types import (
|
from chromadb.types import (
|
||||||
SubmitEmbeddingRecord,
|
SubmitEmbeddingRecord,
|
||||||
MetadataEmbeddingRecord,
|
MetadataEmbeddingRecord,
|
||||||
@@ -128,16 +129,16 @@ def sync(segment: MetadataReader, seq_id: SeqId) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_insert_and_count(
|
def test_insert_and_count(
|
||||||
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
|
system: System,
|
||||||
|
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||||
|
produce_fns: ProducerFn,
|
||||||
) -> None:
|
) -> None:
|
||||||
producer = system.instance(Producer)
|
producer = system.instance(Producer)
|
||||||
system.reset_state()
|
system.reset_state()
|
||||||
|
|
||||||
topic = str(segment_definition["topic"])
|
topic = str(segment_definition["topic"])
|
||||||
|
|
||||||
max_id = 0
|
max_id = produce_fns(producer, topic, sample_embeddings, 3)[1][-1]
|
||||||
for i in range(3):
|
|
||||||
max_id = producer.submit_embedding(topic, next(sample_embeddings))
|
|
||||||
|
|
||||||
segment = SqliteMetadataSegment(system, segment_definition)
|
segment = SqliteMetadataSegment(system, segment_definition)
|
||||||
segment.start()
|
segment.start()
|
||||||
@@ -166,17 +167,15 @@ def assert_equiv_records(
|
|||||||
|
|
||||||
|
|
||||||
def test_get(
|
def test_get(
|
||||||
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
|
system: System,
|
||||||
|
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||||
|
produce_fns: ProducerFn,
|
||||||
) -> None:
|
) -> None:
|
||||||
producer = system.instance(Producer)
|
producer = system.instance(Producer)
|
||||||
system.reset_state()
|
system.reset_state()
|
||||||
topic = str(segment_definition["topic"])
|
topic = str(segment_definition["topic"])
|
||||||
|
|
||||||
embeddings = [next(sample_embeddings) for i in range(10)]
|
embeddings, seq_ids = produce_fns(producer, topic, sample_embeddings, 10)
|
||||||
|
|
||||||
seq_ids = []
|
|
||||||
for e in embeddings:
|
|
||||||
seq_ids.append(producer.submit_embedding(topic, e))
|
|
||||||
|
|
||||||
segment = SqliteMetadataSegment(system, segment_definition)
|
segment = SqliteMetadataSegment(system, segment_definition)
|
||||||
segment.start()
|
segment.start()
|
||||||
@@ -270,7 +269,9 @@ def test_get(
|
|||||||
|
|
||||||
|
|
||||||
def test_fulltext(
|
def test_fulltext(
|
||||||
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
|
system: System,
|
||||||
|
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||||
|
produce_fns: ProducerFn,
|
||||||
) -> None:
|
) -> None:
|
||||||
producer = system.instance(Producer)
|
producer = system.instance(Producer)
|
||||||
system.reset_state()
|
system.reset_state()
|
||||||
@@ -279,9 +280,7 @@ def test_fulltext(
|
|||||||
segment = SqliteMetadataSegment(system, segment_definition)
|
segment = SqliteMetadataSegment(system, segment_definition)
|
||||||
segment.start()
|
segment.start()
|
||||||
|
|
||||||
max_id = 0
|
max_id = produce_fns(producer, topic, sample_embeddings, 100)[1][-1]
|
||||||
for i in range(100):
|
|
||||||
max_id = producer.submit_embedding(topic, next(sample_embeddings))
|
|
||||||
|
|
||||||
sync(segment, max_id)
|
sync(segment, max_id)
|
||||||
|
|
||||||
@@ -331,7 +330,9 @@ def test_fulltext(
|
|||||||
|
|
||||||
|
|
||||||
def test_delete(
|
def test_delete(
|
||||||
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
|
system: System,
|
||||||
|
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||||
|
produce_fns: ProducerFn,
|
||||||
) -> None:
|
) -> None:
|
||||||
producer = system.instance(Producer)
|
producer = system.instance(Producer)
|
||||||
system.reset_state()
|
system.reset_state()
|
||||||
@@ -340,11 +341,8 @@ def test_delete(
|
|||||||
segment = SqliteMetadataSegment(system, segment_definition)
|
segment = SqliteMetadataSegment(system, segment_definition)
|
||||||
segment.start()
|
segment.start()
|
||||||
|
|
||||||
embeddings = [next(sample_embeddings) for i in range(10)]
|
embeddings, seq_ids = produce_fns(producer, topic, sample_embeddings, 10)
|
||||||
|
max_id = seq_ids[-1]
|
||||||
max_id = 0
|
|
||||||
for e in embeddings:
|
|
||||||
max_id = producer.submit_embedding(topic, e)
|
|
||||||
|
|
||||||
sync(segment, max_id)
|
sync(segment, max_id)
|
||||||
|
|
||||||
@@ -353,16 +351,16 @@ def test_delete(
|
|||||||
assert_equiv_records(embeddings[:1], results)
|
assert_equiv_records(embeddings[:1], results)
|
||||||
|
|
||||||
# Delete by ID
|
# Delete by ID
|
||||||
max_id = producer.submit_embedding(
|
delete_embedding = SubmitEmbeddingRecord(
|
||||||
topic,
|
id="embedding_0",
|
||||||
SubmitEmbeddingRecord(
|
embedding=None,
|
||||||
id="embedding_0",
|
encoding=None,
|
||||||
embedding=None,
|
metadata=None,
|
||||||
encoding=None,
|
operation=Operation.DELETE,
|
||||||
metadata=None,
|
|
||||||
operation=Operation.DELETE,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
max_id = produce_fns(producer, topic, (delete_embedding for _ in range(1)), 1)[1][
|
||||||
|
-1
|
||||||
|
]
|
||||||
|
|
||||||
sync(segment, max_id)
|
sync(segment, max_id)
|
||||||
|
|
||||||
@@ -370,16 +368,9 @@ def test_delete(
|
|||||||
assert segment.get_metadata(ids=["embedding_0"]) == []
|
assert segment.get_metadata(ids=["embedding_0"]) == []
|
||||||
|
|
||||||
# Delete is idempotent
|
# Delete is idempotent
|
||||||
max_id = producer.submit_embedding(
|
max_id = produce_fns(producer, topic, (delete_embedding for _ in range(1)), 1)[1][
|
||||||
topic,
|
-1
|
||||||
SubmitEmbeddingRecord(
|
]
|
||||||
id="embedding_0",
|
|
||||||
embedding=None,
|
|
||||||
encoding=None,
|
|
||||||
metadata=None,
|
|
||||||
operation=Operation.DELETE,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
sync(segment, max_id)
|
sync(segment, max_id)
|
||||||
assert segment.count() == 9
|
assert segment.count() == 9
|
||||||
@@ -420,7 +411,9 @@ def test_update(
|
|||||||
|
|
||||||
|
|
||||||
def test_upsert(
|
def test_upsert(
|
||||||
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
|
system: System,
|
||||||
|
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||||
|
produce_fns: ProducerFn,
|
||||||
) -> None:
|
) -> None:
|
||||||
producer = system.instance(Producer)
|
producer = system.instance(Producer)
|
||||||
system.reset_state()
|
system.reset_state()
|
||||||
@@ -439,7 +432,12 @@ def test_upsert(
|
|||||||
encoding=None,
|
encoding=None,
|
||||||
operation=Operation.UPSERT,
|
operation=Operation.UPSERT,
|
||||||
)
|
)
|
||||||
max_id = producer.submit_embedding(topic, update_record)
|
max_id = produce_fns(
|
||||||
|
producer=producer,
|
||||||
|
topic=topic,
|
||||||
|
embeddings=(update_record for _ in range(1)),
|
||||||
|
n=1,
|
||||||
|
)[1][-1]
|
||||||
sync(segment, max_id)
|
sync(segment, max_id)
|
||||||
results = segment.get_metadata(ids=["no_such_id"])
|
results = segment.get_metadata(ids=["no_such_id"])
|
||||||
assert results[0]["metadata"] == {"foo": "bar"}
|
assert results[0]["metadata"] == {"foo": "bar"}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from typing import Generator, List, Callable, Iterator, Type, cast
|
from typing import Generator, List, Callable, Iterator, Type, cast
|
||||||
from chromadb.config import System, Settings
|
from chromadb.config import System, Settings
|
||||||
|
from chromadb.test.conftest import ProducerFn
|
||||||
from chromadb.types import (
|
from chromadb.types import (
|
||||||
SubmitEmbeddingRecord,
|
SubmitEmbeddingRecord,
|
||||||
VectorQuery,
|
VectorQuery,
|
||||||
@@ -129,6 +130,7 @@ def test_insert_and_count(
|
|||||||
system: System,
|
system: System,
|
||||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||||
vector_reader: Type[VectorReader],
|
vector_reader: Type[VectorReader],
|
||||||
|
produce_fns: ProducerFn,
|
||||||
) -> None:
|
) -> None:
|
||||||
producer = system.instance(Producer)
|
producer = system.instance(Producer)
|
||||||
|
|
||||||
@@ -136,9 +138,9 @@ def test_insert_and_count(
|
|||||||
segment_definition = create_random_segment_definition()
|
segment_definition = create_random_segment_definition()
|
||||||
topic = str(segment_definition["topic"])
|
topic = str(segment_definition["topic"])
|
||||||
|
|
||||||
max_id = 0
|
max_id = produce_fns(
|
||||||
for i in range(3):
|
producer=producer, topic=topic, n=3, embeddings=sample_embeddings
|
||||||
max_id = producer.submit_embedding(topic, next(sample_embeddings))
|
)[1][-1]
|
||||||
|
|
||||||
segment = vector_reader(system, segment_definition)
|
segment = vector_reader(system, segment_definition)
|
||||||
segment.start()
|
segment.start()
|
||||||
@@ -146,8 +148,10 @@ def test_insert_and_count(
|
|||||||
sync(segment, max_id)
|
sync(segment, max_id)
|
||||||
|
|
||||||
assert segment.count() == 3
|
assert segment.count() == 3
|
||||||
for i in range(3):
|
|
||||||
max_id = producer.submit_embedding(topic, next(sample_embeddings))
|
max_id = produce_fns(
|
||||||
|
producer=producer, topic=topic, n=3, embeddings=sample_embeddings
|
||||||
|
)[1][-1]
|
||||||
|
|
||||||
sync(segment, max_id)
|
sync(segment, max_id)
|
||||||
assert segment.count() == 6
|
assert segment.count() == 6
|
||||||
@@ -165,6 +169,7 @@ def test_get_vectors(
|
|||||||
system: System,
|
system: System,
|
||||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||||
vector_reader: Type[VectorReader],
|
vector_reader: Type[VectorReader],
|
||||||
|
produce_fns: ProducerFn,
|
||||||
) -> None:
|
) -> None:
|
||||||
producer = system.instance(Producer)
|
producer = system.instance(Producer)
|
||||||
system.reset_state()
|
system.reset_state()
|
||||||
@@ -174,11 +179,9 @@ def test_get_vectors(
|
|||||||
segment = vector_reader(system, segment_definition)
|
segment = vector_reader(system, segment_definition)
|
||||||
segment.start()
|
segment.start()
|
||||||
|
|
||||||
embeddings = [next(sample_embeddings) for i in range(10)]
|
embeddings, seq_ids = produce_fns(
|
||||||
|
producer=producer, topic=topic, embeddings=sample_embeddings, n=10
|
||||||
seq_ids: List[SeqId] = []
|
)
|
||||||
for e in embeddings:
|
|
||||||
seq_ids.append(producer.submit_embedding(topic, e))
|
|
||||||
|
|
||||||
sync(segment, seq_ids[-1])
|
sync(segment, seq_ids[-1])
|
||||||
|
|
||||||
@@ -210,6 +213,7 @@ def test_ann_query(
|
|||||||
system: System,
|
system: System,
|
||||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||||
vector_reader: Type[VectorReader],
|
vector_reader: Type[VectorReader],
|
||||||
|
produce_fns: ProducerFn,
|
||||||
) -> None:
|
) -> None:
|
||||||
producer = system.instance(Producer)
|
producer = system.instance(Producer)
|
||||||
system.reset_state()
|
system.reset_state()
|
||||||
@@ -219,11 +223,9 @@ def test_ann_query(
|
|||||||
segment = vector_reader(system, segment_definition)
|
segment = vector_reader(system, segment_definition)
|
||||||
segment.start()
|
segment.start()
|
||||||
|
|
||||||
embeddings = [next(sample_embeddings) for i in range(100)]
|
embeddings, seq_ids = produce_fns(
|
||||||
|
producer=producer, topic=topic, embeddings=sample_embeddings, n=100
|
||||||
seq_ids: List[SeqId] = []
|
)
|
||||||
for e in embeddings:
|
|
||||||
seq_ids.append(producer.submit_embedding(topic, e))
|
|
||||||
|
|
||||||
sync(segment, seq_ids[-1])
|
sync(segment, seq_ids[-1])
|
||||||
|
|
||||||
@@ -275,6 +277,7 @@ def test_delete(
|
|||||||
system: System,
|
system: System,
|
||||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||||
vector_reader: Type[VectorReader],
|
vector_reader: Type[VectorReader],
|
||||||
|
produce_fns: ProducerFn,
|
||||||
) -> None:
|
) -> None:
|
||||||
producer = system.instance(Producer)
|
producer = system.instance(Producer)
|
||||||
system.reset_state()
|
system.reset_state()
|
||||||
@@ -284,26 +287,28 @@ def test_delete(
|
|||||||
segment = vector_reader(system, segment_definition)
|
segment = vector_reader(system, segment_definition)
|
||||||
segment.start()
|
segment.start()
|
||||||
|
|
||||||
embeddings = [next(sample_embeddings) for i in range(5)]
|
embeddings, seq_ids = produce_fns(
|
||||||
|
producer=producer, topic=topic, embeddings=sample_embeddings, n=5
|
||||||
seq_ids: List[SeqId] = []
|
)
|
||||||
for e in embeddings:
|
|
||||||
seq_ids.append(producer.submit_embedding(topic, e))
|
|
||||||
|
|
||||||
sync(segment, seq_ids[-1])
|
sync(segment, seq_ids[-1])
|
||||||
assert segment.count() == 5
|
assert segment.count() == 5
|
||||||
|
|
||||||
|
delete_record = SubmitEmbeddingRecord(
|
||||||
|
id=embeddings[0]["id"],
|
||||||
|
embedding=None,
|
||||||
|
encoding=None,
|
||||||
|
metadata=None,
|
||||||
|
operation=Operation.DELETE,
|
||||||
|
)
|
||||||
|
assert isinstance(seq_ids, List)
|
||||||
seq_ids.append(
|
seq_ids.append(
|
||||||
producer.submit_embedding(
|
produce_fns(
|
||||||
topic,
|
producer=producer,
|
||||||
SubmitEmbeddingRecord(
|
topic=topic,
|
||||||
id=embeddings[0]["id"],
|
n=1,
|
||||||
embedding=None,
|
embeddings=(delete_record for _ in range(1)),
|
||||||
encoding=None,
|
)[1][0]
|
||||||
metadata=None,
|
|
||||||
operation=Operation.DELETE,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
sync(segment, seq_ids[-1])
|
sync(segment, seq_ids[-1])
|
||||||
@@ -334,16 +339,12 @@ def test_delete(
|
|||||||
|
|
||||||
# Delete is idempotent
|
# Delete is idempotent
|
||||||
seq_ids.append(
|
seq_ids.append(
|
||||||
producer.submit_embedding(
|
produce_fns(
|
||||||
topic,
|
producer=producer,
|
||||||
SubmitEmbeddingRecord(
|
topic=topic,
|
||||||
id=embeddings[0]["id"],
|
n=1,
|
||||||
embedding=None,
|
embeddings=(delete_record for _ in range(1)),
|
||||||
encoding=None,
|
)[1][0]
|
||||||
metadata=None,
|
|
||||||
operation=Operation.DELETE,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
sync(segment, seq_ids[-1])
|
sync(segment, seq_ids[-1])
|
||||||
@@ -416,6 +417,7 @@ def test_update(
|
|||||||
system: System,
|
system: System,
|
||||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||||
vector_reader: Type[VectorReader],
|
vector_reader: Type[VectorReader],
|
||||||
|
produce_fns: ProducerFn,
|
||||||
) -> None:
|
) -> None:
|
||||||
producer = system.instance(Producer)
|
producer = system.instance(Producer)
|
||||||
system.reset_state()
|
system.reset_state()
|
||||||
@@ -428,16 +430,19 @@ def test_update(
|
|||||||
_test_update(producer, topic, segment, sample_embeddings, Operation.UPDATE)
|
_test_update(producer, topic, segment, sample_embeddings, Operation.UPDATE)
|
||||||
|
|
||||||
# test updating a nonexistent record
|
# test updating a nonexistent record
|
||||||
seq_id = producer.submit_embedding(
|
update_record = SubmitEmbeddingRecord(
|
||||||
topic,
|
id="no_such_record",
|
||||||
SubmitEmbeddingRecord(
|
embedding=[10.0, 10.0],
|
||||||
id="no_such_record",
|
encoding=ScalarEncoding.FLOAT32,
|
||||||
embedding=[10.0, 10.0],
|
metadata=None,
|
||||||
encoding=ScalarEncoding.FLOAT32,
|
operation=Operation.UPDATE,
|
||||||
metadata=None,
|
|
||||||
operation=Operation.UPDATE,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
seq_id = produce_fns(
|
||||||
|
producer=producer,
|
||||||
|
topic=topic,
|
||||||
|
n=1,
|
||||||
|
embeddings=(update_record for _ in range(1)),
|
||||||
|
)[1][0]
|
||||||
|
|
||||||
sync(segment, seq_id)
|
sync(segment, seq_id)
|
||||||
|
|
||||||
@@ -449,6 +454,7 @@ def test_upsert(
|
|||||||
system: System,
|
system: System,
|
||||||
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
sample_embeddings: Iterator[SubmitEmbeddingRecord],
|
||||||
vector_reader: Type[VectorReader],
|
vector_reader: Type[VectorReader],
|
||||||
|
produce_fns: ProducerFn,
|
||||||
) -> None:
|
) -> None:
|
||||||
producer = system.instance(Producer)
|
producer = system.instance(Producer)
|
||||||
system.reset_state()
|
system.reset_state()
|
||||||
@@ -461,16 +467,19 @@ def test_upsert(
|
|||||||
_test_update(producer, topic, segment, sample_embeddings, Operation.UPSERT)
|
_test_update(producer, topic, segment, sample_embeddings, Operation.UPSERT)
|
||||||
|
|
||||||
# test updating a nonexistent record
|
# test updating a nonexistent record
|
||||||
seq_id = producer.submit_embedding(
|
upsert_record = SubmitEmbeddingRecord(
|
||||||
topic,
|
id="no_such_record",
|
||||||
SubmitEmbeddingRecord(
|
embedding=[42, 42],
|
||||||
id="no_such_record",
|
encoding=ScalarEncoding.FLOAT32,
|
||||||
embedding=[42, 42],
|
metadata=None,
|
||||||
encoding=ScalarEncoding.FLOAT32,
|
operation=Operation.UPSERT,
|
||||||
metadata=None,
|
|
||||||
operation=Operation.UPSERT,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
seq_id = produce_fns(
|
||||||
|
producer=producer,
|
||||||
|
topic=topic,
|
||||||
|
n=1,
|
||||||
|
embeddings=(upsert_record for _ in range(1)),
|
||||||
|
)[1][0]
|
||||||
|
|
||||||
sync(segment, seq_id)
|
sync(segment, seq_id)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user