304 lines
8.2 KiB
Python
304 lines
8.2 KiB
Python
"""Pytest configuration for v2 schema migration tests.
|
|
|
|
Provides shared fixtures and configuration for testing the v2 schema migration
|
|
components including database setup, test data, and cleanup procedures.
|
|
"""
|
|
|
|
import pytest
|
|
import os
|
|
import tempfile
|
|
import shutil
|
|
from datetime import datetime, timezone
|
|
from sqlalchemy import create_engine, text
|
|
from sqlalchemy.orm import sessionmaker
|
|
from typing import Generator
|
|
|
|
from src.database.models import Base, register_model
|
|
from src.database.connection import get_database_url
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def test_db_url() -> str:
|
|
"""Get test database URL."""
|
|
# Use a separate test database to avoid affecting production data
|
|
return "postgresql://localhost/trax_test"
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def test_db_engine(test_db_url: str):
|
|
"""Create test database engine with proper cleanup."""
|
|
engine = create_engine(test_db_url)
|
|
|
|
# Create all tables for testing
|
|
Base.metadata.create_all(engine)
|
|
|
|
yield engine
|
|
|
|
# Cleanup: drop all tables
|
|
Base.metadata.drop_all(engine)
|
|
engine.dispose()
|
|
|
|
|
|
@pytest.fixture
|
|
def db_session(test_db_engine) -> Generator:
|
|
"""Create database session for individual tests."""
|
|
Session = sessionmaker(bind=test_db_engine)
|
|
session = Session()
|
|
|
|
yield session
|
|
|
|
# Rollback any uncommitted changes
|
|
session.rollback()
|
|
session.close()
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_v1_transcripts(db_session):
|
|
"""Create sample v1 transcripts for testing."""
|
|
from src.database.models import TranscriptionResult
|
|
|
|
transcripts = []
|
|
|
|
# Create sample v1 transcripts
|
|
for i in range(3):
|
|
transcript = TranscriptionResult(
|
|
content={"text": f"Sample v1 transcript {i}"},
|
|
accuracy=0.85 + (i * 0.05),
|
|
processing_time=10.0 + (i * 2.0)
|
|
)
|
|
db_session.add(transcript)
|
|
transcripts.append(transcript)
|
|
|
|
db_session.commit()
|
|
|
|
yield transcripts
|
|
|
|
# Cleanup is handled by db_session fixture
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_media_files(db_session):
|
|
"""Create sample media files for testing."""
|
|
from src.database.models import MediaFile
|
|
|
|
media_files = []
|
|
|
|
# Create sample media files
|
|
for i in range(2):
|
|
media_file = MediaFile(
|
|
filename=f"test_audio_{i}.wav",
|
|
file_size=1024 * 1024 * (i + 1), # 1MB, 2MB
|
|
duration=60.0 + (i * 30.0), # 60s, 90s
|
|
mime_type="audio/wav",
|
|
source_path=f"/path/to/source_{i}.wav",
|
|
local_path=f"/path/to/local_{i}.wav",
|
|
file_hash=f"hash_{i}",
|
|
status="ready"
|
|
)
|
|
db_session.add(media_file)
|
|
media_files.append(media_file)
|
|
|
|
db_session.commit()
|
|
|
|
yield media_files
|
|
|
|
# Cleanup is handled by db_session fixture
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_youtube_videos(db_session):
|
|
"""Create sample YouTube videos for testing."""
|
|
from src.database.models import YouTubeVideo
|
|
|
|
videos = []
|
|
|
|
# Create sample YouTube videos
|
|
for i in range(2):
|
|
video = YouTubeVideo(
|
|
youtube_id=f"test_id_{i}",
|
|
title=f"Test Video {i}",
|
|
channel=f"Test Channel {i}",
|
|
description=f"Test description {i}",
|
|
duration_seconds=300 + (i * 60), # 5min, 6min
|
|
url=f"https://youtube.com/watch?v=test_id_{i}"
|
|
)
|
|
db_session.add(video)
|
|
videos.append(video)
|
|
|
|
db_session.commit()
|
|
|
|
yield videos
|
|
|
|
# Cleanup is handled by db_session fixture
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_migration_dir():
|
|
"""Create temporary directory for migration testing."""
|
|
temp_dir = tempfile.mkdtemp()
|
|
|
|
yield temp_dir
|
|
|
|
# Cleanup
|
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_alembic_config(temp_migration_dir):
|
|
"""Create mock Alembic configuration for testing."""
|
|
import configparser
|
|
|
|
# Create alembic.ini
|
|
config = configparser.ConfigParser()
|
|
config.add_section('alembic')
|
|
config.set('alembic', 'script_location', os.path.join(temp_migration_dir, 'migrations'))
|
|
config.set('alembic', 'sqlalchemy.url', get_database_url().replace("/trax", "/trax_test"))
|
|
|
|
ini_path = os.path.join(temp_migration_dir, 'alembic.ini')
|
|
with open(ini_path, 'w') as f:
|
|
config.write(f)
|
|
|
|
# Create migrations directory structure
|
|
migrations_dir = os.path.join(temp_migration_dir, 'migrations')
|
|
os.makedirs(migrations_dir, exist_ok=True)
|
|
|
|
versions_dir = os.path.join(migrations_dir, 'versions')
|
|
os.makedirs(versions_dir, exist_ok=True)
|
|
|
|
# Create env.py
|
|
env_py_content = '''
|
|
from logging.config import fileConfig
|
|
from sqlalchemy import engine_from_config
|
|
from sqlalchemy import pool
|
|
from alembic import context
|
|
from src.database.models import Base
|
|
|
|
config = context.config
|
|
if config.config_file_name is not None:
|
|
fileConfig(config.config_file_name)
|
|
|
|
target_metadata = Base.metadata
|
|
|
|
def run_migrations_offline() -> None:
|
|
url = config.get_main_option("sqlalchemy.url")
|
|
context.configure(
|
|
url=url,
|
|
target_metadata=target_metadata,
|
|
literal_binds=True,
|
|
dialect_opts={"paramstyle": "named"},
|
|
)
|
|
|
|
with context.begin_transaction():
|
|
context.run_migrations()
|
|
|
|
def run_migrations_online() -> None:
|
|
connectable = engine_from_config(
|
|
config.get_section(config.config_ini_section, {}),
|
|
prefix="sqlalchemy.",
|
|
poolclass=pool.NullPool,
|
|
)
|
|
|
|
with connectable.connect() as connection:
|
|
context.configure(
|
|
connection=connection, target_metadata=target_metadata
|
|
)
|
|
|
|
with context.begin_transaction():
|
|
context.run_migrations()
|
|
|
|
if context.is_offline_mode():
|
|
run_migrations_offline()
|
|
else:
|
|
run_migrations_online()
|
|
'''
|
|
|
|
with open(os.path.join(migrations_dir, 'env.py'), 'w') as f:
|
|
f.write(env_py_content)
|
|
|
|
# Create script.py.mako
|
|
script_mako_content = '''"""${message}
|
|
|
|
Revision ID: ${up_revision}
|
|
Revises: ${down_revision | comma,n}
|
|
Create Date: ${create_date}
|
|
|
|
"""
|
|
from alembic import op
|
|
import sqlalchemy as sa
|
|
${imports if imports else ""}
|
|
|
|
# revision identifiers, used by Alembic.
|
|
revision = ${repr(up_revision)}
|
|
down_revision = ${repr(down_revision)}
|
|
branch_labels = ${repr(branch_labels)}
|
|
depends_on = ${repr(depends_on)}
|
|
|
|
|
|
def upgrade() -> None:
|
|
${upgrades if upgrades else "pass"}
|
|
|
|
|
|
def downgrade() -> None:
|
|
${downgrades if downgrades else "pass"}
|
|
'''
|
|
|
|
with open(os.path.join(migrations_dir, 'script.py.mako'), 'w') as f:
|
|
f.write(script_mako_content)
|
|
|
|
return ini_path
|
|
|
|
|
|
@pytest.fixture
|
|
def test_data_cleanup(db_session):
|
|
"""Clean up test data after each test."""
|
|
yield
|
|
|
|
# Clean up any remaining test data
|
|
try:
|
|
# Delete test data from all tables
|
|
tables = ['speaker_profiles', 'processing_jobs', 'transcription_results',
|
|
'media_files', 'youtube_videos']
|
|
|
|
for table in tables:
|
|
try:
|
|
db_session.execute(text(f"DELETE FROM {table}"))
|
|
except Exception:
|
|
# Table might not exist yet, which is fine
|
|
pass
|
|
|
|
db_session.commit()
|
|
except Exception:
|
|
# Ignore cleanup errors
|
|
pass
|
|
|
|
|
|
# Pytest configuration
|
|
def pytest_configure(config):
|
|
"""Configure pytest for v2 schema migration tests."""
|
|
# Add custom markers
|
|
config.addinivalue_line(
|
|
"markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')"
|
|
)
|
|
config.addinivalue_line(
|
|
"markers", "integration: marks tests as integration tests"
|
|
)
|
|
config.addinivalue_line(
|
|
"markers", "migration: marks tests as migration tests"
|
|
)
|
|
|
|
|
|
def pytest_collection_modifyitems(config, items):
|
|
"""Modify test collection to add markers based on test names."""
|
|
for item in items:
|
|
# Mark migration tests
|
|
if "migration" in item.nodeid.lower():
|
|
item.add_marker(pytest.mark.migration)
|
|
|
|
# Mark integration tests
|
|
if "integration" in item.nodeid.lower() or "repository" in item.nodeid.lower():
|
|
item.add_marker(pytest.mark.integration)
|
|
|
|
# Mark slow tests
|
|
if any(keyword in item.nodeid.lower() for keyword in ["performance", "migration"]):
|
|
item.add_marker(pytest.mark.slow)
|