youtube-summarizer/backend/services/rag_chat_service.py

825 lines
29 KiB
Python

"""RAG-powered chat service for interactive Q&A with video content."""
import asyncio
import logging
import uuid
from typing import Dict, List, Optional, Any, Tuple
from datetime import datetime
from dataclasses import dataclass
from enum import Enum
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
import numpy as np
from pydantic import BaseModel
from ..core.exceptions import ServiceError
from .deepseek_service import DeepSeekService
logger = logging.getLogger(__name__)
class MessageType(str, Enum):
"""Chat message types."""
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
class SourceReference(BaseModel):
"""Reference to source content with timestamp."""
chunk_id: str
timestamp: int # seconds
timestamp_formatted: str # [HH:MM:SS]
youtube_link: str
chunk_text: str
relevance_score: float
class ChatMessage(BaseModel):
"""Individual chat message."""
id: str
message_type: MessageType
content: str
sources: List[SourceReference]
processing_time_seconds: float
created_at: datetime
class ChatSession(BaseModel):
"""Chat session for a video."""
id: str
user_id: str
video_id: str
summary_id: str
session_name: str
messages: List[ChatMessage]
total_messages: int
is_active: bool
created_at: datetime
updated_at: datetime
class ChatRequest(BaseModel):
"""Request to ask a question."""
video_id: str
question: str
session_id: Optional[str] = None
include_context: bool = True
max_sources: int = 5
class ChatResponse(BaseModel):
"""Response from chat service."""
session_id: str
message: ChatMessage
follow_up_suggestions: List[str]
context_retrieved: bool
total_chunks_searched: int
@dataclass
class TranscriptChunk:
"""Chunk of transcript with metadata."""
chunk_id: str
video_id: str
chunk_text: str
start_timestamp: int
end_timestamp: int
chunk_index: int
word_count: int
class RAGChatService:
"""Service for RAG-powered chat with video content."""
def __init__(
self,
ai_service: Optional[DeepSeekService] = None,
chromadb_path: str = "./data/chromadb_rag"
):
"""Initialize RAG chat service.
Args:
ai_service: DeepSeek AI service for response generation
chromadb_path: Path to ChromaDB persistent storage
"""
self.ai_service = ai_service or DeepSeekService()
self.chromadb_path = chromadb_path
# Initialize embedding model (local, no API required)
logger.info("Loading sentence transformer model...")
self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# Initialize ChromaDB client
self.chroma_client = chromadb.PersistentClient(
path=chromadb_path,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True
)
)
# Chat session storage (in-memory for now, could be database)
self.chat_sessions: Dict[str, ChatSession] = {}
logger.info(f"RAG Chat Service initialized with ChromaDB at {chromadb_path}")
async def process_video_for_rag(
self,
video_id: str,
transcript: str,
video_title: str = ""
) -> bool:
"""Process video transcript for RAG by creating embeddings.
Args:
video_id: YouTube video ID
transcript: Video transcript text
video_title: Video title for context
Returns:
True if processing successful
"""
if not transcript or len(transcript.strip()) < 50:
raise ServiceError("Transcript too short for RAG processing")
logger.info(f"Processing video {video_id} for RAG with {len(transcript)} characters")
try:
# 1. Chunk the transcript
chunks = self._chunk_transcript(transcript, video_id)
logger.info(f"Created {len(chunks)} chunks for video {video_id}")
# 2. Generate embeddings for chunks
chunk_texts = [chunk.chunk_text for chunk in chunks]
logger.info("Generating embeddings...")
embeddings = self.embedding_model.encode(chunk_texts, convert_to_tensor=False)
# 3. Store in ChromaDB
collection_name = f"video_{video_id}"
# Create or get collection
try:
collection = self.chroma_client.get_collection(collection_name)
# Clear existing data
collection.delete()
logger.info(f"Cleared existing data for video {video_id}")
except ValueError:
# Collection doesn't exist, create it
pass
collection = self.chroma_client.create_collection(
name=collection_name,
metadata={"video_id": video_id, "video_title": video_title}
)
# Prepare data for ChromaDB
chunk_ids = [chunk.chunk_id for chunk in chunks]
metadatas = [
{
"video_id": chunk.video_id,
"start_timestamp": chunk.start_timestamp,
"end_timestamp": chunk.end_timestamp,
"chunk_index": chunk.chunk_index,
"word_count": chunk.word_count
}
for chunk in chunks
]
# Add to collection
collection.add(
embeddings=embeddings.tolist(),
documents=chunk_texts,
metadatas=metadatas,
ids=chunk_ids
)
logger.info(f"Successfully stored {len(chunks)} chunks in ChromaDB for video {video_id}")
return True
except Exception as e:
logger.error(f"Error processing video {video_id} for RAG: {e}")
raise ServiceError(f"RAG processing failed: {str(e)}")
def _chunk_transcript(self, transcript: str, video_id: str) -> List[TranscriptChunk]:
"""Chunk transcript into semantically meaningful segments.
Args:
transcript: Full transcript text
video_id: Video ID for chunk IDs
Returns:
List of transcript chunks
"""
# Simple chunking strategy: split by paragraphs with overlap
paragraphs = transcript.split('\n\n')
chunks = []
chunk_size = 300 # Target words per chunk
overlap = 50 # Overlap words between chunks
current_chunk = ""
current_word_count = 0
chunk_index = 0
estimated_timestamp = 0
words_per_minute = 150 # Average speaking rate
for paragraph in paragraphs:
paragraph = paragraph.strip()
if not paragraph:
continue
paragraph_words = len(paragraph.split())
# Add paragraph to current chunk
current_chunk += paragraph + "\n\n"
current_word_count += paragraph_words
# Create chunk if we've reached target size
if current_word_count >= chunk_size or paragraph == paragraphs[-1]:
if current_chunk.strip():
# Calculate timestamps (rough estimation)
chunk_duration = (current_word_count / words_per_minute) * 60
start_timestamp = estimated_timestamp
end_timestamp = estimated_timestamp + int(chunk_duration)
chunk = TranscriptChunk(
chunk_id=f"{video_id}_chunk_{chunk_index}",
video_id=video_id,
chunk_text=current_chunk.strip(),
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
chunk_index=chunk_index,
word_count=current_word_count
)
chunks.append(chunk)
# Prepare next chunk with overlap
if paragraph == paragraphs[-1]:
break
# Create overlap by keeping last part of current chunk
sentences = current_chunk.strip().split('.')
if len(sentences) > 2:
overlap_text = '. '.join(sentences[-2:]).strip()
overlap_words = len(overlap_text.split())
current_chunk = overlap_text + ".\n\n"
current_word_count = overlap_words
estimated_timestamp = end_timestamp - (overlap_words / words_per_minute * 60)
else:
current_chunk = ""
current_word_count = 0
estimated_timestamp = end_timestamp
chunk_index += 1
return chunks
async def ask_question(
self,
request: ChatRequest,
user_id: str = "anonymous"
) -> ChatResponse:
"""Ask a question about video content using RAG.
Args:
request: Chat request with question and video ID
user_id: User ID for session management
Returns:
Chat response with answer and sources
"""
if not request.question or len(request.question.strip()) < 3:
raise ServiceError("Question is too short")
start_time = datetime.now()
logger.info(f"Processing question for video {request.video_id}: {request.question[:100]}...")
try:
# 1. Get or create chat session
session = await self._get_or_create_session(
request.session_id, user_id, request.video_id
)
# 2. Retrieve relevant chunks
relevant_chunks, total_searched = await self._retrieve_relevant_chunks(
request.video_id, request.question, request.max_sources
)
# 3. Generate response using RAG
response_content = await self._generate_rag_response(
request.question, relevant_chunks, session.messages[-5:] if session.messages else []
)
# 4. Create source references
source_refs = self._create_source_references(relevant_chunks, request.video_id)
# 5. Generate follow-up suggestions
follow_ups = await self._generate_follow_up_suggestions(
request.question, response_content, relevant_chunks
)
# 6. Create chat message
processing_time = (datetime.now() - start_time).total_seconds()
message = ChatMessage(
id=str(uuid.uuid4()),
message_type=MessageType.ASSISTANT,
content=response_content,
sources=source_refs,
processing_time_seconds=processing_time,
created_at=start_time
)
# 7. Add to session
user_message = ChatMessage(
id=str(uuid.uuid4()),
message_type=MessageType.USER,
content=request.question,
sources=[],
processing_time_seconds=0,
created_at=start_time
)
session.messages.extend([user_message, message])
session.total_messages += 2
session.updated_at = datetime.now()
# 8. Store session
self.chat_sessions[session.id] = session
response = ChatResponse(
session_id=session.id,
message=message,
follow_up_suggestions=follow_ups,
context_retrieved=len(relevant_chunks) > 0,
total_chunks_searched=total_searched
)
logger.info(f"Question answered in {processing_time:.2f}s with {len(source_refs)} sources")
return response
except Exception as e:
logger.error(f"Error answering question for video {request.video_id}: {e}")
raise ServiceError(f"Failed to answer question: {str(e)}")
async def _retrieve_relevant_chunks(
self,
video_id: str,
question: str,
max_results: int = 5
) -> Tuple[List[Dict[str, Any]], int]:
"""Retrieve relevant chunks using semantic search.
Args:
video_id: Video ID to search
question: User question
max_results: Maximum chunks to return
Returns:
Tuple of (relevant chunks, total searched)
"""
collection_name = f"video_{video_id}"
try:
collection = self.chroma_client.get_collection(collection_name)
# Generate embedding for question
question_embedding = self.embedding_model.encode([question], convert_to_tensor=False)
# Search for relevant chunks
results = collection.query(
query_embeddings=question_embedding.tolist(),
n_results=max_results,
include=['documents', 'metadatas', 'distances']
)
# Process results
relevant_chunks = []
if results['documents'] and len(results['documents'][0]) > 0:
documents = results['documents'][0]
metadatas = results['metadatas'][0]
distances = results['distances'][0]
for i, (doc, metadata, distance) in enumerate(zip(documents, metadatas, distances)):
# Convert distance to similarity score (lower distance = higher similarity)
relevance_score = max(0, 1 - distance)
chunk_data = {
'chunk_text': doc,
'metadata': metadata,
'relevance_score': relevance_score,
'rank': i + 1
}
relevant_chunks.append(chunk_data)
# Get total count from collection
total_count = collection.count()
logger.info(f"Retrieved {len(relevant_chunks)} relevant chunks from {total_count} total chunks")
return relevant_chunks, total_count
except ValueError:
logger.warning(f"No collection found for video {video_id}")
return [], 0
except Exception as e:
logger.error(f"Error retrieving chunks for video {video_id}: {e}")
return [], 0
async def _generate_rag_response(
self,
question: str,
relevant_chunks: List[Dict[str, Any]],
chat_history: List[ChatMessage]
) -> str:
"""Generate response using retrieved chunks and chat history.
Args:
question: User question
relevant_chunks: Retrieved relevant chunks
chat_history: Recent chat history
Returns:
Generated response
"""
if not relevant_chunks:
return "I couldn't find relevant information in the video to answer your question. Could you please rephrase or ask about something else covered in the content?"
# Build context from chunks
context_parts = []
for i, chunk in enumerate(relevant_chunks[:5], 1):
timestamp = self._format_timestamp(chunk['metadata']['start_timestamp'])
context_parts.append(f"[Context {i} - {timestamp}]: {chunk['chunk_text'][:400]}")
context = "\n\n".join(context_parts)
# Build chat history context
history_context = ""
if chat_history:
recent_messages = []
for msg in chat_history[-4:]: # Last 4 messages
if msg.message_type == MessageType.USER:
recent_messages.append(f"User: {msg.content}")
elif msg.message_type == MessageType.ASSISTANT:
recent_messages.append(f"Assistant: {msg.content[:200]}...")
if recent_messages:
history_context = f"\n\nRecent conversation:\n{chr(10).join(recent_messages)}"
system_prompt = """You are a helpful AI assistant that answers questions about video content.
You have access to relevant sections of the video transcript with timestamps.
Instructions:
- Answer the user's question based on the provided context
- Include timestamp references like [05:23] when referencing specific parts
- If the context doesn't contain enough information, say so clearly
- Keep responses conversational but informative
- Don't make up information not in the context
- If multiple contexts are relevant, synthesize information from them
"""
prompt = f"""Based on the video content below, please answer this question: "{question}"
Video Content:
{context}
{history_context}
Please provide a helpful response that references specific timestamps when possible."""
try:
response = await self.ai_service.generate_response(
prompt=prompt,
system_prompt=system_prompt,
temperature=0.4, # Slightly creative but grounded
max_tokens=800
)
# Add timestamp formatting to response if not present
response = self._enhance_response_with_timestamps(response, relevant_chunks)
return response
except Exception as e:
logger.error(f"Error generating RAG response: {e}")
return "I encountered an error generating a response. Please try asking your question again."
def _enhance_response_with_timestamps(
self,
response: str,
relevant_chunks: List[Dict[str, Any]]
) -> str:
"""Enhance response with timestamp references.
Args:
response: Generated response
relevant_chunks: Source chunks with timestamps
Returns:
Enhanced response with timestamps
"""
# If response doesn't have timestamps, add them for the most relevant chunk
if '[' not in response and relevant_chunks:
most_relevant = relevant_chunks[0]
timestamp = self._format_timestamp(most_relevant['metadata']['start_timestamp'])
# Add timestamp reference to the beginning
response = f"According to the video at [{timestamp}], {response[0].lower()}{response[1:]}"
return response
def _create_source_references(
self,
relevant_chunks: List[Dict[str, Any]],
video_id: str
) -> List[SourceReference]:
"""Create source references from relevant chunks.
Args:
relevant_chunks: Retrieved chunks
video_id: Video ID for YouTube links
Returns:
List of source references
"""
source_refs = []
for chunk in relevant_chunks:
metadata = chunk['metadata']
start_timestamp = metadata['start_timestamp']
source_ref = SourceReference(
chunk_id=f"{video_id}_chunk_{metadata['chunk_index']}",
timestamp=start_timestamp,
timestamp_formatted=f"[{self._format_timestamp(start_timestamp)}]",
youtube_link=f"https://youtube.com/watch?v={video_id}&t={start_timestamp}s",
chunk_text=chunk['chunk_text'][:200] + "..." if len(chunk['chunk_text']) > 200 else chunk['chunk_text'],
relevance_score=round(chunk['relevance_score'], 3)
)
source_refs.append(source_ref)
return source_refs
async def _generate_follow_up_suggestions(
self,
question: str,
response: str,
relevant_chunks: List[Dict[str, Any]]
) -> List[str]:
"""Generate follow-up question suggestions.
Args:
question: Original question
response: Generated response
relevant_chunks: Source chunks
Returns:
List of follow-up suggestions
"""
if not relevant_chunks:
return []
try:
# Extract topics from chunks for follow-up suggestions
chunk_topics = []
for chunk in relevant_chunks[:3]:
text = chunk['chunk_text'][:300]
chunk_topics.append(text)
context = " ".join(chunk_topics)
system_prompt = """Generate 3 relevant follow-up questions based on the video content.
Questions should be natural, specific, and encourage deeper exploration of the topic.
Return only the questions, one per line, without numbering."""
prompt = f"""Based on this video content and the user's interest in "{question}", suggest follow-up questions:
{context[:1000]}
Generate 3 specific follow-up questions that would help the user learn more about this topic."""
suggestions_response = await self.ai_service.generate_response(
prompt=prompt,
system_prompt=system_prompt,
temperature=0.6, # More creative for suggestions
max_tokens=200
)
# Parse suggestions
suggestions = []
for line in suggestions_response.split('\n'):
line = line.strip()
if line and not line.startswith(('-', '*', '1.', '2.', '3.')):
# Clean up the suggestion
line = line.lstrip('1234567890.-* ')
if len(line) > 10 and '?' in line:
suggestions.append(line)
return suggestions[:3] # Limit to 3 suggestions
except Exception as e:
logger.error(f"Error generating follow-up suggestions: {e}")
return []
async def _get_or_create_session(
self,
session_id: Optional[str],
user_id: str,
video_id: str
) -> ChatSession:
"""Get existing session or create new one.
Args:
session_id: Optional existing session ID
user_id: User ID
video_id: Video ID
Returns:
Chat session
"""
if session_id and session_id in self.chat_sessions:
session = self.chat_sessions[session_id]
if session.video_id == video_id:
return session
# Create new session
new_session = ChatSession(
id=str(uuid.uuid4()),
user_id=user_id,
video_id=video_id,
summary_id="", # Will be set when linked to summary
session_name=f"Chat - {datetime.now().strftime('%Y-%m-%d %H:%M')}",
messages=[],
total_messages=0,
is_active=True,
created_at=datetime.now(),
updated_at=datetime.now()
)
self.chat_sessions[new_session.id] = new_session
return new_session
def _format_timestamp(self, seconds: int) -> str:
"""Format seconds as MM:SS or HH:MM:SS.
Args:
seconds: Time in seconds
Returns:
Formatted timestamp
"""
hours = seconds // 3600
minutes = (seconds % 3600) // 60
secs = seconds % 60
if hours > 0:
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
else:
return f"{minutes:02d}:{secs:02d}"
async def get_chat_session(self, session_id: str) -> Optional[ChatSession]:
"""Get chat session by ID.
Args:
session_id: Session ID
Returns:
Chat session or None if not found
"""
return self.chat_sessions.get(session_id)
async def list_user_sessions(self, user_id: str, video_id: Optional[str] = None) -> List[ChatSession]:
"""List chat sessions for a user.
Args:
user_id: User ID
video_id: Optional video ID filter
Returns:
List of user's chat sessions
"""
sessions = []
for session in self.chat_sessions.values():
if session.user_id == user_id:
if video_id is None or session.video_id == video_id:
sessions.append(session)
# Sort by most recent
sessions.sort(key=lambda s: s.updated_at, reverse=True)
return sessions
async def delete_session(self, session_id: str, user_id: str) -> bool:
"""Delete a chat session.
Args:
session_id: Session ID to delete
user_id: User ID for authorization
Returns:
True if deleted successfully
"""
if session_id in self.chat_sessions:
session = self.chat_sessions[session_id]
if session.user_id == user_id:
del self.chat_sessions[session_id]
return True
return False
async def export_session(self, session_id: str, user_id: str) -> Optional[str]:
"""Export chat session as markdown.
Args:
session_id: Session ID
user_id: User ID for authorization
Returns:
Markdown export or None if not found
"""
session = self.chat_sessions.get(session_id)
if not session or session.user_id != user_id:
return None
lines = [
f"# Chat Session: {session.session_name}",
"",
f"**Video ID:** {session.video_id}",
f"**Created:** {session.created_at.strftime('%Y-%m-%d %H:%M:%S')}",
f"**Total Messages:** {session.total_messages}",
"",
"---",
""
]
for message in session.messages:
if message.message_type == MessageType.USER:
lines.extend([
f"## 👤 User",
"",
message.content,
""
])
elif message.message_type == MessageType.ASSISTANT:
lines.extend([
f"## 🤖 Assistant",
"",
message.content,
""
])
if message.sources:
lines.extend([
"**Sources:**",
""
])
for source in message.sources:
lines.append(f"- {source.timestamp_formatted} [Jump to video]({source.youtube_link})")
lines.append("")
lines.extend(["---", ""])
return "\n".join(lines)
async def get_service_health(self) -> Dict[str, Any]:
"""Get RAG chat service health status.
Returns:
Service health information
"""
health = {
"service": "rag_chat",
"status": "healthy",
"timestamp": datetime.now().isoformat()
}
try:
# Test ChromaDB
collections = self.chroma_client.list_collections()
health["chromadb_status"] = "connected"
health["collections_count"] = len(collections)
# Test embedding model
test_embedding = self.embedding_model.encode(["test"], convert_to_tensor=False)
health["embedding_model_status"] = "loaded"
health["embedding_dimension"] = len(test_embedding[0])
# Active sessions
health["active_sessions"] = len(self.chat_sessions)
# Test AI service
if self.ai_service:
ai_health = await self.ai_service.test_connection()
health["ai_service_status"] = ai_health["status"]
else:
health["ai_service_status"] = "not_configured"
health["status"] = "degraded"
except Exception as e:
health["status"] = "error"
health["error"] = str(e)
return health