Skip to content

Commit

Permalink
Add seed and migration scripts.
Browse files Browse the repository at this point in the history
  • Loading branch information
bLopata committed Jan 14, 2025
1 parent c2a1bb1 commit 2705492
Show file tree
Hide file tree
Showing 3 changed files with 365 additions and 0 deletions.
175 changes: 175 additions & 0 deletions src/migrate_honcho_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import asyncio
import json
import os
from typing import Dict
from uuid import UUID

from nanoid import generate as generate_nanoid
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from dotenv import load_dotenv

load_dotenv()

SOURCE_SCHEMA = 'honcho_old'
DEST_SCHEMA = 'honcho_new'

def create_db_engine(url: str) -> AsyncEngine:
"""Create an async database engine from a connection URL"""
if url.startswith('postgresql://'):
url = url.replace('postgresql://', 'postgresql+asyncpg://', 1)
return create_async_engine(url, echo=False, pool_pre_ping=True)

async def migrate_data():
"""Migrate data between schemas in the same database"""
print("Starting schema migration...")
print(f"From: {SOURCE_SCHEMA} schema")
print(f"To: {DEST_SCHEMA} schema")

connection_uri = os.getenv('CONNECTION_URI')
if not connection_uri:
raise ValueError("CONNECTION_URI environment variable is not set")

engine = create_db_engine(connection_uri)

async with AsyncSession(engine) as session:
async with session.begin():
await migrate_schemas(session)

print("Migration complete!")

async def migrate_schemas(session: AsyncSession):
"""Migrate data between schemas"""
id_mappings: Dict[str, Dict[UUID, str]] = {
'apps': {},
'users': {},
'sessions': {},
'messages': {}
}

# Migrate apps
print("Migrating apps...")
result = await session.execute(text(f'''
SELECT id::text, name, created_at, metadata
FROM {SOURCE_SCHEMA}.apps
ORDER BY created_at ASC
'''))
for row in result.mappings():
public_id = generate_nanoid()
id_mappings['apps'][UUID(row['id'])] = public_id

await session.execute(text(f'''
INSERT INTO {DEST_SCHEMA}.apps (
public_id,
name,
created_at,
metadata
) VALUES (
:public_id,
:name,
:created_at,
cast(:metadata as jsonb)
)
'''), {
'public_id': public_id,
'name': row['name'],
'created_at': row['created_at'],
'metadata': json.dumps(row['metadata'] or {})
})

# Migrate users
print("Migrating users...")
result = await session.execute(text(f'''
SELECT id::text, name, app_id::text, created_at, metadata
FROM {SOURCE_SCHEMA}.users
ORDER BY created_at ASC
'''))
for row in result.mappings():
public_id = generate_nanoid()
id_mappings['users'][UUID(row['id'])] = public_id

metadata = row['metadata'] or {}
if isinstance(metadata, str):
metadata = json.loads(metadata)

metadata.update({
'legacy_id': str(row['id'])
})

await session.execute(text(f'''
INSERT INTO {DEST_SCHEMA}.users (public_id, name, app_id, created_at, metadata)
VALUES (:public_id, :name, :app_id, :created_at, cast(:metadata as jsonb))
'''), {
'public_id': public_id,
'name': row['name'],
'app_id': id_mappings['apps'][UUID(row['app_id'])],
'created_at': row['created_at'],
'metadata': json.dumps(metadata)
})

# Migrate sessions
print("Migrating sessions...")
result = await session.execute(text(f'''
SELECT id::text, user_id::text, is_active, created_at, metadata
FROM {SOURCE_SCHEMA}.sessions
ORDER BY created_at ASC
'''))
for row in result.mappings():
public_id = generate_nanoid()
id_mappings['sessions'][UUID(row['id'])] = public_id

metadata = row['metadata'] or {}
if isinstance(metadata, str):
metadata = json.loads(metadata)

metadata.update({
'legacy_id': str(row['id'])
})

await session.execute(text(f'''
INSERT INTO {DEST_SCHEMA}.sessions (
public_id,
user_id,
is_active,
created_at,
metadata
) VALUES (
:public_id,
:user_id,
:is_active,
:created_at,
cast(:metadata as jsonb)
)
'''), {
'public_id': public_id,
'user_id': id_mappings['users'][UUID(row['user_id'])],
'is_active': row['is_active'],
'created_at': row['created_at'],
'metadata': json.dumps(metadata)
})

# Migrate messages
print("Migrating messages...")
result = await session.execute(text(f'''
SELECT id::text, session_id::text, is_user, content, created_at, metadata
FROM {SOURCE_SCHEMA}.messages
ORDER BY created_at ASC
'''))
for row in result.mappings():
public_id = generate_nanoid()
id_mappings['messages'][UUID(row['id'])] = public_id

await session.execute(text(f'''
INSERT INTO {DEST_SCHEMA}.messages (public_id, session_id, is_user, content, created_at, metadata)
VALUES (:public_id, :session_id, :is_user, :content, :created_at, cast(:metadata as jsonb))
'''), {
'public_id': public_id,
'session_id': id_mappings['sessions'][UUID(row['session_id'])],
'is_user': row['is_user'],
'content': row['content'],
'created_at': row['created_at'],
'metadata': json.dumps(row['metadata'] or {})
})

if __name__ == "__main__":
asyncio.run(migrate_data())
36 changes: 36 additions & 0 deletions src/seed_export_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import asyncio
import os
from pathlib import Path

from dotenv import load_dotenv
from sqlalchemy import text

from src.db import scaffold_db, engine
from src.seed_from_export import seed_from_export

async def drop_schema():
"""Drop the schema if it exists"""
async with engine.begin() as conn:
await conn.execute(text("DROP SCHEMA IF EXISTS honcho_old CASCADE"))

async def main():
"""Main function to scaffold database and seed from export"""
load_dotenv()

# Ensure we're using the right schema
if 'DATABASE_SCHEMA' not in os.environ:
os.environ['DATABASE_SCHEMA'] = 'honcho_old'

print("Dropping existing schema...")
await drop_schema()

print("Scaffolding database...")
scaffold_db()

print("Seeding database from export...")
await seed_from_export()

print("Database seeding complete!")

if __name__ == "__main__":
asyncio.run(main())
154 changes: 154 additions & 0 deletions src/seed_from_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import asyncio
import csv
import datetime
import json
import uuid
from pathlib import Path
from sqlalchemy import text

from .db import SessionLocal, engine
from .old_models import App, User, Session as ChatSession, Message, OldBase

SOURCE_SCHEMA = 'honcho_old'

async def parse_csv_file(file_path: Path) -> list[dict]:
"""Parse a CSV file and return a list of dictionaries"""
with open(str(file_path), 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
rows = [row for row in reader]
# Sort by created_at in ascending order (oldest first)
return sorted(rows, key=lambda x: x['created_at'])

async def parse_metadata(metadata_str: str) -> dict:
"""Parse metadata string into dictionary, handling empty cases"""
if not metadata_str or metadata_str == '{}':
return {}
try:
return json.loads(metadata_str)
except json.JSONDecodeError:
print(f"Warning: Could not parse metadata: {metadata_str}")
return {}

async def seed_from_export(dump_dir: str = "src/yousim_dump"):
"""Seed the database with data from exported CSV files"""
dump_path = Path(dump_dir)

# Create schema if it doesn't exist
print("Ensuring schema exists...")
async with engine.begin() as conn:
await conn.execute(text(f'CREATE SCHEMA IF NOT EXISTS {SOURCE_SCHEMA}'))

# Drop existing tables and create new ones
print("Dropping existing tables...")
async with engine.begin() as conn:
await conn.run_sync(OldBase.metadata.drop_all)
print("Creating new tables...")
async with engine.begin() as conn:
await conn.run_sync(OldBase.metadata.create_all)

# Track stats for reporting
stats = {
'apps': {'imported': 0, 'skipped': 0},
'users': {'imported': 0, 'skipped': 0},
'sessions': {'imported': 0, 'skipped': 0},
'messages': {'imported': 0, 'skipped': 0}
}

# Store mappings for foreign key relationships
app_id_mapping = {}
user_id_mapping = {}
session_id_mapping = {}

# Import Apps
async with SessionLocal() as session:
async with session.begin():
apps_data = await parse_csv_file(dump_path / "apps_rows (1).csv")
for app_row in apps_data:
try:
metadata = await parse_metadata(app_row['metadata'])
app = App(
id=uuid.UUID(app_row['id']),
name=app_row['name'],
created_at=datetime.datetime.fromisoformat(app_row['created_at']),
h_metadata=metadata
)
session.add(app)
app_id_mapping[app_row['id']] = app.id
stats['apps']['imported'] += 1
except Exception as e:
print(f"Error importing app {app_row['id']}: {str(e)}")
stats['apps']['skipped'] += 1

# Import Users
async with SessionLocal() as session:
async with session.begin():
users_data = await parse_csv_file(dump_path / "users_rows (1).csv")
for user_row in users_data:
try:
metadata = await parse_metadata(user_row['metadata'])
user = User(
id=uuid.UUID(user_row['id']),
name=user_row['name'],
app_id=app_id_mapping[user_row['app_id']],
created_at=datetime.datetime.fromisoformat(user_row['created_at']),
h_metadata=metadata
)
session.add(user)
user_id_mapping[user_row['id']] = user.id
stats['users']['imported'] += 1
except Exception as e:
print(f"Error importing user {user_row['id']}: {str(e)}")
stats['users']['skipped'] += 1

# Import Sessions
async with SessionLocal() as session:
async with session.begin():
sessions_data = await parse_csv_file(dump_path / "sessions_rows.csv")
for session_row in sessions_data:
try:
metadata = await parse_metadata(session_row['metadata'])
# Removed legacy ID updates from here
chat_session = ChatSession(
id=uuid.UUID(session_row['id']),
is_active=session_row['is_active'].lower() == 'true',
user_id=user_id_mapping[session_row['user_id']],
created_at=datetime.datetime.fromisoformat(session_row['created_at']),
h_metadata=metadata # Using original metadata without modifications
)
session.add(chat_session)
session_id_mapping[session_row['id']] = chat_session.id
stats['sessions']['imported'] += 1
except Exception as e:
print(f"Error importing session {session_row['id']}: {str(e)}")
stats['sessions']['skipped'] += 1

# Import Messages
async with SessionLocal() as session:
async with session.begin():
messages_data = await parse_csv_file(dump_path / "messages_rows.csv")
for message_row in messages_data:
try:
metadata = await parse_metadata(message_row['metadata'])
message = Message(
id=uuid.UUID(message_row['id']),
session_id=session_id_mapping[message_row['session_id']],
content=message_row['content'],
is_user=message_row['is_user'].lower() == 'true',
created_at=datetime.datetime.fromisoformat(message_row['created_at']),
h_metadata=metadata
)
session.add(message)
stats['messages']['imported'] += 1
except Exception as e:
print(f"Error importing message {message_row['id']}: {str(e)}")
stats['messages']['skipped'] += 1

# Print import statistics
print("\nImport Statistics:")
for entity, counts in stats.items():
print(f"{entity.title()}:")
print(f" Imported: {counts['imported']}")
print(f" Skipped: {counts['skipped']}")

if __name__ == "__main__":
asyncio.run(seed_from_export())

0 comments on commit 2705492

Please sign in to comment.