568 lines
19 KiB
Python
568 lines
19 KiB
Python
"""Chat API endpoints for RAG-powered video conversations."""
|
|
|
|
import logging
|
|
from typing import List, Dict, Any, Optional
|
|
from datetime import datetime
|
|
|
|
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks, Query
|
|
from pydantic import BaseModel, Field
|
|
from sqlalchemy.orm import Session
|
|
|
|
from backend.core.database_registry import registry
|
|
from backend.models.chat import ChatSession, ChatMessage
|
|
from backend.models.summary import Summary
|
|
from backend.services.rag_service import RAGService, RAGError
|
|
from backend.services.auth_service import AuthService
|
|
from backend.models.user import User
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Initialize services
|
|
rag_service = RAGService()
|
|
auth_service = AuthService()
|
|
|
|
# Router
|
|
router = APIRouter(prefix="/api/chat", tags=["chat"])
|
|
|
|
|
|
# Request/Response Models
|
|
class CreateSessionRequest(BaseModel):
|
|
"""Request model for creating a chat session."""
|
|
video_id: str = Field(..., description="YouTube video ID")
|
|
title: Optional[str] = Field(None, description="Optional session title")
|
|
|
|
|
|
class ChatSessionResponse(BaseModel):
|
|
"""Response model for chat session."""
|
|
session_id: str
|
|
video_id: str
|
|
title: str
|
|
user_id: Optional[str]
|
|
message_count: int
|
|
is_active: bool
|
|
created_at: str
|
|
last_message_at: Optional[str]
|
|
video_metadata: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
class ChatQueryRequest(BaseModel):
|
|
"""Request model for chat query."""
|
|
query: str = Field(..., min_length=1, max_length=2000, description="User's question")
|
|
search_mode: Optional[str] = Field("hybrid", description="Search strategy: vector, hybrid, traditional")
|
|
max_context_chunks: Optional[int] = Field(None, ge=1, le=10, description="Maximum context chunks to use")
|
|
|
|
|
|
class ChatMessageResponse(BaseModel):
|
|
"""Response model for chat message."""
|
|
id: str
|
|
message_type: str
|
|
content: str
|
|
created_at: str
|
|
sources: Optional[List[Dict[str, Any]]] = None
|
|
total_sources: Optional[int] = None
|
|
|
|
|
|
class ChatQueryResponse(BaseModel):
|
|
"""Response model for chat query response."""
|
|
model_config = {"protected_namespaces": ()} # Allow 'model_' fields
|
|
|
|
response: str
|
|
sources: List[Dict[str, Any]]
|
|
total_sources: int
|
|
query: str
|
|
context_chunks_used: int
|
|
model_used: str
|
|
processing_time_seconds: float
|
|
timestamp: str
|
|
no_context_found: Optional[bool] = None
|
|
|
|
|
|
class IndexVideoRequest(BaseModel):
|
|
"""Request model for indexing video content."""
|
|
video_id: str = Field(..., description="YouTube video ID")
|
|
transcript: str = Field(..., min_length=100, description="Video transcript text")
|
|
summary_id: Optional[str] = Field(None, description="Optional summary ID")
|
|
|
|
|
|
class IndexVideoResponse(BaseModel):
|
|
"""Response model for video indexing."""
|
|
video_id: str
|
|
chunks_created: int
|
|
chunks_indexed: int
|
|
processing_time_seconds: float
|
|
indexed: bool
|
|
chunking_stats: Dict[str, Any]
|
|
|
|
|
|
# Dependency functions
|
|
def get_db() -> Session:
|
|
"""Get database session."""
|
|
return registry.get_session()
|
|
|
|
|
|
def get_current_user_optional() -> Optional[User]:
|
|
"""Get current user (optional for demo mode)."""
|
|
return None # For now, return None to support demo mode
|
|
|
|
|
|
async def get_rag_service() -> RAGService:
|
|
"""Get RAG service instance."""
|
|
if not hasattr(rag_service, '_initialized'):
|
|
await rag_service.initialize()
|
|
rag_service._initialized = True
|
|
return rag_service
|
|
|
|
|
|
# API Endpoints
|
|
@router.post("/sessions", response_model=Dict[str, Any])
|
|
async def create_chat_session(
|
|
request: CreateSessionRequest,
|
|
current_user: Optional[User] = Depends(get_current_user_optional),
|
|
rag_service: RAGService = Depends(get_rag_service)
|
|
):
|
|
"""Create a new chat session for a video.
|
|
|
|
Args:
|
|
request: Session creation request
|
|
current_user: Optional authenticated user
|
|
rag_service: RAG service instance
|
|
|
|
Returns:
|
|
Created session information
|
|
"""
|
|
try:
|
|
logger.info(f"Creating chat session for video {request.video_id}")
|
|
|
|
# Check if video exists and is indexed
|
|
with registry.get_session() as session:
|
|
summary = session.query(Summary).filter(
|
|
Summary.video_id == request.video_id
|
|
).first()
|
|
|
|
if not summary:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail=f"Video {request.video_id} not found. Please process the video first."
|
|
)
|
|
|
|
# Create chat session
|
|
session_info = await rag_service.create_chat_session(
|
|
video_id=request.video_id,
|
|
user_id=str(current_user.id) if current_user else None,
|
|
title=request.title
|
|
)
|
|
|
|
return {
|
|
"success": True,
|
|
"session": session_info,
|
|
"message": "Chat session created successfully"
|
|
}
|
|
|
|
except RAGError as e:
|
|
logger.error(f"RAG error creating session: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error creating session: {e}")
|
|
raise HTTPException(status_code=500, detail="Failed to create chat session")
|
|
|
|
|
|
@router.get("/sessions/{session_id}", response_model=ChatSessionResponse)
|
|
async def get_chat_session(
|
|
session_id: str,
|
|
current_user: Optional[User] = Depends(get_current_user_optional)
|
|
):
|
|
"""Get chat session information.
|
|
|
|
Args:
|
|
session_id: Chat session ID
|
|
current_user: Optional authenticated user
|
|
|
|
Returns:
|
|
Chat session details
|
|
"""
|
|
try:
|
|
with registry.get_session() as session:
|
|
chat_session = session.query(ChatSession).filter(
|
|
ChatSession.id == session_id
|
|
).first()
|
|
|
|
if not chat_session:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Chat session not found"
|
|
)
|
|
|
|
# Check permissions (users can only access their own sessions)
|
|
if current_user and chat_session.user_id and chat_session.user_id != str(current_user.id):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="Access denied"
|
|
)
|
|
|
|
# Get video metadata
|
|
video_metadata = None
|
|
if chat_session.summary_id:
|
|
summary = session.query(Summary).filter(
|
|
Summary.id == chat_session.summary_id
|
|
).first()
|
|
if summary:
|
|
video_metadata = {
|
|
'title': summary.video_title,
|
|
'channel': getattr(summary, 'channel_name', None),
|
|
'duration': getattr(summary, 'video_duration', None)
|
|
}
|
|
|
|
return ChatSessionResponse(
|
|
session_id=chat_session.id,
|
|
video_id=chat_session.video_id,
|
|
title=chat_session.title,
|
|
user_id=chat_session.user_id,
|
|
message_count=chat_session.message_count or 0,
|
|
is_active=chat_session.is_active,
|
|
created_at=chat_session.created_at.isoformat() if chat_session.created_at else "",
|
|
last_message_at=chat_session.last_message_at.isoformat() if chat_session.last_message_at else None,
|
|
video_metadata=video_metadata
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error getting session: {e}")
|
|
raise HTTPException(status_code=500, detail="Failed to get session")
|
|
|
|
|
|
@router.post("/sessions/{session_id}/messages", response_model=ChatQueryResponse)
|
|
async def send_chat_message(
|
|
session_id: str,
|
|
request: ChatQueryRequest,
|
|
current_user: Optional[User] = Depends(get_current_user_optional),
|
|
rag_service: RAGService = Depends(get_rag_service)
|
|
):
|
|
"""Send a message to the chat session and get AI response.
|
|
|
|
Args:
|
|
session_id: Chat session ID
|
|
request: Chat query request
|
|
current_user: Optional authenticated user
|
|
rag_service: RAG service instance
|
|
|
|
Returns:
|
|
AI response with sources and metadata
|
|
"""
|
|
try:
|
|
logger.info(f"Processing chat message for session {session_id}")
|
|
|
|
# Verify session exists and user has access
|
|
with registry.get_session() as session:
|
|
chat_session = session.query(ChatSession).filter(
|
|
ChatSession.id == session_id
|
|
).first()
|
|
|
|
if not chat_session:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Chat session not found"
|
|
)
|
|
|
|
if not chat_session.is_active:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Chat session is not active"
|
|
)
|
|
|
|
# Check permissions
|
|
if current_user and chat_session.user_id and chat_session.user_id != str(current_user.id):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="Access denied"
|
|
)
|
|
|
|
# Process chat query
|
|
response = await rag_service.chat_query(
|
|
session_id=session_id,
|
|
query=request.query,
|
|
user_id=str(current_user.id) if current_user else None,
|
|
search_mode=request.search_mode,
|
|
max_context_chunks=request.max_context_chunks
|
|
)
|
|
|
|
return ChatQueryResponse(**response)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except RAGError as e:
|
|
logger.error(f"RAG error processing message: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error processing message: {e}")
|
|
raise HTTPException(status_code=500, detail="Failed to process message")
|
|
|
|
|
|
@router.get("/sessions/{session_id}/history", response_model=List[ChatMessageResponse])
|
|
async def get_chat_history(
|
|
session_id: str,
|
|
limit: int = Query(50, ge=1, le=200, description="Maximum number of messages"),
|
|
current_user: Optional[User] = Depends(get_current_user_optional),
|
|
rag_service: RAGService = Depends(get_rag_service)
|
|
):
|
|
"""Get chat history for a session.
|
|
|
|
Args:
|
|
session_id: Chat session ID
|
|
limit: Maximum number of messages to return
|
|
current_user: Optional authenticated user
|
|
rag_service: RAG service instance
|
|
|
|
Returns:
|
|
List of chat messages
|
|
"""
|
|
try:
|
|
# Verify session and permissions
|
|
with registry.get_session() as session:
|
|
chat_session = session.query(ChatSession).filter(
|
|
ChatSession.id == session_id
|
|
).first()
|
|
|
|
if not chat_session:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Chat session not found"
|
|
)
|
|
|
|
# Check permissions
|
|
if current_user and chat_session.user_id and chat_session.user_id != str(current_user.id):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="Access denied"
|
|
)
|
|
|
|
# Get chat history
|
|
messages = await rag_service.get_chat_history(session_id, limit)
|
|
|
|
return [
|
|
ChatMessageResponse(
|
|
id=msg['id'],
|
|
message_type=msg['message_type'],
|
|
content=msg['content'],
|
|
created_at=msg['created_at'],
|
|
sources=msg.get('sources'),
|
|
total_sources=msg.get('total_sources')
|
|
)
|
|
for msg in messages
|
|
]
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error getting chat history: {e}")
|
|
raise HTTPException(status_code=500, detail="Failed to get chat history")
|
|
|
|
|
|
@router.delete("/sessions/{session_id}")
|
|
async def end_chat_session(
|
|
session_id: str,
|
|
current_user: Optional[User] = Depends(get_current_user_optional)
|
|
):
|
|
"""End/deactivate a chat session.
|
|
|
|
Args:
|
|
session_id: Chat session ID
|
|
current_user: Optional authenticated user
|
|
|
|
Returns:
|
|
Success confirmation
|
|
"""
|
|
try:
|
|
with registry.get_session() as session:
|
|
chat_session = session.query(ChatSession).filter(
|
|
ChatSession.id == session_id
|
|
).first()
|
|
|
|
if not chat_session:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Chat session not found"
|
|
)
|
|
|
|
# Check permissions
|
|
if current_user and chat_session.user_id and chat_session.user_id != str(current_user.id):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="Access denied"
|
|
)
|
|
|
|
# Deactivate session
|
|
chat_session.is_active = False
|
|
chat_session.ended_at = datetime.now()
|
|
session.commit()
|
|
|
|
return {
|
|
"success": True,
|
|
"message": "Chat session ended successfully"
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error ending session: {e}")
|
|
raise HTTPException(status_code=500, detail="Failed to end session")
|
|
|
|
|
|
@router.post("/index", response_model=IndexVideoResponse)
|
|
async def index_video_content(
|
|
request: IndexVideoRequest,
|
|
background_tasks: BackgroundTasks,
|
|
current_user: Optional[User] = Depends(get_current_user_optional),
|
|
rag_service: RAGService = Depends(get_rag_service)
|
|
):
|
|
"""Index video content for RAG search.
|
|
|
|
Args:
|
|
request: Video indexing request
|
|
background_tasks: FastAPI background tasks
|
|
current_user: Optional authenticated user
|
|
rag_service: RAG service instance
|
|
|
|
Returns:
|
|
Indexing results
|
|
"""
|
|
try:
|
|
logger.info(f"Indexing video content for {request.video_id}")
|
|
|
|
# Index video content
|
|
result = await rag_service.index_video_content(
|
|
video_id=request.video_id,
|
|
transcript=request.transcript,
|
|
summary_id=request.summary_id
|
|
)
|
|
|
|
return IndexVideoResponse(**result)
|
|
|
|
except RAGError as e:
|
|
logger.error(f"RAG error indexing video: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error indexing video: {e}")
|
|
raise HTTPException(status_code=500, detail="Failed to index video content")
|
|
|
|
|
|
@router.get("/user/sessions", response_model=List[ChatSessionResponse])
|
|
async def get_user_chat_sessions(
|
|
current_user: User = Depends(get_current_user_optional),
|
|
limit: int = Query(50, ge=1, le=200, description="Maximum number of sessions")
|
|
):
|
|
"""Get chat sessions for the current user.
|
|
|
|
Args:
|
|
current_user: Authenticated user (optional for demo mode)
|
|
limit: Maximum number of sessions
|
|
|
|
Returns:
|
|
List of user's chat sessions
|
|
"""
|
|
try:
|
|
with registry.get_session() as session:
|
|
query = session.query(ChatSession)
|
|
|
|
# Filter by user if authenticated
|
|
if current_user:
|
|
query = query.filter(ChatSession.user_id == str(current_user.id))
|
|
|
|
sessions = query.order_by(
|
|
ChatSession.last_message_at.desc().nulls_last(),
|
|
ChatSession.created_at.desc()
|
|
).limit(limit).all()
|
|
|
|
# Format response
|
|
session_responses = []
|
|
for chat_session in sessions:
|
|
# Get video metadata
|
|
video_metadata = None
|
|
if chat_session.summary_id:
|
|
summary = session.query(Summary).filter(
|
|
Summary.id == chat_session.summary_id
|
|
).first()
|
|
if summary:
|
|
video_metadata = {
|
|
'title': summary.video_title,
|
|
'channel': getattr(summary, 'channel_name', None)
|
|
}
|
|
|
|
session_responses.append(ChatSessionResponse(
|
|
session_id=chat_session.id,
|
|
video_id=chat_session.video_id,
|
|
title=chat_session.title,
|
|
user_id=chat_session.user_id,
|
|
message_count=chat_session.message_count or 0,
|
|
is_active=chat_session.is_active,
|
|
created_at=chat_session.created_at.isoformat() if chat_session.created_at else "",
|
|
last_message_at=chat_session.last_message_at.isoformat() if chat_session.last_message_at else None,
|
|
video_metadata=video_metadata
|
|
))
|
|
|
|
return session_responses
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting user sessions: {e}")
|
|
raise HTTPException(status_code=500, detail="Failed to get user sessions")
|
|
|
|
|
|
@router.get("/stats")
|
|
async def get_chat_stats(
|
|
current_user: Optional[User] = Depends(get_current_user_optional),
|
|
rag_service: RAGService = Depends(get_rag_service)
|
|
):
|
|
"""Get chat service statistics and health metrics.
|
|
|
|
Args:
|
|
current_user: Optional authenticated user
|
|
rag_service: RAG service instance
|
|
|
|
Returns:
|
|
Service statistics
|
|
"""
|
|
try:
|
|
stats = await rag_service.get_service_stats()
|
|
return {
|
|
"success": True,
|
|
"stats": stats,
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting chat stats: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
|
|
@router.get("/health")
|
|
async def chat_health_check(
|
|
rag_service: RAGService = Depends(get_rag_service)
|
|
):
|
|
"""Perform health check on chat service.
|
|
|
|
Args:
|
|
rag_service: RAG service instance
|
|
|
|
Returns:
|
|
Health check results
|
|
"""
|
|
try:
|
|
health = await rag_service.health_check()
|
|
return {
|
|
"service": "chat",
|
|
"timestamp": datetime.now().isoformat(),
|
|
**health
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Chat health check failed: {e}")
|
|
return {
|
|
"service": "chat",
|
|
"status": "unhealthy",
|
|
"error": str(e),
|
|
"timestamp": datetime.now().isoformat()
|
|
} |