-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
365 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import asyncio | ||
import json | ||
import os | ||
from typing import Dict | ||
from uuid import UUID | ||
|
||
from nanoid import generate as generate_nanoid | ||
from sqlalchemy import text | ||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine | ||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
|
||
SOURCE_SCHEMA = 'honcho_old' | ||
DEST_SCHEMA = 'honcho_new' | ||
|
||
def create_db_engine(url: str) -> AsyncEngine: | ||
"""Create an async database engine from a connection URL""" | ||
if url.startswith('postgresql://'): | ||
url = url.replace('postgresql://', 'postgresql+asyncpg://', 1) | ||
return create_async_engine(url, echo=False, pool_pre_ping=True) | ||
|
||
async def migrate_data(): | ||
"""Migrate data between schemas in the same database""" | ||
print("Starting schema migration...") | ||
print(f"From: {SOURCE_SCHEMA} schema") | ||
print(f"To: {DEST_SCHEMA} schema") | ||
|
||
connection_uri = os.getenv('CONNECTION_URI') | ||
if not connection_uri: | ||
raise ValueError("CONNECTION_URI environment variable is not set") | ||
|
||
engine = create_db_engine(connection_uri) | ||
|
||
async with AsyncSession(engine) as session: | ||
async with session.begin(): | ||
await migrate_schemas(session) | ||
|
||
print("Migration complete!") | ||
|
||
async def migrate_schemas(session: AsyncSession): | ||
"""Migrate data between schemas""" | ||
id_mappings: Dict[str, Dict[UUID, str]] = { | ||
'apps': {}, | ||
'users': {}, | ||
'sessions': {}, | ||
'messages': {} | ||
} | ||
|
||
# Migrate apps | ||
print("Migrating apps...") | ||
result = await session.execute(text(f''' | ||
SELECT id::text, name, created_at, metadata | ||
FROM {SOURCE_SCHEMA}.apps | ||
ORDER BY created_at ASC | ||
''')) | ||
for row in result.mappings(): | ||
public_id = generate_nanoid() | ||
id_mappings['apps'][UUID(row['id'])] = public_id | ||
|
||
await session.execute(text(f''' | ||
INSERT INTO {DEST_SCHEMA}.apps ( | ||
public_id, | ||
name, | ||
created_at, | ||
metadata | ||
) VALUES ( | ||
:public_id, | ||
:name, | ||
:created_at, | ||
cast(:metadata as jsonb) | ||
) | ||
'''), { | ||
'public_id': public_id, | ||
'name': row['name'], | ||
'created_at': row['created_at'], | ||
'metadata': json.dumps(row['metadata'] or {}) | ||
}) | ||
|
||
# Migrate users | ||
print("Migrating users...") | ||
result = await session.execute(text(f''' | ||
SELECT id::text, name, app_id::text, created_at, metadata | ||
FROM {SOURCE_SCHEMA}.users | ||
ORDER BY created_at ASC | ||
''')) | ||
for row in result.mappings(): | ||
public_id = generate_nanoid() | ||
id_mappings['users'][UUID(row['id'])] = public_id | ||
|
||
metadata = row['metadata'] or {} | ||
if isinstance(metadata, str): | ||
metadata = json.loads(metadata) | ||
|
||
metadata.update({ | ||
'legacy_id': str(row['id']) | ||
}) | ||
|
||
await session.execute(text(f''' | ||
INSERT INTO {DEST_SCHEMA}.users (public_id, name, app_id, created_at, metadata) | ||
VALUES (:public_id, :name, :app_id, :created_at, cast(:metadata as jsonb)) | ||
'''), { | ||
'public_id': public_id, | ||
'name': row['name'], | ||
'app_id': id_mappings['apps'][UUID(row['app_id'])], | ||
'created_at': row['created_at'], | ||
'metadata': json.dumps(metadata) | ||
}) | ||
|
||
# Migrate sessions | ||
print("Migrating sessions...") | ||
result = await session.execute(text(f''' | ||
SELECT id::text, user_id::text, is_active, created_at, metadata | ||
FROM {SOURCE_SCHEMA}.sessions | ||
ORDER BY created_at ASC | ||
''')) | ||
for row in result.mappings(): | ||
public_id = generate_nanoid() | ||
id_mappings['sessions'][UUID(row['id'])] = public_id | ||
|
||
metadata = row['metadata'] or {} | ||
if isinstance(metadata, str): | ||
metadata = json.loads(metadata) | ||
|
||
metadata.update({ | ||
'legacy_id': str(row['id']) | ||
}) | ||
|
||
await session.execute(text(f''' | ||
INSERT INTO {DEST_SCHEMA}.sessions ( | ||
public_id, | ||
user_id, | ||
is_active, | ||
created_at, | ||
metadata | ||
) VALUES ( | ||
:public_id, | ||
:user_id, | ||
:is_active, | ||
:created_at, | ||
cast(:metadata as jsonb) | ||
) | ||
'''), { | ||
'public_id': public_id, | ||
'user_id': id_mappings['users'][UUID(row['user_id'])], | ||
'is_active': row['is_active'], | ||
'created_at': row['created_at'], | ||
'metadata': json.dumps(metadata) | ||
}) | ||
|
||
# Migrate messages | ||
print("Migrating messages...") | ||
result = await session.execute(text(f''' | ||
SELECT id::text, session_id::text, is_user, content, created_at, metadata | ||
FROM {SOURCE_SCHEMA}.messages | ||
ORDER BY created_at ASC | ||
''')) | ||
for row in result.mappings(): | ||
public_id = generate_nanoid() | ||
id_mappings['messages'][UUID(row['id'])] = public_id | ||
|
||
await session.execute(text(f''' | ||
INSERT INTO {DEST_SCHEMA}.messages (public_id, session_id, is_user, content, created_at, metadata) | ||
VALUES (:public_id, :session_id, :is_user, :content, :created_at, cast(:metadata as jsonb)) | ||
'''), { | ||
'public_id': public_id, | ||
'session_id': id_mappings['sessions'][UUID(row['session_id'])], | ||
'is_user': row['is_user'], | ||
'content': row['content'], | ||
'created_at': row['created_at'], | ||
'metadata': json.dumps(row['metadata'] or {}) | ||
}) | ||
|
||
if __name__ == "__main__": | ||
asyncio.run(migrate_data()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import asyncio | ||
import os | ||
from pathlib import Path | ||
|
||
from dotenv import load_dotenv | ||
from sqlalchemy import text | ||
|
||
from src.db import scaffold_db, engine | ||
from src.seed_from_export import seed_from_export | ||
|
||
async def drop_schema(): | ||
"""Drop the schema if it exists""" | ||
async with engine.begin() as conn: | ||
await conn.execute(text("DROP SCHEMA IF EXISTS honcho_old CASCADE")) | ||
|
||
async def main(): | ||
"""Main function to scaffold database and seed from export""" | ||
load_dotenv() | ||
|
||
# Ensure we're using the right schema | ||
if 'DATABASE_SCHEMA' not in os.environ: | ||
os.environ['DATABASE_SCHEMA'] = 'honcho_old' | ||
|
||
print("Dropping existing schema...") | ||
await drop_schema() | ||
|
||
print("Scaffolding database...") | ||
scaffold_db() | ||
|
||
print("Seeding database from export...") | ||
await seed_from_export() | ||
|
||
print("Database seeding complete!") | ||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import asyncio | ||
import csv | ||
import datetime | ||
import json | ||
import uuid | ||
from pathlib import Path | ||
from sqlalchemy import text | ||
|
||
from .db import SessionLocal, engine | ||
from .old_models import App, User, Session as ChatSession, Message, OldBase | ||
|
||
SOURCE_SCHEMA = 'honcho_old' | ||
|
||
async def parse_csv_file(file_path: Path) -> list[dict]: | ||
"""Parse a CSV file and return a list of dictionaries""" | ||
with open(str(file_path), 'r', encoding='utf-8') as f: | ||
reader = csv.DictReader(f) | ||
rows = [row for row in reader] | ||
# Sort by created_at in ascending order (oldest first) | ||
return sorted(rows, key=lambda x: x['created_at']) | ||
|
||
async def parse_metadata(metadata_str: str) -> dict: | ||
"""Parse metadata string into dictionary, handling empty cases""" | ||
if not metadata_str or metadata_str == '{}': | ||
return {} | ||
try: | ||
return json.loads(metadata_str) | ||
except json.JSONDecodeError: | ||
print(f"Warning: Could not parse metadata: {metadata_str}") | ||
return {} | ||
|
||
async def seed_from_export(dump_dir: str = "src/yousim_dump"): | ||
"""Seed the database with data from exported CSV files""" | ||
dump_path = Path(dump_dir) | ||
|
||
# Create schema if it doesn't exist | ||
print("Ensuring schema exists...") | ||
async with engine.begin() as conn: | ||
await conn.execute(text(f'CREATE SCHEMA IF NOT EXISTS {SOURCE_SCHEMA}')) | ||
|
||
# Drop existing tables and create new ones | ||
print("Dropping existing tables...") | ||
async with engine.begin() as conn: | ||
await conn.run_sync(OldBase.metadata.drop_all) | ||
print("Creating new tables...") | ||
async with engine.begin() as conn: | ||
await conn.run_sync(OldBase.metadata.create_all) | ||
|
||
# Track stats for reporting | ||
stats = { | ||
'apps': {'imported': 0, 'skipped': 0}, | ||
'users': {'imported': 0, 'skipped': 0}, | ||
'sessions': {'imported': 0, 'skipped': 0}, | ||
'messages': {'imported': 0, 'skipped': 0} | ||
} | ||
|
||
# Store mappings for foreign key relationships | ||
app_id_mapping = {} | ||
user_id_mapping = {} | ||
session_id_mapping = {} | ||
|
||
# Import Apps | ||
async with SessionLocal() as session: | ||
async with session.begin(): | ||
apps_data = await parse_csv_file(dump_path / "apps_rows (1).csv") | ||
for app_row in apps_data: | ||
try: | ||
metadata = await parse_metadata(app_row['metadata']) | ||
app = App( | ||
id=uuid.UUID(app_row['id']), | ||
name=app_row['name'], | ||
created_at=datetime.datetime.fromisoformat(app_row['created_at']), | ||
h_metadata=metadata | ||
) | ||
session.add(app) | ||
app_id_mapping[app_row['id']] = app.id | ||
stats['apps']['imported'] += 1 | ||
except Exception as e: | ||
print(f"Error importing app {app_row['id']}: {str(e)}") | ||
stats['apps']['skipped'] += 1 | ||
|
||
# Import Users | ||
async with SessionLocal() as session: | ||
async with session.begin(): | ||
users_data = await parse_csv_file(dump_path / "users_rows (1).csv") | ||
for user_row in users_data: | ||
try: | ||
metadata = await parse_metadata(user_row['metadata']) | ||
user = User( | ||
id=uuid.UUID(user_row['id']), | ||
name=user_row['name'], | ||
app_id=app_id_mapping[user_row['app_id']], | ||
created_at=datetime.datetime.fromisoformat(user_row['created_at']), | ||
h_metadata=metadata | ||
) | ||
session.add(user) | ||
user_id_mapping[user_row['id']] = user.id | ||
stats['users']['imported'] += 1 | ||
except Exception as e: | ||
print(f"Error importing user {user_row['id']}: {str(e)}") | ||
stats['users']['skipped'] += 1 | ||
|
||
# Import Sessions | ||
async with SessionLocal() as session: | ||
async with session.begin(): | ||
sessions_data = await parse_csv_file(dump_path / "sessions_rows.csv") | ||
for session_row in sessions_data: | ||
try: | ||
metadata = await parse_metadata(session_row['metadata']) | ||
# Removed legacy ID updates from here | ||
chat_session = ChatSession( | ||
id=uuid.UUID(session_row['id']), | ||
is_active=session_row['is_active'].lower() == 'true', | ||
user_id=user_id_mapping[session_row['user_id']], | ||
created_at=datetime.datetime.fromisoformat(session_row['created_at']), | ||
h_metadata=metadata # Using original metadata without modifications | ||
) | ||
session.add(chat_session) | ||
session_id_mapping[session_row['id']] = chat_session.id | ||
stats['sessions']['imported'] += 1 | ||
except Exception as e: | ||
print(f"Error importing session {session_row['id']}: {str(e)}") | ||
stats['sessions']['skipped'] += 1 | ||
|
||
# Import Messages | ||
async with SessionLocal() as session: | ||
async with session.begin(): | ||
messages_data = await parse_csv_file(dump_path / "messages_rows.csv") | ||
for message_row in messages_data: | ||
try: | ||
metadata = await parse_metadata(message_row['metadata']) | ||
message = Message( | ||
id=uuid.UUID(message_row['id']), | ||
session_id=session_id_mapping[message_row['session_id']], | ||
content=message_row['content'], | ||
is_user=message_row['is_user'].lower() == 'true', | ||
created_at=datetime.datetime.fromisoformat(message_row['created_at']), | ||
h_metadata=metadata | ||
) | ||
session.add(message) | ||
stats['messages']['imported'] += 1 | ||
except Exception as e: | ||
print(f"Error importing message {message_row['id']}: {str(e)}") | ||
stats['messages']['skipped'] += 1 | ||
|
||
# Print import statistics | ||
print("\nImport Statistics:") | ||
for entity, counts in stats.items(): | ||
print(f"{entity.title()}:") | ||
print(f" Imported: {counts['imported']}") | ||
print(f" Skipped: {counts['skipped']}") | ||
|
||
if __name__ == "__main__": | ||
asyncio.run(seed_from_export()) |