diff --git a/.env.template b/.env.template index 4cf2a18..f4482d9 100644 --- a/.env.template +++ b/.env.template @@ -1,5 +1,6 @@ -DATABASE_TYPE=postgres -CONNECTION_URI=postgresql+psycopg://testuser:testpwd@localhost:5432/honcho +CONNECTION_URI=postgresql+psycopg://testuser:testpwd@localhost:5432/honcho # sample for local database + +# CONNECTION_URI=postgresql+psycopg://testuser:testpwd@database:5432/honcho # sample for docker-compose database OPENAI_API_KEY= diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml deleted file mode 100644 index adafe9f..0000000 --- a/.github/workflows/run_tests.yml +++ /dev/null @@ -1,46 +0,0 @@ -name: Run Tests -on: [push, pull_request] -jobs: - test: - permissions: - pull-requests: write - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.10 - uses: actions/setup-python@v3 - with: - python-version: "3.10" - - name: Install poetry - run: | - pip install poetry - - name: Start Database - run: | - cd api/local - docker compose up --wait - cd ../.. - - name: Start Server - run: | - cd api - poetry install --no-root - poetry run uvicorn src.main:app & - sleep 5 - cd .. - env: - DATABASE_TYPE: postgres - CONNECTION_URI: postgresql+psycopg://testuser:testpwd@localhost:5432/honcho - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - - name: Run Tests - run: | - cd sdk - poetry install - poetry run pytest - cd .. - - name: Stop Database - run: | - cd api/local - docker compose down - cd ../.. - - name: Stop Server - run: | - kill $(jobs -p) || true diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 14a2b11..67bd5ec 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -107,13 +107,17 @@ As mentioned earlier a `docker-compose` template is included for running Honcho. As an alternative to running Honcho locally it can also be run with the compose template. +The docker-compose template is set to use an environment file called `.env`. +You can also copy the `.env.template` and fill with the appropriate values. + Copy the template and update the appropriate environment variables before launching the service. ```bash cd honcho/api +cp .env.template .env +# update the file with openai key and other wanted environment variables cp docker-compose.yml.example docker-compose.yml -[ update the file with openai key and other wanted environment variables ] docker compose up ``` diff --git a/docker-compose.yml.example b/docker-compose.yml.example index f68e350..1ba2026 100644 --- a/docker-compose.yml.example +++ b/docker-compose.yml.example @@ -11,24 +11,20 @@ services: - 8000:8000 volumes: - .:/app - environment: - - DATABASE_TYPE=postgres - - CONNECTION_URI=postgresql+psycopg://testuser:testpwd@database:5432/honcho - - OPENAI_API_KEY=[YOUR_OPENAI_API_KEY] - - OPENTELEMETRY_ENABLED=false - - SENTRY_ENABLED=false - - SENTRY_DSN= - - OTEL_SERVICE_NAME=honcho - - OTEL_PYTHON_LOGGING_AUTO_INSTRUMENTATION_ENABLED=true - - OTEL_PYTHON_LOG_CORRELATION=true - - OTEL_PYTHON_LOG_LEVEL= - - OTEL_EXPORTER_OTLP_PROTOCOL= - - OTEL_EXPORTER_OTLP_ENDPOINT= - - OTEL_EXPORTER_OTLP_HEADERS= - - OTEL_RESOURCE_ATTRIBUTES= - - DEBUG_LOG_OTEL_TO_PROVIDER=false - - DEBUG_LOG_OTEL_TO_CONSOLE=true - - USE_AUTH_SERVICE=false + env_file: + - .env + deriver: + build: + context: . + dockerfile: Dockerfile + entrypoint: ["python", "-m", "src.deriver"] + depends_on: + database: + condition: service_healthy + volumes: + - .:/app + env_file: + - .env database: image: ankane/pgvector restart: always diff --git a/poetry.lock b/poetry.lock index ff6af70..f8ec958 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,9 +1,10 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. [[package]] name = "annotated-types" version = "0.6.0" description = "Reusable constraint types to use with typing.Annotated" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -15,6 +16,7 @@ files = [ name = "anyio" version = "4.3.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -37,6 +39,7 @@ trio = ["trio (>=0.23)"] name = "asgiref" version = "3.8.1" description = "ASGI specs, helper code, and adapters" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -54,6 +57,7 @@ tests = ["mypy (>=0.800)", "pytest", "pytest-asyncio"] name = "certifi" version = "2024.2.2" description = "Python package for providing Mozilla's CA Bundle." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -65,6 +69,7 @@ files = [ name = "charset-normalizer" version = "3.3.2" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -164,6 +169,7 @@ files = [ name = "click" version = "8.1.7" description = "Composable command line interface toolkit" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -178,6 +184,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -189,6 +196,7 @@ files = [ name = "deprecated" version = "1.2.14" description = "Python @deprecated decorator to deprecate old python classes, functions or methods." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -206,6 +214,7 @@ dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] name = "distro" version = "1.9.0" description = "Distro - an OS platform information API" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -217,6 +226,7 @@ files = [ name = "docstring-parser" version = "0.15" description = "Parse Python docstrings in reST, Google and Numpydoc format" +category = "main" optional = false python-versions = ">=3.6,<4.0" files = [ @@ -228,6 +238,7 @@ files = [ name = "exceptiongroup" version = "1.2.1" description = "Backport of PEP 654 (exception groups)" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -242,6 +253,7 @@ test = ["pytest (>=6)"] name = "fastapi" version = "0.109.2" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -261,6 +273,7 @@ all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)" name = "fastapi-pagination" version = "0.12.24" description = "FastAPI pagination" +category = "main" optional = false python-versions = "<4.0,>=3.8" files = [ @@ -294,6 +307,7 @@ tortoise = ["tortoise-orm (>=0.16.18,<0.21.0)"] name = "googleapis-common-protos" version = "1.63.0" description = "Common protobufs used in Google APIs" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -311,6 +325,7 @@ grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] name = "greenlet" version = "3.0.3" description = "Lightweight in-process concurrent programming" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -382,6 +397,7 @@ test = ["objgraph", "psutil"] name = "grpcio" version = "1.63.0" description = "HTTP/2-based RPC framework" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -440,6 +456,7 @@ protobuf = ["grpcio-tools (>=1.63.0)"] name = "h11" version = "0.14.0" description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -451,6 +468,7 @@ files = [ name = "httpcore" version = "1.0.5" description = "A minimal low-level HTTP client." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -465,13 +483,14 @@ h11 = ">=0.13,<0.15" [package.extras] asyncio = ["anyio (>=4.0,<5.0)"] http2 = ["h2 (>=3,<5)"] -socks = ["socksio (==1.*)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] trio = ["trio (>=0.22.0,<0.26.0)"] [[package]] name = "httptools" version = "0.6.1" description = "A collection of framework independent HTTP protocol utils." +category = "main" optional = false python-versions = ">=3.8.0" files = [ @@ -520,6 +539,7 @@ test = ["Cython (>=0.29.24,<0.30.0)"] name = "httpx" version = "0.27.0" description = "The next generation HTTP client." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -530,20 +550,21 @@ files = [ [package.dependencies] anyio = "*" certifi = "*" -httpcore = "==1.*" +httpcore = ">=1.0.0,<2.0.0" idna = "*" sniffio = "*" [package.extras] brotli = ["brotli", "brotlicffi"] -cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +cli = ["click (>=8.0.0,<9.0.0)", "pygments (>=2.0.0,<3.0.0)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] -socks = ["socksio (==1.*)"] +socks = ["socksio (>=1.0.0,<2.0.0)"] [[package]] name = "idna" version = "3.7" description = "Internationalized Domain Names in Applications (IDNA)" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -555,6 +576,7 @@ files = [ name = "importlib-metadata" version = "6.11.0" description = "Read metadata from Python packages" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -574,6 +596,7 @@ testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs name = "importlib-resources" version = "6.4.0" description = "Read resources from Python packages" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -592,6 +615,7 @@ testing = ["jaraco.test (>=5.4)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "p name = "limits" version = "3.12.0" description = "Rate limiting utilities" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -621,6 +645,7 @@ rediscluster = ["redis (>=4.2.0,!=4.5.2,!=4.5.3)"] name = "mirascope" version = "0.12.3" description = "LLM toolkit for lightning-fast, high-quality development" +category = "main" optional = false python-versions = "<4.0,>=3.9" files = [ @@ -652,6 +677,7 @@ weave = ["weave (>=0.50.2,<1.0.0)"] name = "numpy" version = "1.26.4" description = "Fundamental package for array computing in Python" +category = "main" optional = false python-versions = ">=3.9" files = [ @@ -697,6 +723,7 @@ files = [ name = "openai" version = "1.30.1" description = "The official Python library for the openai API" +category = "main" optional = false python-versions = ">=3.7.1" files = [ @@ -720,6 +747,7 @@ datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] name = "opentelemetry-api" version = "1.23.0" description = "OpenTelemetry Python API" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -735,6 +763,7 @@ importlib-metadata = ">=6.0,<7.0" name = "opentelemetry-exporter-otlp" version = "1.23.0" description = "OpenTelemetry Collector Exporters" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -750,6 +779,7 @@ opentelemetry-exporter-otlp-proto-http = "1.23.0" name = "opentelemetry-exporter-otlp-proto-common" version = "1.23.0" description = "OpenTelemetry Protobuf encoding" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -764,6 +794,7 @@ opentelemetry-proto = "1.23.0" name = "opentelemetry-exporter-otlp-proto-grpc" version = "1.23.0" description = "OpenTelemetry Collector Protobuf over gRPC Exporter" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -787,6 +818,7 @@ test = ["pytest-grpc"] name = "opentelemetry-exporter-otlp-proto-http" version = "1.23.0" description = "OpenTelemetry Collector Protobuf over HTTP Exporter" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -810,6 +842,7 @@ test = ["responses (>=0.22.0,<0.25)"] name = "opentelemetry-instrumentation" version = "0.44b0" description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -826,6 +859,7 @@ wrapt = ">=1.0.0,<2.0.0" name = "opentelemetry-instrumentation-asgi" version = "0.44b0" description = "ASGI instrumentation for OpenTelemetry" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -848,6 +882,7 @@ test = ["opentelemetry-instrumentation-asgi[instruments]", "opentelemetry-test-u name = "opentelemetry-instrumentation-fastapi" version = "0.44b0" description = "OpenTelemetry FastAPI Instrumentation" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -870,6 +905,7 @@ test = ["httpx (>=0.22,<1.0)", "opentelemetry-instrumentation-fastapi[instrument name = "opentelemetry-instrumentation-logging" version = "0.44b0" description = "OpenTelemetry Logging instrumentation" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -888,6 +924,7 @@ test = ["opentelemetry-test-utils (==0.44b0)"] name = "opentelemetry-instrumentation-sqlalchemy" version = "0.44b0" description = "OpenTelemetry SQLAlchemy instrumentation" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -910,6 +947,7 @@ test = ["opentelemetry-instrumentation-sqlalchemy[instruments]", "opentelemetry- name = "opentelemetry-proto" version = "1.23.0" description = "OpenTelemetry Python Proto" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -924,6 +962,7 @@ protobuf = ">=3.19,<5.0" name = "opentelemetry-sdk" version = "1.23.0" description = "OpenTelemetry Python SDK" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -940,6 +979,7 @@ typing-extensions = ">=3.7.4" name = "opentelemetry-semantic-conventions" version = "0.44b0" description = "OpenTelemetry Semantic Conventions" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -951,6 +991,7 @@ files = [ name = "opentelemetry-util-http" version = "0.44b0" description = "Web util for OpenTelemetry" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -962,6 +1003,7 @@ files = [ name = "packaging" version = "24.0" description = "Core utilities for Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -973,6 +1015,7 @@ files = [ name = "pgvector" version = "0.2.5" description = "pgvector support for Python" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -986,6 +1029,7 @@ numpy = "*" name = "protobuf" version = "4.25.3" description = "" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1006,6 +1050,7 @@ files = [ name = "psycopg" version = "3.1.19" description = "PostgreSQL database adapter for Python" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1030,6 +1075,7 @@ test = ["anyio (>=3.6.2,<4.0)", "mypy (>=1.4.1)", "pproxy (>=2.7)", "pytest (>=6 name = "psycopg-binary" version = "3.1.19" description = "PostgreSQL database adapter for Python -- C optimisation distribution" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1102,6 +1148,7 @@ files = [ name = "pydantic" version = "2.7.1" description = "Data validation using Python type hints" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1121,6 +1168,7 @@ email = ["email-validator (>=2.0.0)"] name = "pydantic-core" version = "2.18.2" description = "Core functionality for Pydantic validation and serialization" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1208,24 +1256,11 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" -[[package]] -name = "python-dateutil" -version = "2.9.0.post0" -description = "Extensions to the standard Python datetime module" -optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" -files = [ - {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, - {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, -] - -[package.dependencies] -six = ">=1.5" - [[package]] name = "python-dotenv" version = "1.0.1" description = "Read key-value pairs from a .env file and set them as environment variables" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1236,26 +1271,11 @@ files = [ [package.extras] cli = ["click (>=5.0)"] -[[package]] -name = "realtime" -version = "1.0.4" -description = "" -optional = false -python-versions = "<4.0,>=3.8" -files = [ - {file = "realtime-1.0.4-py3-none-any.whl", hash = "sha256:b06bea001985f089167320bda1e91c6b2d866f56ca810bb8d768ee3cf695ee21"}, - {file = "realtime-1.0.4.tar.gz", hash = "sha256:a9095f60121a365e84656c582e6ccd8dc8b3a732ddddb2ccd26cc3d32b77bdf6"}, -] - -[package.dependencies] -python-dateutil = ">=2.8.1,<3.0.0" -typing-extensions = ">=4.11.0,<5.0.0" -websockets = ">=11,<13" - [[package]] name = "requests" version = "2.31.0" description = "Python HTTP for Humans." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1277,6 +1297,7 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] name = "sentry-sdk" version = "1.45.0" description = "Python client for Sentry (https://sentry.io)" +category = "main" optional = false python-versions = "*" files = [ @@ -1326,6 +1347,7 @@ tornado = ["tornado (>=5)"] name = "setuptools" version = "69.5.1" description = "Easily download, build, install, upgrade, and uninstall Python packages" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1338,21 +1360,11 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] -[[package]] -name = "six" -version = "1.16.0" -description = "Python 2 and 3 compatibility utilities" -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" -files = [ - {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, - {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, -] - [[package]] name = "slowapi" version = "0.1.9" description = "A rate limiting extension for Starlette and Fastapi" +category = "main" optional = false python-versions = ">=3.7,<4.0" files = [ @@ -1370,6 +1382,7 @@ redis = ["redis (>=3.4.1,<4.0.0)"] name = "sniffio" version = "1.3.1" description = "Sniff out which async library your code is running under" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1381,6 +1394,7 @@ files = [ name = "sqlalchemy" version = "2.0.30" description = "Database Abstraction Library" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1468,6 +1482,7 @@ sqlcipher = ["sqlcipher3_binary"] name = "starlette" version = "0.36.3" description = "The little ASGI library that shines." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1486,6 +1501,7 @@ full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7 name = "tenacity" version = "8.3.0" description = "Retry code until it succeeds" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1501,6 +1517,7 @@ test = ["pytest", "tornado (>=4.5)", "typeguard"] name = "tqdm" version = "4.66.4" description = "Fast, Extensible Progress Meter" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1521,6 +1538,7 @@ telegram = ["requests"] name = "typing-extensions" version = "4.11.0" description = "Backported and Experimental Type Hints for Python 3.8+" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1532,6 +1550,7 @@ files = [ name = "tzdata" version = "2024.1" description = "Provider of IANA time zone data" +category = "main" optional = false python-versions = ">=2" files = [ @@ -1543,6 +1562,7 @@ files = [ name = "urllib3" version = "2.2.1" description = "HTTP library with thread-safe connection pooling, file post, and more." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1560,6 +1580,7 @@ zstd = ["zstandard (>=0.18.0)"] name = "uvicorn" version = "0.24.0.post1" description = "The lightning-fast ASGI server." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1579,6 +1600,7 @@ standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", name = "uvloop" version = "0.19.0" description = "Fast implementation of asyncio event loop on top of libuv" +category = "main" optional = false python-versions = ">=3.8.0" files = [ @@ -1619,91 +1641,11 @@ files = [ docs = ["Sphinx (>=4.1.2,<4.2.0)", "sphinx-rtd-theme (>=0.5.2,<0.6.0)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"] test = ["Cython (>=0.29.36,<0.30.0)", "aiohttp (==3.9.0b0)", "aiohttp (>=3.8.1)", "flake8 (>=5.0,<6.0)", "mypy (>=0.800)", "psutil", "pyOpenSSL (>=23.0.0,<23.1.0)", "pycodestyle (>=2.9.0,<2.10.0)"] -[[package]] -name = "websockets" -version = "12.0" -description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" -optional = false -python-versions = ">=3.8" -files = [ - {file = "websockets-12.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d554236b2a2006e0ce16315c16eaa0d628dab009c33b63ea03f41c6107958374"}, - {file = "websockets-12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2d225bb6886591b1746b17c0573e29804619c8f755b5598d875bb4235ea639be"}, - {file = "websockets-12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:eb809e816916a3b210bed3c82fb88eaf16e8afcf9c115ebb2bacede1797d2547"}, - {file = "websockets-12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c588f6abc13f78a67044c6b1273a99e1cf31038ad51815b3b016ce699f0d75c2"}, - {file = "websockets-12.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5aa9348186d79a5f232115ed3fa9020eab66d6c3437d72f9d2c8ac0c6858c558"}, - {file = "websockets-12.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6350b14a40c95ddd53e775dbdbbbc59b124a5c8ecd6fbb09c2e52029f7a9f480"}, - {file = "websockets-12.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:70ec754cc2a769bcd218ed8d7209055667b30860ffecb8633a834dde27d6307c"}, - {file = "websockets-12.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6e96f5ed1b83a8ddb07909b45bd94833b0710f738115751cdaa9da1fb0cb66e8"}, - {file = "websockets-12.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4d87be612cbef86f994178d5186add3d94e9f31cc3cb499a0482b866ec477603"}, - {file = "websockets-12.0-cp310-cp310-win32.whl", hash = "sha256:befe90632d66caaf72e8b2ed4d7f02b348913813c8b0a32fae1cc5fe3730902f"}, - {file = "websockets-12.0-cp310-cp310-win_amd64.whl", hash = "sha256:363f57ca8bc8576195d0540c648aa58ac18cf85b76ad5202b9f976918f4219cf"}, - {file = "websockets-12.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5d873c7de42dea355d73f170be0f23788cf3fa9f7bed718fd2830eefedce01b4"}, - {file = "websockets-12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3f61726cae9f65b872502ff3c1496abc93ffbe31b278455c418492016e2afc8f"}, - {file = "websockets-12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ed2fcf7a07334c77fc8a230755c2209223a7cc44fc27597729b8ef5425aa61a3"}, - {file = "websockets-12.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e332c210b14b57904869ca9f9bf4ca32f5427a03eeb625da9b616c85a3a506c"}, - {file = "websockets-12.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5693ef74233122f8ebab026817b1b37fe25c411ecfca084b29bc7d6efc548f45"}, - {file = "websockets-12.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e9e7db18b4539a29cc5ad8c8b252738a30e2b13f033c2d6e9d0549b45841c04"}, - {file = "websockets-12.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6e2df67b8014767d0f785baa98393725739287684b9f8d8a1001eb2839031447"}, - {file = "websockets-12.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:bea88d71630c5900690fcb03161ab18f8f244805c59e2e0dc4ffadae0a7ee0ca"}, - {file = "websockets-12.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dff6cdf35e31d1315790149fee351f9e52978130cef6c87c4b6c9b3baf78bc53"}, - {file = "websockets-12.0-cp311-cp311-win32.whl", hash = "sha256:3e3aa8c468af01d70332a382350ee95f6986db479ce7af14d5e81ec52aa2b402"}, - {file = "websockets-12.0-cp311-cp311-win_amd64.whl", hash = "sha256:25eb766c8ad27da0f79420b2af4b85d29914ba0edf69f547cc4f06ca6f1d403b"}, - {file = "websockets-12.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0e6e2711d5a8e6e482cacb927a49a3d432345dfe7dea8ace7b5790df5932e4df"}, - {file = "websockets-12.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:dbcf72a37f0b3316e993e13ecf32f10c0e1259c28ffd0a85cee26e8549595fbc"}, - {file = "websockets-12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:12743ab88ab2af1d17dd4acb4645677cb7063ef4db93abffbf164218a5d54c6b"}, - {file = "websockets-12.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b645f491f3c48d3f8a00d1fce07445fab7347fec54a3e65f0725d730d5b99cb"}, - {file = "websockets-12.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9893d1aa45a7f8b3bc4510f6ccf8db8c3b62120917af15e3de247f0780294b92"}, - {file = "websockets-12.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f38a7b376117ef7aff996e737583172bdf535932c9ca021746573bce40165ed"}, - {file = "websockets-12.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:f764ba54e33daf20e167915edc443b6f88956f37fb606449b4a5b10ba42235a5"}, - {file = "websockets-12.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:1e4b3f8ea6a9cfa8be8484c9221ec0257508e3a1ec43c36acdefb2a9c3b00aa2"}, - {file = "websockets-12.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9fdf06fd06c32205a07e47328ab49c40fc1407cdec801d698a7c41167ea45113"}, - {file = "websockets-12.0-cp312-cp312-win32.whl", hash = "sha256:baa386875b70cbd81798fa9f71be689c1bf484f65fd6fb08d051a0ee4e79924d"}, - {file = "websockets-12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ae0a5da8f35a5be197f328d4727dbcfafa53d1824fac3d96cdd3a642fe09394f"}, - {file = "websockets-12.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5f6ffe2c6598f7f7207eef9a1228b6f5c818f9f4d53ee920aacd35cec8110438"}, - {file = "websockets-12.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9edf3fc590cc2ec20dc9d7a45108b5bbaf21c0d89f9fd3fd1685e223771dc0b2"}, - {file = "websockets-12.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8572132c7be52632201a35f5e08348137f658e5ffd21f51f94572ca6c05ea81d"}, - {file = "websockets-12.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:604428d1b87edbf02b233e2c207d7d528460fa978f9e391bd8aaf9c8311de137"}, - {file = "websockets-12.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a9d160fd080c6285e202327aba140fc9a0d910b09e423afff4ae5cbbf1c7205"}, - {file = "websockets-12.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87b4aafed34653e465eb77b7c93ef058516cb5acf3eb21e42f33928616172def"}, - {file = "websockets-12.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b2ee7288b85959797970114deae81ab41b731f19ebcd3bd499ae9ca0e3f1d2c8"}, - {file = "websockets-12.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:7fa3d25e81bfe6a89718e9791128398a50dec6d57faf23770787ff441d851967"}, - {file = "websockets-12.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a571f035a47212288e3b3519944f6bf4ac7bc7553243e41eac50dd48552b6df7"}, - {file = "websockets-12.0-cp38-cp38-win32.whl", hash = "sha256:3c6cc1360c10c17463aadd29dd3af332d4a1adaa8796f6b0e9f9df1fdb0bad62"}, - {file = "websockets-12.0-cp38-cp38-win_amd64.whl", hash = "sha256:1bf386089178ea69d720f8db6199a0504a406209a0fc23e603b27b300fdd6892"}, - {file = "websockets-12.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ab3d732ad50a4fbd04a4490ef08acd0517b6ae6b77eb967251f4c263011a990d"}, - {file = "websockets-12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a1d9697f3337a89691e3bd8dc56dea45a6f6d975f92e7d5f773bc715c15dde28"}, - {file = "websockets-12.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1df2fbd2c8a98d38a66f5238484405b8d1d16f929bb7a33ed73e4801222a6f53"}, - {file = "websockets-12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23509452b3bc38e3a057382c2e941d5ac2e01e251acce7adc74011d7d8de434c"}, - {file = "websockets-12.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e5fc14ec6ea568200ea4ef46545073da81900a2b67b3e666f04adf53ad452ec"}, - {file = "websockets-12.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46e71dbbd12850224243f5d2aeec90f0aaa0f2dde5aeeb8fc8df21e04d99eff9"}, - {file = "websockets-12.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b81f90dcc6c85a9b7f29873beb56c94c85d6f0dac2ea8b60d995bd18bf3e2aae"}, - {file = "websockets-12.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:a02413bc474feda2849c59ed2dfb2cddb4cd3d2f03a2fedec51d6e959d9b608b"}, - {file = "websockets-12.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:bbe6013f9f791944ed31ca08b077e26249309639313fff132bfbf3ba105673b9"}, - {file = "websockets-12.0-cp39-cp39-win32.whl", hash = "sha256:cbe83a6bbdf207ff0541de01e11904827540aa069293696dd528a6640bd6a5f6"}, - {file = "websockets-12.0-cp39-cp39-win_amd64.whl", hash = "sha256:fc4e7fa5414512b481a2483775a8e8be7803a35b30ca805afa4998a84f9fd9e8"}, - {file = "websockets-12.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:248d8e2446e13c1d4326e0a6a4e9629cb13a11195051a73acf414812700badbd"}, - {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f44069528d45a933997a6fef143030d8ca8042f0dfaad753e2906398290e2870"}, - {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c4e37d36f0d19f0a4413d3e18c0d03d0c268ada2061868c1e6f5ab1a6d575077"}, - {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d829f975fc2e527a3ef2f9c8f25e553eb7bc779c6665e8e1d52aa22800bb38b"}, - {file = "websockets-12.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2c71bd45a777433dd9113847af751aae36e448bc6b8c361a566cb043eda6ec30"}, - {file = "websockets-12.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0bee75f400895aef54157b36ed6d3b308fcab62e5260703add87f44cee9c82a6"}, - {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:423fc1ed29f7512fceb727e2d2aecb952c46aa34895e9ed96071821309951123"}, - {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:27a5e9964ef509016759f2ef3f2c1e13f403725a5e6a1775555994966a66e931"}, - {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3181df4583c4d3994d31fb235dc681d2aaad744fbdbf94c4802485ececdecf2"}, - {file = "websockets-12.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:b067cb952ce8bf40115f6c19f478dc71c5e719b7fbaa511359795dfd9d1a6468"}, - {file = "websockets-12.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:00700340c6c7ab788f176d118775202aadea7602c5cc6be6ae127761c16d6b0b"}, - {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e469d01137942849cff40517c97a30a93ae79917752b34029f0ec72df6b46399"}, - {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffefa1374cd508d633646d51a8e9277763a9b78ae71324183693959cf94635a7"}, - {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba0cab91b3956dfa9f512147860783a1829a8d905ee218a9837c18f683239611"}, - {file = "websockets-12.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2cb388a5bfb56df4d9a406783b7f9dbefb888c09b71629351cc6b036e9259370"}, - {file = "websockets-12.0-py3-none-any.whl", hash = "sha256:dc284bbc8d7c78a6c69e0c7325ab46ee5e40bb4d50e494d8131a07ef47500e9e"}, - {file = "websockets-12.0.tar.gz", hash = "sha256:81df9cbcbb6c260de1e007e58c011bfebe2dafc8435107b0537f393dd38c8b1b"}, -] - [[package]] name = "wrapt" version = "1.16.0" description = "Module for decorators, wrappers and monkey patching." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -1783,6 +1725,7 @@ files = [ name = "zipp" version = "3.18.1" description = "Backport of pathlib-compatible object wrapper for zip files" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1797,4 +1740,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "7f35bbdd2708308a7a4af0ba9f34fc8e3202894ceeea867a85a070989448e9d3" +content-hash = "767598b63aad8e05f1b142a6a60fee18d458f7f64f85a7bb8581335eb835bbe9" diff --git a/pyproject.toml b/pyproject.toml index 1fd791b..066ea4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,6 @@ opentelemetry-exporter-otlp = "^1.23.0" opentelemetry-instrumentation-sqlalchemy = "^0.44b0" opentelemetry-instrumentation-logging = "^0.44b0" greenlet = "^3.0.3" -realtime = "^1.0.2" psycopg = {extras = ["binary"], version = "^3.1.18"} httpx = "^0.27.0" uvloop = "^0.19.0" diff --git a/src/crud.py b/src/crud.py index 2b342e5..20605c8 100644 --- a/src/crud.py +++ b/src/crud.py @@ -285,7 +285,7 @@ async def create_message( user_id: uuid.UUID, session_id: uuid.UUID, ) -> models.Message: - honcho_session = get_session( + honcho_session = await get_session( db, app_id=app_id, session_id=session_id, user_id=user_id ) if honcho_session is None: @@ -299,6 +299,7 @@ async def create_message( ) db.add(honcho_message) await db.commit() + # await db.refresh(honcho_message, attribute_names=["id", "content", "h_metadata"]) await db.refresh(honcho_message) return honcho_message diff --git a/src/db.py b/src/db.py index e45a964..848b76c 100644 --- a/src/db.py +++ b/src/db.py @@ -9,10 +9,10 @@ connect_args = {} -if ( - os.environ["DATABASE_TYPE"] == "sqlite" -): # https://fastapi.tiangolo.com/tutorial/sql-databases/#note - connect_args = {"check_same_thread": False} +# if ( +# os.environ["DATABASE_TYPE"] == "sqlite" +# ): # https://fastapi.tiangolo.com/tutorial/sql-databases/#note +# connect_args = {"check_same_thread": False} engine = create_async_engine( os.environ["CONNECTION_URI"], diff --git a/src/deriver.py b/src/deriver.py index 163ddae..603aa10 100644 --- a/src/deriver.py +++ b/src/deriver.py @@ -5,11 +5,13 @@ from typing import List import sentry_sdk +import uvloop from dotenv import load_dotenv from mirascope.openai import OpenAICall, OpenAICallParams -from realtime.connection import Socket + +# from realtime.connection import Socket from sqlalchemy import select -from sqlalchemy.orm import selectinload +from sqlalchemy.ext.asyncio import AsyncSession from . import crud, models, schemas from .db import SessionLocal @@ -24,8 +26,8 @@ ) -SUPABASE_ID = os.getenv("SUPABASE_ID") -SUPABASE_API_KEY = os.getenv("SUPABASE_API_KEY") +# SUPABASE_ID = os.getenv("SUPABASE_ID") +# SUPABASE_API_KEY = os.getenv("SUPABASE_API_KEY") class DeriveFacts(OpenAICall): @@ -60,48 +62,54 @@ class CheckDups(OpenAICall): call_params = OpenAICallParams(model="gpt-4o-2024-05-13") -async def callback(payload): +async def process_item(db: AsyncSession, payload: dict): # print(payload["record"]["is_user"]) # print(type(payload["record"]["is_user"])) - if payload["record"]["is_user"]: # Check if the message is from a user - session_id = payload["record"]["session_id"] - message_id = payload["record"]["id"] - content = payload["record"]["content"] - - # Example of querying for a user_id based on session_id, adjust according to your schema - session: models.Session - user_id: uuid.UUID - app_id: uuid.UUID - async with SessionLocal() as db: - stmt = ( - select(models.Session) - .join(models.Session.messages) - .where(models.Message.id == message_id) - .where(models.Session.id == session_id) - .options(selectinload(models.Session.user)) - ) - result = await db.execute(stmt) - session = result.scalars().one() - user = session.user - user_id = user.id - app_id = user.app_id - collection: models.Collection - async with SessionLocal() as db: - collection = await crud.get_collection_by_name( - db, app_id, user_id, "honcho" - ) - if collection is None: - collection_create = schemas.CollectionCreate(name="honcho", metadata={}) - collection = await crud.create_collection( - db, - collection=collection_create, - app_id=app_id, - user_id=user_id, - ) - collection_id = collection.id - await process_user_message( - content, app_id, user_id, session_id, collection_id, message_id + # if payload["record"]["is_user"]: # Check if the message is from a user + # session_id = payload["record"]["session_id"] + # message_id = payload["record"]["id"] + # content = payload["record"]["content"] + + # # Example of querying for a user_id based on session_id, adjust according to your schema + # session: models.Session + # user_id: uuid.UUID + # app_id: uuid.UUID + # async with SessionLocal() as db: + # stmt = ( + # select(models.Session) + # .join(models.Session.messages) + # .where(models.Message.id == message_id) + # .where(models.Session.id == session_id) + # .options(selectinload(models.Session.user)) + # ) + # result = await db.execute(stmt) + # session = result.scalars().one() + # user = session.user + # user_id = user.id + # app_id = user.app_id + collection: models.Collection + # async with SessionLocal() as db: + collection = await crud.get_collection_by_name( + db, payload["app_id"], payload["user_id"], "honcho" + ) + if collection is None: + collection_create = schemas.CollectionCreate(name="honcho", metadata={}) + collection = await crud.create_collection( + db, + collection=collection_create, + app_id=payload["app_id"], + user_id=payload["user_id"], ) + collection_id = collection.id + await process_user_message( + payload["content"], + payload["app_id"], + payload["user_id"], + payload["session_id"], + collection_id, + payload["message_id"], + db, + ) return @@ -112,20 +120,23 @@ async def process_user_message( session_id: uuid.UUID, collection_id: uuid.UUID, message_id: uuid.UUID, + db: AsyncSession, ): """ Process a user message and derive facts from it (check for duplicates before writing to the collection). """ - async with SessionLocal() as db: - messages_stmt = await crud.get_messages( - db=db, app_id=app_id, user_id=user_id, session_id=session_id, reverse=True - ) - messages_stmt = messages_stmt.limit(10) - response = await db.execute(messages_stmt) - messages = response.scalars().all() - messages = messages[::-1] - # contents = [m.content for m in messages] - # print(contents) + messages_stmt = await crud.get_messages( + db=db, app_id=app_id, user_id=user_id, session_id=session_id, reverse=True + ) + messages_stmt = messages_stmt.limit(10) + response = await db.execute(messages_stmt) + messages = response.scalars().all() + messages = messages[::-1] + contents = [m.content for m in messages] + print("===================") + print("Contents") + print(contents) + print("===================") chat_history_str = "\n".join( [f"user: {m.content}" if m.is_user else f"ai: {m.content}" for m in messages] @@ -189,6 +200,11 @@ async def check_dups( check_duplication.existing_facts = existing_facts check_duplication.fact = fact response = await check_duplication.call_async() + print("==================") + print("Dedupe Responses") + print(response) + print(response.content) + print("==================") if response.content == "True": new_facts.append(fact) print(f"New Fact: {fact}") @@ -201,13 +217,77 @@ async def check_dups( return new_facts +async def dequeue(semaphore: asyncio.Semaphore, queue_empty_flag: asyncio.Event): + async with semaphore, SessionLocal() as db: + try: + result = await db.execute( + select(models.QueueItem) + .where(models.QueueItem.processed == False) + .with_for_update(skip_locked=True) + .limit(1) + ) + item = result.scalar_one_or_none() + + if item: + print("========") + print("Processing") + print("========") + await process_item(db, payload=item.payload) + item.processed = True + await db.commit() + else: + # No items to process, set the queue_empty_flag + queue_empty_flag.set() + + except Exception as e: + print("==========") + print("Exception") + print(e) + print("==========") + await db.rollback() + + +async def polling_loop(semaphore: asyncio.Semaphore, queue_empty_flag: asyncio.Event): + while True: + if queue_empty_flag.is_set(): + print("========") + print("Queue is empty flag") + print("========") + await asyncio.sleep(5) # Sleep briefly if the queue is empty + queue_empty_flag.clear() # Reset the flag + continue + if semaphore.locked(): + print("========") + print("Semaphore Locked") + print("========") + await asyncio.sleep(2) # Sleep briefly if the semaphore is fully locked + continue + task = asyncio.create_task(dequeue(semaphore, queue_empty_flag)) + task.add_done_callback(lambda t: print(f"Task done: {t}")) + await asyncio.sleep(0) # Yield control to allow tasks to run + # tasks = [] + # for _ in range(5): + # tasks.append(task) + # await asyncio.gather(*tasks) + # await dequeue() + # await asyncio.sleep(5) + + +async def main(): + semaphore = asyncio.Semaphore(1) # Limit to 5 concurrent dequeuing operations + queue_empty_flag = asyncio.Event() # Event to signal when the queue is empty + await polling_loop(semaphore, queue_empty_flag) + + if __name__ == "__main__": + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + asyncio.run(main()) # URL = f"wss://{SUPABASE_ID}.supabase.co/realtime/v1/websocket?apikey={SUPABASE_API_KEY}&vsn=1.0.0" - URL = f"ws://127.0.0.1:54321/realtime/v1/websocket?apikey={SUPABASE_API_KEY}" # For local Supabase + # URL = f"ws://127.0.0.1:54321/realtime/v1/websocket?apikey={SUPABASE_API_KEY}" # For local Supabase # listen_to_websocket(URL) - s = Socket(URL) - s.connect() + # s = Socket(URL) + # s.connect() - channel = s.set_channel("realtime:public:messages") - channel.join().on("INSERT", lambda payload: asyncio.create_task(callback(payload))) - s.listen() + # channel = s.set_channel("realtime:public:messages") + # channel.join().on("INSERT", lambda payload: asyncio.create_task(callback(payload))) + # s.listen() diff --git a/src/models.py b/src/models.py index 0790b73..81cc6ef 100644 --- a/src/models.py +++ b/src/models.py @@ -6,9 +6,11 @@ from pgvector.sqlalchemy import Vector from sqlalchemy import ( JSON, + Boolean, Column, DateTime, ForeignKey, + Integer, String, UniqueConstraint, Uuid, @@ -20,9 +22,9 @@ load_dotenv() -DATABASE_TYPE = os.getenv("DATABASE_TYPE", "postgres") +# DATABASE_TYPE = os.getenv("DATABASE_TYPE", "postgres") -ColumnType = JSONB if DATABASE_TYPE == "postgres" else JSON +# ColumnType = JSONB if DATABASE_TYPE == "postgres" else JSON class App(Base): @@ -35,7 +37,7 @@ class App(Base): created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), default=datetime.datetime.utcnow ) - h_metadata: Mapped[dict] = mapped_column("metadata", ColumnType, default={}) + h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) # Add any additional fields for an app here @@ -45,7 +47,7 @@ class User(Base): primary_key=True, index=True, default=uuid.uuid4 ) name: Mapped[str] = mapped_column(String(512), index=True) - h_metadata: Mapped[dict] = mapped_column("metadata", ColumnType, default={}) + h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), default=datetime.datetime.utcnow ) @@ -67,7 +69,7 @@ class Session(Base): ) location_id: Mapped[str] = mapped_column(String(512), index=True, default="default") is_active: Mapped[bool] = mapped_column(default=True) - h_metadata: Mapped[dict] = mapped_column("metadata", ColumnType, default={}) + h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), default=datetime.datetime.utcnow ) @@ -87,7 +89,7 @@ class Message(Base): session_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("sessions.id"), index=True) is_user: Mapped[bool] content: Mapped[str] = mapped_column(String(65535)) - h_metadata: Mapped[dict] = mapped_column("metadata", ColumnType, default={}) + h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), default=datetime.datetime.utcnow @@ -112,7 +114,7 @@ class Metamessage(Base): created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), default=datetime.datetime.utcnow ) - h_metadata: Mapped[dict] = mapped_column("metadata", ColumnType, default={}) + h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) def __repr__(self) -> str: return f"Metamessages(id={self.id}, message_id={self.message_id}, metamessage_type={self.metamessage_type}, content={self.content[10:]})" @@ -127,7 +129,7 @@ class Collection(Base): created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), default=datetime.datetime.utcnow ) - h_metadata: Mapped[dict] = mapped_column("metadata", ColumnType, default={}) + h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) documents = relationship( "Document", back_populates="collection", cascade="all, delete, delete-orphan" ) @@ -144,7 +146,7 @@ class Document(Base): id: Mapped[uuid.UUID] = mapped_column( primary_key=True, index=True, default=uuid.uuid4 ) - h_metadata: Mapped[dict] = mapped_column("metadata", ColumnType, default={}) + h_metadata: Mapped[dict] = mapped_column("metadata", JSONB, default={}) content: Mapped[str] = mapped_column(String(65535)) embedding = mapped_column(Vector(1536)) created_at: Mapped[datetime.datetime] = mapped_column( @@ -153,3 +155,10 @@ class Document(Base): collection_id = Column(Uuid, ForeignKey("collections.id"), index=True) collection = relationship("Collection", back_populates="documents") + + +class QueueItem(Base): + __tablename__ = "queue" + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + payload: Mapped[dict] = mapped_column(JSONB, nullable=False) + processed: Mapped[bool] = mapped_column(Boolean, default=False) diff --git a/src/routers/messages.py b/src/routers/messages.py index 72294f4..bd7123e 100644 --- a/src/routers/messages.py +++ b/src/routers/messages.py @@ -2,12 +2,15 @@ import uuid from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request from fastapi_pagination import Page from fastapi_pagination.ext.sqlalchemy import paginate +from sqlalchemy.ext.asyncio import AsyncSession from src import crud, schemas +from src.db import SessionLocal from src.dependencies import db +from src.models import QueueItem from src.security import auth router = APIRouter( @@ -16,6 +19,24 @@ ) +async def enqueue(payload: dict): + async with SessionLocal() as db: + try: + processed_payload = { + k: str(v) if isinstance(v, uuid.UUID) else v for k, v in payload.items() + } + item = QueueItem(payload=processed_payload) + db.add(item) + await db.commit() + return + except Exception as e: + print("=====================") + print("FAILURE: in enqueue") + print("=====================") + print(e) + await db.rollback() + + @router.post("", response_model=schemas.Message) async def create_message_for_session( request: Request, @@ -23,6 +44,7 @@ async def create_message_for_session( user_id: uuid.UUID, session_id: uuid.UUID, message: schemas.MessageCreate, + background_tasks: BackgroundTasks, db=db, auth=Depends(auth), ): @@ -44,10 +66,27 @@ async def create_message_for_session( """ try: - return await crud.create_message( + honcho_message = await crud.create_message( db, message=message, app_id=app_id, user_id=user_id, session_id=session_id ) + if message.is_user: + print("=======") + print("Should be enqueued") + payload = { + "app_id": app_id, + "user_id": user_id, + "session_id": session_id, + "message_id": honcho_message.id, + "content": honcho_message.content, + "metadata": honcho_message.h_metadata, + } + background_tasks.add_task(enqueue, payload) # type: ignore + + return honcho_message except ValueError: + print("=====================") + print("FAILURE: in create message") + print("=====================") raise HTTPException(status_code=404, detail="Session not found") from None