Skip to content

Commit

Permalink
Merge pull request privacera#165 from vinayakbagal7/PAIG-2210
Browse files Browse the repository at this point in the history
Implemented guardrail service with APIs to effectively manage guardrails in the system privacera#160
  • Loading branch information
pravin-bansod authored Feb 28, 2025
2 parents 7ccb2df + 66d3f58 commit cab49b4
Show file tree
Hide file tree
Showing 94 changed files with 8,083 additions and 172 deletions.
2 changes: 1 addition & 1 deletion paig-server/backend/paig/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import signal
import click
import uvicorn
import psutil

PID_FILE = os.path.join(ROOT_DIR, "paig_server.pid")

Expand Down Expand Up @@ -180,6 +179,7 @@ def is_server_running():
with open(PID_FILE, "r") as f:
pid = int(f.read().strip())
# Check if the PID belongs to a process with the expected command
import psutil
if psutil.pid_exists(pid):
process = psutil.Process(pid)
cmdline = " ".join(process.cmdline())
Expand Down
6 changes: 4 additions & 2 deletions paig-server/backend/paig/alembic_db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
import asyncio

config = load_config_file()
default_ai_app = load_default_ai_config()
database_url = config["database"]["url"]
security_config = config["security"]

def create_or_update_tables(root_dir: str = None):
try:
Expand All @@ -37,6 +35,8 @@ def create_or_update_tables(root_dir: str = None):


async def check_and_create_default_user():
security_config = config["security"]

engine = None
try:
engine = create_async_engine(url=database_url)
Expand Down Expand Up @@ -76,6 +76,8 @@ def get_bind(self, mapper=None, clause=None, **kwargs):


async def check_and_create_default_ai_application():
default_ai_app = load_default_ai_config()

engine = None
try:
engine = create_async_engine(url=database_url)
Expand Down
2 changes: 2 additions & 0 deletions paig-server/backend/paig/alembic_db/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from api.audit.RDS_service.db_models import access_audit_model
from api.encryption.database.db_models import encryption_master_key_model, encryption_key_model
from api.evaluation.database.db_models import eval_model, eval_targets, eval_config
from api.guardrails.database.db_models import guardrail_model, gr_connection_model
from api.guardrails.database.db_models import response_template_model
from core.db_session.session import Base
target_metadata = Base.metadata

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Added Guardrail service tables
Revision ID: 67a256363095
Revises: db6e7b60cb0a
Create Date: 2025-01-31 16:59:29.619823
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
import core.db_models.utils


# revision identifiers, used by Alembic.
revision: str = '67a256363095'
down_revision: Union[str, None] = 'db6e7b60cb0a'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('guardrail',
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.String(length=4000), nullable=True),
sa.Column('version', sa.Integer(), nullable=False),
sa.Column('guardrail_provider', sa.Enum('AWS', name='guardrailprovider'), nullable=True),
sa.Column('guardrail_connection_name', sa.String(length=255), nullable=True),
sa.Column('guardrail_configs', sa.JSON(), nullable=False),
sa.Column('guardrail_provider_response', sa.JSON(), nullable=True),
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('status', sa.Integer(), nullable=False),
sa.Column('create_time', sa.DateTime(), nullable=False),
sa.Column('update_time', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_guardrail_create_time'), 'guardrail', ['create_time'], unique=False)
op.create_index(op.f('ix_guardrail_id'), 'guardrail', ['id'], unique=False)
op.create_index(op.f('ix_guardrail_update_time'), 'guardrail', ['update_time'], unique=False)
op.create_table('guardrail_connection',
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.String(length=4000), nullable=True),
sa.Column('guardrail_provider', sa.Enum('AWS', name='guardrailprovider'), nullable=False),
sa.Column('connection_details', sa.JSON(), nullable=False),
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('status', sa.Integer(), nullable=False),
sa.Column('create_time', sa.DateTime(), nullable=False),
sa.Column('update_time', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_guardrail_connection_create_time'), 'guardrail_connection', ['create_time'], unique=False)
op.create_index(op.f('ix_guardrail_connection_id'), 'guardrail_connection', ['id'], unique=False)
op.create_index(op.f('ix_guardrail_connection_update_time'), 'guardrail_connection', ['update_time'], unique=False)
op.create_table('response_template',
sa.Column('response', sa.String(length=4000), nullable=False),
sa.Column('description', sa.String(length=4000), nullable=True),
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('status', sa.Integer(), nullable=False),
sa.Column('create_time', sa.DateTime(), nullable=False),
sa.Column('update_time', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_response_template_create_time'), 'response_template', ['create_time'], unique=False)
op.create_index(op.f('ix_response_template_id'), 'response_template', ['id'], unique=False)
op.create_index(op.f('ix_response_template_update_time'), 'response_template', ['update_time'], unique=False)
op.create_table('guardrail_version_history',
sa.Column('guardrail_id', sa.Integer(), nullable=False),
sa.Column('version', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.String(length=4000), nullable=True),
sa.Column('guardrail_provider', sa.Enum('AWS', name='guardrailprovider'), nullable=True),
sa.Column('guardrail_connection_name', sa.String(length=255), nullable=True),
sa.Column('guardrail_configs', sa.JSON(), nullable=False),
sa.Column('guardrail_provider_response', sa.JSON(), nullable=True),
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('status', sa.Integer(), nullable=False),
sa.Column('create_time', sa.DateTime(), nullable=False),
sa.Column('update_time', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['guardrail_id'], ['guardrail.id'], name='fk_guardrail_version_history_guardrail_id', ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_guardrail_version_history_create_time'), 'guardrail_version_history', ['create_time'], unique=False)
op.create_index(op.f('ix_guardrail_version_history_id'), 'guardrail_version_history', ['id'], unique=False)
op.create_index(op.f('ix_guardrail_version_history_update_time'), 'guardrail_version_history', ['update_time'], unique=False)

op.add_column('ai_application', sa.Column('guardrails', core.db_models.utils.CommaSeparatedList(length=255), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('ai_application', 'guardrails')

op.drop_index(op.f('ix_guardrail_version_history_update_time'), table_name='guardrail_version_history')
op.drop_index(op.f('ix_guardrail_version_history_id'), table_name='guardrail_version_history')
op.drop_index(op.f('ix_guardrail_version_history_create_time'), table_name='guardrail_version_history')
op.drop_table('guardrail_version_history')
op.drop_index(op.f('ix_response_template_update_time'), table_name='response_template')
op.drop_index(op.f('ix_response_template_id'), table_name='response_template')
op.drop_index(op.f('ix_response_template_create_time'), table_name='response_template')
op.drop_table('response_template')
op.drop_index(op.f('ix_guardrail_connection_update_time'), table_name='guardrail_connection')
op.drop_index(op.f('ix_guardrail_connection_id'), table_name='guardrail_connection')
op.drop_index(op.f('ix_guardrail_connection_create_time'), table_name='guardrail_connection')
op.drop_table('guardrail_connection')
op.drop_index(op.f('ix_guardrail_update_time'), table_name='guardrail')
op.drop_index(op.f('ix_guardrail_id'), table_name='guardrail')
op.drop_index(op.f('ix_guardrail_create_time'), table_name='guardrail')
op.drop_table('guardrail')
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Adding new encryption type in enum
Revision ID: db6e7b60cb0a
Revises: 701ddf55a1b4
Create Date: 2024-12-17 10:59:41.630340
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
from alembic.context import get_context


# revision identifiers, used by Alembic.
revision: str = 'db6e7b60cb0a'
down_revision: Union[str, None] = '701ddf55a1b4'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Check the database dialect
dialect = get_context().dialect.name

if dialect == 'sqlite':
# Skip execution for SQLite
print("Skipping this migration for SQLite.")
return

# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('encryption_key', 'key_type',
existing_type=sa.VARCHAR(length=18),
type_=sa.Enum('MSG_PROTECT_SHIELD', 'MSG_PROTECT_PLUGIN', 'CRDS_PROTECT_GUARDRAIL', name='encryptionkeytype'),
existing_nullable=False)
# ### end Alembic commands ###


def downgrade() -> None:
# Check the database dialect
dialect = get_context().dialect.name

if dialect == 'sqlite':
# Skip execution for SQLite
print("Skipping downgrade for SQLite.")
return

# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('encryption_key', 'key_type',
existing_type=sa.Enum('MSG_PROTECT_SHIELD', 'MSG_PROTECT_PLUGIN', 'CRDS_PROTECT_GUARDRAIL', name='encryptionkeytype'),
type_=sa.VARCHAR(length=18),
existing_nullable=False)
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
class EncryptionKeyType(Enum):
MSG_PROTECT_SHIELD = 'MSG_PROTECT_SHIELD'
MSG_PROTECT_PLUGIN = 'MSG_PROTECT_PLUGIN'
CRDS_PROTECT_GUARDRAIL = 'CRDS_PROTECT_GUARDRAIL'


class EncryptionKeyStatus(Enum):
Expand Down
3 changes: 3 additions & 0 deletions paig-server/backend/paig/api/encryption/events/startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ async def create_default_encryption_keys_if_not_exists():
# Create MSG_PROTECT_PLUGIN encryption key
await create_encryption_keys_if_not_exists(encryption_key_service, EncryptionKeyType.MSG_PROTECT_PLUGIN)

# Create CRDS_PROTECT_GUARDRAIL encryption key
await create_encryption_keys_if_not_exists(encryption_key_service, EncryptionKeyType.CRDS_PROTECT_GUARDRAIL)


async def create_default_encryption_keys():
context = set_session_context(session_id="encryption_startup")
Expand Down
16 changes: 15 additions & 1 deletion paig-server/backend/paig/api/governance/api_schemas/ai_app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, List

from pydantic import Field
from pydantic import Field, BaseModel

from core.factory.database_initiator import BaseAPIFilter
from core.api_schemas.base_view import BaseView
Expand All @@ -27,6 +27,7 @@ class AIApplicationView(BaseView):
application_key: Optional[str] = Field(None, description="The application key", alias="applicationKey")
vector_dbs: Optional[List[str]] = Field([], description="The vector databases associated with the AI application", alias="vectorDBs")
guardrail_details: Optional[str] = Field(None, description="The guardrail details", alias="guardrailDetails")
guardrails: Optional[List[str]] = Field([], description="The guardrails associated with AI application", alias="guardrails")

vector_db_id: Optional[int] = Field(None, description="The vector databases id with the AI application",
alias="vectorDBId")
Expand All @@ -51,6 +52,18 @@ def to_ai_application_data(self):
)


class GuardrailApplicationsAssociation(BaseModel):
"""
A model representing an AI application guardrail update request.
Attributes:
guardrail (str): The guardrail to update.
applications (List[str]): The applications to update.
"""
guardrail: str = Field(..., description="The guardrail to update")
applications: List[str] = Field(..., description="The applications to update")


class AIApplicationFilter(BaseAPIFilter):
"""
Filter class for AI application queries.
Expand All @@ -66,3 +79,4 @@ class AIApplicationFilter(BaseAPIFilter):
name: Optional[str] = Field(default=None, description="Filter by name")
application_key: Optional[str] = Field(default=None, description="Filter by application key", alias="applicationKey")
vector_dbs: Optional[str] = Field(default=None, description="Filter by vector db", alias="vectorDB")
guardrails: Optional[str] = Field(default=None, description="Filter by guardrail details", alias="guardrail")
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from core.controllers.paginated_response import Pageable
from core.db_session import Transactional, Propagation
from api.governance.api_schemas.ai_app import AIApplicationView, AIApplicationFilter
from api.governance.api_schemas.ai_app import AIApplicationView, AIApplicationFilter, GuardrailApplicationsAssociation
from api.governance.api_schemas.ai_app_config import AIApplicationConfigView
from api.governance.api_schemas.ai_app_policy import AIApplicationPolicyView
from api.governance.services.ai_app_config_service import AIAppConfigService
Expand Down Expand Up @@ -125,3 +125,19 @@ async def delete_ai_application(self, id: int):
"""
await self.ai_app_service.delete_ai_application(id)
await background_capture_event(event=DeleteAIApplicationEvent())

@Transactional(propagation=Propagation.REQUIRED)
async def update_guardrail_application_association(self, request: GuardrailApplicationsAssociation):
"""
Associates or disassociates applications with a given guardrail.
- Applications in `request.applications` will be associated with the guardrail.
- Applications currently linked to the guardrail but missing from `request.applications` will be disassociated.
Args:
request (GuardrailAssociationRequest): Guardrail name and list of applications.
Returns:
GuardrailApplicationsAssociation: Updated associations.
"""
return await self.ai_app_service.update_guardrail_application_association(request)
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class AIApplicationModel(BaseSQLModel):
application_key = Column(String(255), nullable=False)
vector_dbs = Column(CommaSeparatedList(255), nullable=True)
guardrail_details = Column(String(255), nullable=True)
guardrails = Column(CommaSeparatedList(255), nullable=True)

app_config = relationship("AIApplicationConfigModel", back_populates="ai_app", uselist=False, cascade="all, delete-orphan")
app_policies = relationship("AIApplicationPolicyModel", back_populates="ai_app", cascade="all, delete-orphan")
Expand Down
13 changes: 12 additions & 1 deletion paig-server/backend/paig/api/governance/routes/ai_app_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi import APIRouter, Depends, status, Query

from core.controllers.paginated_response import Pageable
from api.governance.api_schemas.ai_app import AIApplicationView, AIApplicationFilter
from api.governance.api_schemas.ai_app import AIApplicationView, AIApplicationFilter, GuardrailApplicationsAssociation
from api.governance.controllers.ai_app_controller import AIAppController
from core.utils import SingletonDepends

Expand Down Expand Up @@ -37,6 +37,17 @@ async def create_application(
return await ai_app_controller.create_ai_application(create_ai_app_request)


@ai_app_router.put("/guardrails")
async def update_application_guardrails(
request: GuardrailApplicationsAssociation,
ai_app_controller: AIAppController = ai_app_controller_instance
):
"""
Associates or disassociates applications with a given guardrail.
"""
return await ai_app_controller.update_guardrail_application_association(request)


@ai_app_router.get("/{id}", response_model=AIApplicationView)
async def get_application(
id: int,
Expand Down
Loading

0 comments on commit cab49b4

Please sign in to comment.