mirror of
https://github.com/placeholder-soft/chroma.git
synced 2026-04-29 12:24:58 +08:00
[BUG]: URL Parsing And Validation (#1118)
## Description of changes
*Summarize the changes made by this PR.*
- Improvements & Bug fixes
- Added additional validations to URLs - URLs like api-gw.aws.com/dev
will now trigger an error asking the user to correctly specify the URL
with http or https
- When the full URL (http(s)://example.com) is provided by the user, the
port parameter is ignored (debug message is logged). An assumption is
made that the URL is entirely defined, thus not requiring additional
alterations such as injecting the port.
- Added negative test cases for invalid URLs
## Test plan
*How are these changes tested?*
- [x] Tests pass locally with `pytest` for python
## Documentation Changes
TBD
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, cast
|
||||
from typing import Sequence
|
||||
from uuid import UUID
|
||||
@@ -32,28 +33,54 @@ from chromadb.config import Settings, System
|
||||
from chromadb.telemetry import Telemetry
|
||||
from urllib.parse import urlparse, urlunparse, quote
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FastAPI(API):
|
||||
_settings: Settings
|
||||
|
||||
@staticmethod
|
||||
def _validate_host(host: str) -> None:
|
||||
parsed = urlparse(host)
|
||||
if "/" in host and parsed.scheme not in {"http", "https"}:
|
||||
raise ValueError(
|
||||
"Invalid URL. " f"Unrecognized protocol - {parsed.scheme}."
|
||||
)
|
||||
if "/" in host and (not host.startswith("http")):
|
||||
raise ValueError(
|
||||
"Invalid URL. "
|
||||
"Seems that you are trying to pass URL as a host but without specifying the protocol. "
|
||||
"Please add http:// or https:// to the host."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def resolve_url(
|
||||
chroma_server_host: str,
|
||||
chroma_server_ssl_enabled: Optional[bool] = False,
|
||||
default_api_path: Optional[str] = "",
|
||||
chroma_server_http_port: int = 8000,
|
||||
chroma_server_http_port: Optional[int] = 8000,
|
||||
) -> str:
|
||||
parsed = urlparse(chroma_server_host)
|
||||
_skip_port = False
|
||||
_chroma_server_host = chroma_server_host
|
||||
FastAPI._validate_host(_chroma_server_host)
|
||||
if _chroma_server_host.startswith("http"):
|
||||
logger.debug("Skipping port as the user is passing a full URL")
|
||||
_skip_port = True
|
||||
parsed = urlparse(_chroma_server_host)
|
||||
|
||||
scheme = "https" if chroma_server_ssl_enabled else parsed.scheme or "http"
|
||||
net_loc = parsed.netloc or parsed.hostname or chroma_server_host
|
||||
port = parsed.port or chroma_server_http_port
|
||||
port = (
|
||||
":" + str(parsed.port or chroma_server_http_port) if not _skip_port else ""
|
||||
)
|
||||
path = parsed.path or default_api_path
|
||||
|
||||
if not path or path == net_loc or not path.endswith(default_api_path or ""):
|
||||
if not path or path == net_loc:
|
||||
path = default_api_path if default_api_path else ""
|
||||
if not path.endswith(default_api_path or ""):
|
||||
path = path + default_api_path if default_api_path else ""
|
||||
full_url = urlunparse(
|
||||
(scheme, f"{net_loc}:{port}", quote(path.replace("//", "/")), "", "", "")
|
||||
(scheme, f"{net_loc}{port}", quote(path.replace("//", "/")), "", "", "")
|
||||
)
|
||||
|
||||
return full_url
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, strategies as st
|
||||
|
||||
from chromadb.api.fastapi import FastAPI
|
||||
@@ -28,7 +29,7 @@ def domain_strategy() -> st.SearchStrategy[str]:
|
||||
return st.tuples(label, tld).map(".".join)
|
||||
|
||||
|
||||
port_strategy = st.integers(min_value=1, max_value=65535)
|
||||
port_strategy = st.one_of(st.integers(min_value=1, max_value=65535), st.none())
|
||||
|
||||
ssl_enabled_strategy = st.booleans()
|
||||
|
||||
@@ -56,8 +57,21 @@ def is_valid_url(url: str) -> bool:
|
||||
|
||||
def generate_valid_domain_url() -> st.SearchStrategy[str]:
|
||||
return st.builds(
|
||||
lambda url_scheme, hostname, url_path: f"{url_scheme}://{hostname}{url_path}",
|
||||
url_scheme=st.sampled_from(["http", "https"]),
|
||||
lambda url_scheme, hostname, url_path: f"{url_scheme}{hostname}{url_path}",
|
||||
url_scheme=st.sampled_from(["http://", "https://"]),
|
||||
hostname=domain_strategy(),
|
||||
url_path=url_path_strategy(),
|
||||
)
|
||||
|
||||
|
||||
def generate_invalid_domain_url() -> st.SearchStrategy[str]:
|
||||
return st.builds(
|
||||
lambda url_scheme, hostname, url_path: f"{url_scheme}{hostname}{url_path}",
|
||||
url_scheme=st.builds(
|
||||
lambda scheme, suffix: f"{scheme}{suffix}",
|
||||
scheme=st.text(max_size=10),
|
||||
suffix=st.sampled_from(["://", ":///", ":////", ""]),
|
||||
),
|
||||
hostname=domain_strategy(),
|
||||
url_path=url_path_strategy(),
|
||||
)
|
||||
@@ -76,7 +90,7 @@ host_or_domain_strategy = st.one_of(
|
||||
)
|
||||
def test_url_resolve(
|
||||
hostname: str,
|
||||
port: int,
|
||||
port: Optional[int],
|
||||
ssl_enabled: bool,
|
||||
default_api_path: Optional[str],
|
||||
) -> None:
|
||||
@@ -90,5 +104,31 @@ def test_url_resolve(
|
||||
assert (
|
||||
_url.startswith("https") if ssl_enabled else _url.startswith("http")
|
||||
), f"Invalid URL: {_url} - SSL Enabled: {ssl_enabled}"
|
||||
if hostname.startswith("http"):
|
||||
assert ":" + str(port) not in _url, f"Port in URL not expected: {_url}"
|
||||
else:
|
||||
assert ":" + str(port) in _url, f"Port in URL expected: {_url}"
|
||||
if default_api_path:
|
||||
assert _url.endswith(default_api_path), f"Invalid URL: {_url}"
|
||||
|
||||
|
||||
@given(
|
||||
hostname=generate_invalid_domain_url(),
|
||||
port=port_strategy,
|
||||
ssl_enabled=ssl_enabled_strategy,
|
||||
default_api_path=st.sampled_from(["/api/v1", "/api/v2", None]),
|
||||
)
|
||||
def test_resolve_invalid(
|
||||
hostname: str,
|
||||
port: Optional[int],
|
||||
ssl_enabled: bool,
|
||||
default_api_path: Optional[str],
|
||||
) -> None:
|
||||
with pytest.raises(ValueError) as e:
|
||||
FastAPI.resolve_url(
|
||||
chroma_server_host=hostname,
|
||||
chroma_server_http_port=port,
|
||||
chroma_server_ssl_enabled=ssl_enabled,
|
||||
default_api_path=default_api_path,
|
||||
)
|
||||
assert "Invalid URL" in str(e.value)
|
||||
|
||||
Reference in New Issue
Block a user