clean-tracks/src/core/audio_handler.py

291 lines
9.5 KiB
Python

"""
Audio file handling module for Clean-Tracks.
This module provides functionality for loading, saving, and manipulating
audio files in various formats.
"""
import os
from pathlib import Path
from typing import Optional, Union, Dict, Any
from enum import Enum
import logging
import numpy as np
from pydub import AudioSegment
from pydub.exceptions import CouldntDecodeError
logger = logging.getLogger(__name__)
class AudioFormat(Enum):
"""Supported audio file formats."""
MP3 = "mp3"
WAV = "wav"
FLAC = "flac"
M4A = "m4a"
OGG = "ogg"
WMA = "wma"
AAC = "aac"
@classmethod
def from_extension(cls, extension: str) -> Optional['AudioFormat']:
"""Get AudioFormat from file extension."""
ext = extension.lower().lstrip('.')
for format_type in cls:
if format_type.value == ext:
return format_type
return None
class AudioFile:
"""
Represents an audio file with methods for loading, saving, and processing.
Attributes:
file_path: Path to the audio file
audio_segment: PyDub AudioSegment object
format: Audio format
metadata: File metadata dictionary
"""
def __init__(self, file_path: Union[str, Path]):
"""
Initialize AudioFile with a file path.
Args:
file_path: Path to the audio file
Raises:
FileNotFoundError: If the file doesn't exist
ValueError: If the file format is not supported
"""
self.file_path = Path(file_path)
if not self.file_path.exists():
raise FileNotFoundError(f"Audio file not found: {file_path}")
self.format = self._detect_format()
if not self.format:
raise ValueError(f"Unsupported audio format: {self.file_path.suffix}")
self.audio_segment: Optional[AudioSegment] = None
self.metadata: Dict[str, Any] = {}
self._load_metadata()
def _detect_format(self) -> Optional[AudioFormat]:
"""Detect the audio format from file extension."""
return AudioFormat.from_extension(self.file_path.suffix)
def _load_metadata(self) -> None:
"""Load metadata from the audio file."""
self.metadata = {
'filename': self.file_path.name,
'format': self.format.value,
'size_bytes': self.file_path.stat().st_size,
'path': str(self.file_path.absolute())
}
def load(self, lazy: bool = False) -> 'AudioFile':
"""
Load the audio file into memory.
Args:
lazy: If True, defer loading until needed
Returns:
Self for method chaining
Raises:
CouldntDecodeError: If the file cannot be decoded
"""
if lazy:
logger.debug(f"Lazy loading enabled for {self.file_path}")
return self
try:
logger.info(f"Loading audio file: {self.file_path}")
self.audio_segment = AudioSegment.from_file(
str(self.file_path),
format=self.format.value
)
# Update metadata with audio properties
self.metadata.update({
'duration_ms': len(self.audio_segment),
'duration_seconds': len(self.audio_segment) / 1000.0,
'channels': self.audio_segment.channels,
'sample_rate': self.audio_segment.frame_rate,
'sample_width': self.audio_segment.sample_width,
'bitrate': self._estimate_bitrate()
})
logger.info(f"Successfully loaded {self.file_path.name}: "
f"{self.metadata['duration_seconds']:.2f}s, "
f"{self.metadata['sample_rate']}Hz")
except CouldntDecodeError as e:
logger.error(f"Failed to decode audio file: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error loading audio file: {e}")
raise
return self
def _estimate_bitrate(self) -> Optional[int]:
"""Estimate the bitrate of the audio file."""
if not self.audio_segment:
return None
duration_seconds = len(self.audio_segment) / 1000.0
if duration_seconds <= 0:
return None
file_size_bits = self.metadata['size_bytes'] * 8
return int(file_size_bits / duration_seconds)
def save(self,
output_path: Union[str, Path],
format: Optional[AudioFormat] = None,
parameters: Optional[Dict[str, Any]] = None) -> Path:
"""
Save the audio file to disk.
Args:
output_path: Path where the file should be saved
format: Output format (uses original format if not specified)
parameters: Additional export parameters (bitrate, codec, etc.)
Returns:
Path to the saved file
Raises:
RuntimeError: If audio_segment is not loaded
"""
if not self.audio_segment:
raise RuntimeError("Audio not loaded. Call load() first.")
output_path = Path(output_path)
output_format = format or self.format
# Ensure output path has correct extension
if output_path.suffix.lower() != f".{output_format.value}":
output_path = output_path.with_suffix(f".{output_format.value}")
# Create output directory if it doesn't exist
output_path.parent.mkdir(parents=True, exist_ok=True)
# Default export parameters
export_params = {
'format': output_format.value,
'bitrate': '192k' if output_format == AudioFormat.MP3 else None
}
# Update with user parameters
if parameters:
export_params.update(parameters)
# Remove None values
export_params = {k: v for k, v in export_params.items() if v is not None}
logger.info(f"Saving audio to {output_path} as {output_format.value}")
try:
self.audio_segment.export(str(output_path), **export_params)
logger.info(f"Successfully saved to {output_path}")
except Exception as e:
logger.error(f"Failed to save audio file: {e}")
raise
return output_path
def get_audio_array(self) -> np.ndarray:
"""
Get the audio data as a numpy array.
Returns:
Numpy array of audio samples
Raises:
RuntimeError: If audio_segment is not loaded
"""
if not self.audio_segment:
raise RuntimeError("Audio not loaded. Call load() first.")
# Convert to mono for processing
mono_audio = self.audio_segment.set_channels(1)
# Get raw audio data
samples = np.array(mono_audio.get_array_of_samples())
# Normalize to [-1, 1] range
if mono_audio.sample_width == 1:
samples = samples / 128.0 - 1.0
elif mono_audio.sample_width == 2:
samples = samples / 32768.0
elif mono_audio.sample_width == 4:
samples = samples / 2147483648.0
return samples
def get_sample_rate(self) -> int:
"""Get the sample rate of the audio."""
if not self.audio_segment:
raise RuntimeError("Audio not loaded. Call load() first.")
return self.audio_segment.frame_rate
def get_duration_seconds(self) -> float:
"""Get the duration of the audio in seconds."""
if not self.audio_segment:
raise RuntimeError("Audio not loaded. Call load() first.")
return len(self.audio_segment) / 1000.0
def slice(self, start_ms: int, end_ms: int) -> AudioSegment:
"""
Get a slice of the audio.
Args:
start_ms: Start time in milliseconds
end_ms: End time in milliseconds
Returns:
AudioSegment of the sliced audio
"""
if not self.audio_segment:
raise RuntimeError("Audio not loaded. Call load() first.")
return self.audio_segment[start_ms:end_ms]
def replace_segment(self,
start_ms: int,
end_ms: int,
replacement: AudioSegment) -> None:
"""
Replace a segment of the audio.
Args:
start_ms: Start time in milliseconds
end_ms: End time in milliseconds
replacement: AudioSegment to insert
"""
if not self.audio_segment:
raise RuntimeError("Audio not loaded. Call load() first.")
# Split the audio
before = self.audio_segment[:start_ms]
after = self.audio_segment[end_ms:]
# Reconstruct with replacement
self.audio_segment = before + replacement + after
def __repr__(self) -> str:
"""String representation of AudioFile."""
return (f"AudioFile(path={self.file_path.name}, "
f"format={self.format.value if self.format else 'unknown'}, "
f"loaded={self.audio_segment is not None})")
def __len__(self) -> int:
"""Get the length of the audio in milliseconds."""
if not self.audio_segment:
return 0
return len(self.audio_segment)