diff --git a/.github/workflows/surrealml_core_test.yml b/.github/workflows/surrealml_core_test.yml index 4aea41f..e672b20 100644 --- a/.github/workflows/surrealml_core_test.yml +++ b/.github/workflows/surrealml_core_test.yml @@ -50,7 +50,6 @@ jobs: source venv/bin/activate export PYTHONPATH="." python tests/unit_tests/engine/test_sklearn.py - python -m unittest discover deactivate - name: Run Core Unit Tests diff --git a/.github/workflows/surrealml_core_torch_test.yml b/.github/workflows/surrealml_core_torch_test.yml new file mode 100644 index 0000000..b18ebe1 --- /dev/null +++ b/.github/workflows/surrealml_core_torch_test.yml @@ -0,0 +1,61 @@ +name: Run Torch Tests + +on: + pull_request: + types: [opened, reopened, synchronize] + +jobs: + test_core: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up Rust + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.11' + + - name: Pre-test Setup + run: | + python3 -m venv venv + source venv/bin/activate + pip install --upgrade pip + # pip install -r requirements.txt + + # build the local version of the core module to be loaded into python + echo "Building local version of core module" + + pip install . + export PYTHONPATH="." + + python ./tests/scripts/ci_local_build.py + echo "Local build complete" + + # train the models for the tests + python ./tests/model_builder/torch_assets.py + deactivate + + - name: Run Python Unit Tests + run: | + source venv/bin/activate + export PYTHONPATH="." + python tests/unit_tests/engine/test_torch.py + python tests/unit_tests/test_rust_adapter.py + python tests/unit_tests/engine/test_surml_file.py + deactivate + + - name: Run Core Unit Tests + run: cd modules/core && cargo test --features torch-tests + + - name: Run HTTP Transfer Tests + run: cargo test diff --git a/tests/model_builder/torch_assets.py b/tests/model_builder/torch_assets.py index fd5215d..25dc0b3 100644 --- a/tests/model_builder/torch_assets.py +++ b/tests/model_builder/torch_assets.py @@ -1,3 +1,8 @@ +""" +This file trains and saves the torch linear model to the model stash directory for the core to test against +""" +from tests.model_builder.utils import install_package +install_package("torch==2.1.2") import os from surrealml.model_templates.torch.torch_linear import train_model as linear_torch_train_model