[BUG]: URL Parsing And Validation (#1118)

## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Added additional validations to URLs - URLs like api-gw.aws.com/dev
will now trigger an error asking the user to correctly specify the URL
with http or https
- When the full URL (http(s)://example.com) is provided by the user, the
port parameter is ignored (debug message is logged). An assumption is
made that the URL is entirely defined, thus not requiring additional
alterations such as injecting the port.
    - Added negative test cases for invalid URLs

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python

## Documentation Changes
TBD
This commit is contained in:
Trayan Azarov
2023-09-11 17:58:20 +03:00
committed by GitHub
parent ea73f05bdf
commit 9c1979c931
2 changed files with 76 additions and 9 deletions

View File

@@ -1,4 +1,5 @@
import json
import logging
from typing import Optional, cast
from typing import Sequence
from uuid import UUID
@@ -32,28 +33,54 @@ from chromadb.config import Settings, System
from chromadb.telemetry import Telemetry
from urllib.parse import urlparse, urlunparse, quote
logger = logging.getLogger(__name__)
class FastAPI(API):
_settings: Settings
@staticmethod
def _validate_host(host: str) -> None:
parsed = urlparse(host)
if "/" in host and parsed.scheme not in {"http", "https"}:
raise ValueError(
"Invalid URL. " f"Unrecognized protocol - {parsed.scheme}."
)
if "/" in host and (not host.startswith("http")):
raise ValueError(
"Invalid URL. "
"Seems that you are trying to pass URL as a host but without specifying the protocol. "
"Please add http:// or https:// to the host."
)
@staticmethod
def resolve_url(
chroma_server_host: str,
chroma_server_ssl_enabled: Optional[bool] = False,
default_api_path: Optional[str] = "",
chroma_server_http_port: int = 8000,
chroma_server_http_port: Optional[int] = 8000,
) -> str:
parsed = urlparse(chroma_server_host)
_skip_port = False
_chroma_server_host = chroma_server_host
FastAPI._validate_host(_chroma_server_host)
if _chroma_server_host.startswith("http"):
logger.debug("Skipping port as the user is passing a full URL")
_skip_port = True
parsed = urlparse(_chroma_server_host)
scheme = "https" if chroma_server_ssl_enabled else parsed.scheme or "http"
net_loc = parsed.netloc or parsed.hostname or chroma_server_host
port = parsed.port or chroma_server_http_port
port = (
":" + str(parsed.port or chroma_server_http_port) if not _skip_port else ""
)
path = parsed.path or default_api_path
if not path or path == net_loc or not path.endswith(default_api_path or ""):
if not path or path == net_loc:
path = default_api_path if default_api_path else ""
if not path.endswith(default_api_path or ""):
path = path + default_api_path if default_api_path else ""
full_url = urlunparse(
(scheme, f"{net_loc}:{port}", quote(path.replace("//", "/")), "", "", "")
(scheme, f"{net_loc}{port}", quote(path.replace("//", "/")), "", "", "")
)
return full_url

View File

@@ -1,6 +1,7 @@
from typing import Optional
from urllib.parse import urlparse
import pytest
from hypothesis import given, strategies as st
from chromadb.api.fastapi import FastAPI
@@ -28,7 +29,7 @@ def domain_strategy() -> st.SearchStrategy[str]:
return st.tuples(label, tld).map(".".join)
port_strategy = st.integers(min_value=1, max_value=65535)
port_strategy = st.one_of(st.integers(min_value=1, max_value=65535), st.none())
ssl_enabled_strategy = st.booleans()
@@ -56,8 +57,21 @@ def is_valid_url(url: str) -> bool:
def generate_valid_domain_url() -> st.SearchStrategy[str]:
return st.builds(
lambda url_scheme, hostname, url_path: f"{url_scheme}://{hostname}{url_path}",
url_scheme=st.sampled_from(["http", "https"]),
lambda url_scheme, hostname, url_path: f"{url_scheme}{hostname}{url_path}",
url_scheme=st.sampled_from(["http://", "https://"]),
hostname=domain_strategy(),
url_path=url_path_strategy(),
)
def generate_invalid_domain_url() -> st.SearchStrategy[str]:
return st.builds(
lambda url_scheme, hostname, url_path: f"{url_scheme}{hostname}{url_path}",
url_scheme=st.builds(
lambda scheme, suffix: f"{scheme}{suffix}",
scheme=st.text(max_size=10),
suffix=st.sampled_from(["://", ":///", ":////", ""]),
),
hostname=domain_strategy(),
url_path=url_path_strategy(),
)
@@ -76,7 +90,7 @@ host_or_domain_strategy = st.one_of(
)
def test_url_resolve(
hostname: str,
port: int,
port: Optional[int],
ssl_enabled: bool,
default_api_path: Optional[str],
) -> None:
@@ -90,5 +104,31 @@ def test_url_resolve(
assert (
_url.startswith("https") if ssl_enabled else _url.startswith("http")
), f"Invalid URL: {_url} - SSL Enabled: {ssl_enabled}"
if hostname.startswith("http"):
assert ":" + str(port) not in _url, f"Port in URL not expected: {_url}"
else:
assert ":" + str(port) in _url, f"Port in URL expected: {_url}"
if default_api_path:
assert _url.endswith(default_api_path), f"Invalid URL: {_url}"
@given(
hostname=generate_invalid_domain_url(),
port=port_strategy,
ssl_enabled=ssl_enabled_strategy,
default_api_path=st.sampled_from(["/api/v1", "/api/v2", None]),
)
def test_resolve_invalid(
hostname: str,
port: Optional[int],
ssl_enabled: bool,
default_api_path: Optional[str],
) -> None:
with pytest.raises(ValueError) as e:
FastAPI.resolve_url(
chroma_server_host=hostname,
chroma_server_http_port=port,
chroma_server_ssl_enabled=ssl_enabled,
default_api_path=default_api_path,
)
assert "Invalid URL" in str(e.value)