[ENH]: CIP-4: In and Not In Metadata Filters (#1081)

Cherry-picked from #1029

## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - Added support for `$in` and `$nin` metadata filters

> Note: See CIP in `docs/` or example notebook for more info

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python

## Documentation Changes
TBD

---------

Co-authored-by: Hammad Bashir <HammadB@users.noreply.github.com>
This commit is contained in:
Trayan Azarov
2023-09-05 20:42:01 +03:00
committed by GitHub
parent 750f2edbfa
commit 6dd2d4af0b
8 changed files with 354 additions and 20 deletions

View File

@@ -207,6 +207,8 @@ def validate_where(where: Where) -> Where:
if (
key != "$and"
and key != "$or"
and key != "$in"
and key != "$nin"
and not isinstance(value, (str, int, float, dict))
):
raise ValueError(
@@ -238,15 +240,37 @@ def validate_where(where: Where) -> Where:
raise ValueError(
f"Expected operand value to be an int or a float for operator {operator}, got {operand}"
)
if operator not in ["$gt", "$gte", "$lt", "$lte", "$ne", "$eq"]:
if operator in ["$in", "$nin"]:
if not isinstance(operand, list):
raise ValueError(
f"Expected operand value to be an list for operator {operator}, got {operand}"
)
if operator not in [
"$gt",
"$gte",
"$lt",
"$lte",
"$ne",
"$eq",
"$in",
"$nin",
]:
raise ValueError(
f"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, got {operator}"
f"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, $in, $nin, "
f"got {operator}"
)
if not isinstance(operand, (str, int, float)):
if not isinstance(operand, (str, int, float, list)):
raise ValueError(
f"Expected where operand value to be a str, int, or float, got {operand}"
f"Expected where operand value to be a str, int, float, or list of those type, got {operand}"
)
if isinstance(operand, list) and (
len(operand) == 0
or not all(isinstance(x, type(operand[0])) for x in operand)
):
raise ValueError(
f"Expected where operand value to be a non-empty list, and all values to obe of the same type "
f"got {operand}"
)
return where

View File

@@ -1,8 +1,8 @@
from typing import Optional, Sequence, Any, Tuple, cast, Generator, Union, Dict
from typing import Optional, Sequence, Any, Tuple, cast, Generator, Union, Dict, List
from chromadb.segment import MetadataReader
from chromadb.ingest import Consumer
from chromadb.config import System
from chromadb.types import Segment
from chromadb.types import Segment, InclusionExclusionOperator
from chromadb.db.impl.sqlite import SqliteDB
from overrides import override
from chromadb.db.base import (
@@ -146,7 +146,6 @@ class SqliteMetadataSegment(MetadataReader):
limit = limit or 2**63 - 1
offset = offset or 0
with self._db.tx() as cur:
return list(islice(self._records(cur, q), offset, offset + limit))
@@ -405,7 +404,6 @@ class SqliteMetadataSegment(MetadataReader):
self, q: QueryBuilder, where: Where, embeddings_t: Table, metadata_t: Table
) -> Criterion:
clause: list[Criterion] = []
for k, v in where.items():
if k == "$and":
criteria = [
@@ -419,8 +417,32 @@ class SqliteMetadataSegment(MetadataReader):
for w in cast(Sequence[Where], v)
]
clause.append(reduce(lambda x, y: x | y, criteria))
elif k == "$in":
expr = cast(
Dict[InclusionExclusionOperator, List[LiteralValue]], {k: v}
)
sq = (
self._db.querybuilder()
.from_(metadata_t)
.select(metadata_t.id)
.where(metadata_t.key.isin(ParameterValue(k)))
.where(_where_clause(expr, metadata_t))
)
clause.append(embeddings_t.id.isin(sq))
elif k == "$nin":
expr = cast(
Dict[InclusionExclusionOperator, List[LiteralValue]], {k: v}
)
sq = (
self._db.querybuilder()
.from_(metadata_t)
.select(metadata_t.id)
.where(metadata_t.key.notin(ParameterValue(k)))
.where(_where_clause(expr, metadata_t))
)
clause.append(embeddings_t.id.notin(sq))
else:
expr = cast(Union[LiteralValue, Dict[WhereOperator, LiteralValue]], v)
expr = cast(Union[LiteralValue, Dict[WhereOperator, LiteralValue]], v) # type: ignore
sq = (
self._db.querybuilder()
.from_(metadata_t)
@@ -492,24 +514,31 @@ def _decode_seq_id(seq_id_bytes: bytes) -> SeqId:
def _where_clause(
expr: Union[LiteralValue, Dict[WhereOperator, LiteralValue]],
expr: Union[
LiteralValue,
Dict[WhereOperator, LiteralValue],
Dict[InclusionExclusionOperator, List[LiteralValue]],
],
table: Table,
) -> Criterion:
"""Given a field name, an expression, and a table, construct a Pypika Criterion"""
# Literal value case
if isinstance(expr, (str, int, float, bool)):
return _where_clause({"$eq": expr}, table)
return _where_clause({cast(WhereOperator, "$eq"): expr}, table)
# Operator dict case
operator, value = next(iter(expr.items()))
return _value_criterion(value, operator, table)
def _value_criterion(value: LiteralValue, op: WhereOperator, table: Table) -> Criterion:
def _value_criterion(
value: Union[LiteralValue, List[LiteralValue]],
op: Union[WhereOperator, InclusionExclusionOperator],
table: Table,
) -> Criterion:
"""Return a criterion to compare a value with the appropriate columns given its type
and the operation type."""
if isinstance(value, str):
cols = [table.string_value]
# isinstance(True, int) evaluates to True, so we need to check for bools separately
@@ -519,6 +548,37 @@ def _value_criterion(value: LiteralValue, op: WhereOperator, table: Table) -> Cr
cols = [table.int_value]
elif isinstance(value, float) and op in ("$eq", "$ne"):
cols = [table.float_value]
elif isinstance(value, list) and op in ("$in", "$nin"):
_v = value
if len(_v) == 0:
raise ValueError(f"Empty list for {op} operator")
if isinstance(value[0], str):
col_exprs = [
table.string_value.isin(_v)
if op == "$in"
else table.str_value.notin(_v)
]
elif isinstance(value[0], bool):
col_exprs = [
table.bool_value.isin(_v) if op == "$in" else table.bool_value.notin(_v)
]
elif isinstance(value[0], int):
col_exprs = [
table.int_value.isin(_v) if op == "$in" else table.int_value.notin(_v)
]
elif isinstance(value[0], float):
col_exprs = [
table.float_value.isin(_v)
if op == "$in"
else table.float_value.notin(_v)
]
elif isinstance(value, list) and op in ("$in", "$nin"):
col_exprs = [
table.int_value.isin(value),
table.float_value.isin(value)
if op == "$in"
else table.float_value.notin(value),
]
else:
cols = [table.int_value, table.float_value]

View File

@@ -14,6 +14,7 @@ from hypothesis.stateful import RuleBasedStateMachine
from dataclasses import dataclass
from chromadb.api.types import Documents, Embeddings, Metadata
from chromadb.types import LiteralValue
# Set the random seed for reproducibility
np.random.seed(0) # unnecessary, hypothesis does this for us
@@ -448,6 +449,26 @@ class DeterministicRuleStrategy(SearchStrategy): # type: ignore
return True
def opposite_value(value: LiteralValue) -> SearchStrategy[Any]:
"""
Returns a strategy that will generate all valid values except the input value - testing of $nin
"""
if isinstance(value, float):
return st.floats(allow_nan=False, allow_infinity=False).filter(
lambda x: x != value
)
elif isinstance(value, str):
return safe_text.filter(lambda x: x != value)
elif isinstance(value, bool):
return st.booleans().filter(lambda x: x != value)
elif isinstance(value, int):
return st.integers(min_value=-(2**31), max_value=2**31 - 1).filter(
lambda x: x != value
)
else:
return st.from_type(type(value)).filter(lambda x: x != value)
@st.composite
def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where:
"""Generate a filter that could be used in a query against the given collection"""
@@ -457,7 +478,7 @@ def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where:
key = draw(st.sampled_from(known_keys))
value = collection.known_metadata_keys[key]
legal_ops: List[Optional[str]] = [None, "$eq", "$ne"]
legal_ops: List[Optional[str]] = [None, "$eq", "$ne", "$in", "$nin"]
if not isinstance(value, str) and not isinstance(value, bool):
legal_ops.extend(["$gt", "$lt", "$lte", "$gte"])
if isinstance(value, float):
@@ -468,6 +489,14 @@ def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where:
if op is None:
return {key: value}
elif op == "$in":
if isinstance(value, str) and not value:
return {}
return {key: {op: [value, *[draw(opposite_value(value)) for _ in range(3)]]}}
elif op == "$nin":
if isinstance(value, str) and not value:
return {}
return {key: {op: [draw(opposite_value(value)) for _ in range(3)]}}
else:
return {key: {op: value}}

View File

@@ -42,11 +42,16 @@ def _filter_where_clause(clause: Where, metadata: Metadata) -> bool:
if key == "$or":
assert isinstance(expr, list)
return any(_filter_where_clause(clause, metadata) for clause in expr)
if key == "$in":
assert isinstance(expr, list)
return metadata[key] in expr if key in metadata else False
if key == "$nin":
assert isinstance(expr, list)
return metadata[key] not in expr
# expr is an operator expression
assert isinstance(expr, dict)
op, val = list(expr.items())[0]
assert isinstance(metadata, dict)
if key not in metadata:
return False
@@ -55,6 +60,10 @@ def _filter_where_clause(clause: Where, metadata: Metadata) -> bool:
return key in metadata and metadata_key == val
elif op == "$ne":
return key in metadata and metadata_key != val
elif op == "$in":
return key in metadata and metadata_key in val
elif op == "$nin":
return key in metadata and metadata_key not in val
# The following conditions only make sense for numeric values
assert isinstance(metadata_key, int) or isinstance(metadata_key, float)
@@ -132,7 +141,6 @@ def _filter_embedding_set(
)
if not _filter_where_doc_clause(filter["where_document"], documents[i]):
ids.discard(normalized_record_set["ids"][i])
return list(ids)
@@ -174,7 +182,6 @@ def test_filterable_metadata_get(
return
coll.add(**record_set)
for filter in filters:
result_ids = coll.get(**filter)["ids"]
expected_ids = _filter_embedding_set(record_set, filter)

View File

@@ -122,7 +122,11 @@ WhereOperator = Union[
Literal["$ne"],
Literal["$eq"],
]
OperatorExpression = Dict[Union[WhereOperator, LogicalOperator], LiteralValue]
InclusionExclusionOperator = Union[Literal["$in"], Literal["$nin"]]
OperatorExpression = Union[
Dict[Union[WhereOperator, LogicalOperator], LiteralValue],
Dict[InclusionExclusionOperator, List[LiteralValue]],
]
Where = Dict[
Union[str, LogicalOperator], Union[LiteralValue, OperatorExpression, List["Where"]]

View File

@@ -191,5 +191,5 @@ test('wrong code returns an error', async () => {
// @ts-ignore - supposed to fail
const results = await collection.get({ where: { "test": { "$contains": "hello" } } });
expect(results.error).toBeDefined()
expect(results.error).toBe("ValueError('Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, got $contains')")
expect(results.error).toContain("ValueError('Expected where operator")
})

View File

@@ -0,0 +1,61 @@
# CIP-4: In and Not In Metadata Filters Proposal
## Status
Current Status: `Under Discussion`
## **Motivation**
Currently, Chroma does not provide a way to filter metadata through `in` and `not in`. This appears to be a frequent ask
from community members.
## **Public Interfaces**
The changes will affect the following public interfaces:
- `Where` and `OperatorExpression`
classes - https://github.com/chroma-core/chroma/blob/48700dd07f14bcfd8b206dc3b2e2795d5531094d/chromadb/types.py#L125-L129
- `collection.get()`
- `collection.query()`
## **Proposed Changes**
We suggest the introduction of two new operators `$in` and `$nin` that will be used to filter metadata. We call these
operators `InclusionExclusionOperator`.
We suggest the following new operator definition:
```python
InclusionExclusionOperator = Union[Literal["$in"], Literal["$nin"]]
```
Additionally, we suggest that those operators are added to `OperatorExpression` for seamless integration with
existing `Where` semantics:
```python
OperatorExpression = Union[
Dict[Union[WhereOperator, LogicalOperator], LiteralValue],
Dict[InclusionExclusionOperator, List[LiteralValue]],
]
```
An example of a query using the new operators would be:
```python
collection.query(query_texts=query,
where={"$and": [{"author": {'$in': ['john', 'jill']}}, {"article_type": {"$eq": "blog"}}]},
n_results=3)
```
## **Compatibility, Deprecation, and Migration Plan**
The change is compatible with existing release 0.4.x.
## **Test Plan**
Property tests will be updated to ensure boundary conditions are covered as well as interoperability with existing `Where`
operators.
## **Rejected Alternatives**
N/A

View File

@@ -0,0 +1,149 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2023-08-30T12:48:38.227653Z",
"start_time": "2023-08-30T12:48:27.744069Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Number of requested results 10 is greater than number of elements in index 3, updating n_results = 3\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'ids': [['1', '3']], 'distances': [[0.28824201226234436, 1.017508625984192]], 'metadatas': [[{'author': 'john'}, {'author': 'jill'}]], 'embeddings': None, 'documents': [['Article by john', 'Article by Jill']]}\n",
"{'ids': ['1', '3'], 'embeddings': None, 'metadatas': [{'author': 'john'}, {'author': 'jill'}], 'documents': ['Article by john', 'Article by Jill']}\n"
]
}
],
"source": [
"import chromadb\n",
"\n",
"from chromadb.utils import embedding_functions\n",
"\n",
"sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=\"all-MiniLM-L6-v2\")\n",
"\n",
"\n",
"client = chromadb.Client()\n",
"# client.heartbeat()\n",
"# client.reset()\n",
"collection = client.get_or_create_collection(\"test-where-list\", embedding_function=sentence_transformer_ef)\n",
"collection.add(documents=[\"Article by john\", \"Article by Jack\", \"Article by Jill\"],\n",
" metadatas=[{\"author\": \"john\"}, {\"author\": \"jack\"}, {\"author\": \"jill\"}], ids=[\"1\", \"2\", \"3\"])\n",
"\n",
"query = [\"Give me articles by john\"]\n",
"res = collection.query(query_texts=query,where={'author': {'$in': ['john', 'jill']}}, n_results=10)\n",
"print(res)\n",
"\n",
"res_get = collection.get(where={'author': {'$in': ['john', 'jill']}})\n",
"print(res_get)\n"
]
},
{
"cell_type": "markdown",
"source": [
"# Interactions with existing Where operators"
],
"metadata": {
"collapsed": false
},
"id": "752cef843ba2f900"
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [
{
"data": {
"text/plain": "{'ids': [['1']],\n 'distances': [[0.28824201226234436]],\n 'metadatas': [[{'article_type': 'blog', 'author': 'john'}]],\n 'embeddings': None,\n 'documents': [['Article by john']]}"
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"collection.upsert(documents=[\"Article by john\", \"Article by Jack\", \"Article by Jill\"],\n",
" metadatas=[{\"author\": \"john\",\"article_type\":\"blog\"}, {\"author\": \"jack\",\"article_type\":\"social\"}, {\"author\": \"jill\",\"article_type\":\"paper\"}], ids=[\"1\", \"2\", \"3\"])\n",
"\n",
"collection.query(query_texts=query,where={\"$and\":[{\"author\": {'$in': ['john', 'jill']}},{\"article_type\":{\"$eq\":\"blog\"}}]}, n_results=3)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-08-30T12:48:49.974353Z",
"start_time": "2023-08-30T12:48:49.938985Z"
}
},
"id": "ca56cda318f9e94d"
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [
{
"data": {
"text/plain": "{'ids': [['1', '3']],\n 'distances': [[0.28824201226234436, 1.017508625984192]],\n 'metadatas': [[{'article_type': 'blog', 'author': 'john'},\n {'article_type': 'paper', 'author': 'jill'}]],\n 'embeddings': None,\n 'documents': [['Article by john', 'Article by Jill']]}"
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"collection.query(query_texts=query,where={\"$or\":[{\"author\": {'$in': ['john']}},{\"article_type\":{\"$in\":[\"paper\"]}}]}, n_results=3)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-08-30T12:48:53.501431Z",
"start_time": "2023-08-30T12:48:53.481571Z"
}
},
"id": "f10e79ec90c797c1"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "d97b8b6dd96261d0"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}