trax/test_mps.py

80 lines
2.5 KiB
Python

#!/usr/bin/env python3
"""Test script to verify PyTorch MPS (Metal Performance Shaders) GPU acceleration."""
import torch
import time
def test_mps_acceleration():
"""Compare performance between CPU and MPS (GPU) for matrix operations."""
# Check MPS availability
if not torch.backends.mps.is_available():
print("❌ MPS is not available on this system")
return
print("✅ MPS (Metal Performance Shaders) is available!")
print(f"PyTorch version: {torch.__version__}")
print("-" * 50)
# Test parameters
size = 4096
iterations = 100
# Create random matrices
print(f"\n📊 Testing matrix multiplication ({size}x{size})...")
a = torch.randn(size, size)
b = torch.randn(size, size)
# CPU benchmark
print("\n🖥️ CPU Performance:")
start = time.time()
for _ in range(iterations):
c_cpu = torch.matmul(a, b)
cpu_time = time.time() - start
print(f" Time: {cpu_time:.3f} seconds")
# MPS (GPU) benchmark
print("\n🚀 MPS (GPU) Performance:")
device = torch.device("mps")
a_mps = a.to(device)
b_mps = b.to(device)
# Warm up GPU
for _ in range(10):
_ = torch.matmul(a_mps, b_mps)
torch.mps.synchronize() # Ensure GPU operations complete
start = time.time()
for _ in range(iterations):
c_mps = torch.matmul(a_mps, b_mps)
torch.mps.synchronize() # Ensure all operations complete
mps_time = time.time() - start
print(f" Time: {mps_time:.3f} seconds")
# Results
speedup = cpu_time / mps_time
print("\n📈 Results:")
print(f" Speedup: {speedup:.2f}x faster on MPS")
print(f" {'🎉 MPS acceleration is working!' if speedup > 1 else '⚠️ No acceleration detected'}")
# Memory info
print("\n💾 MPS Memory Info:")
print(f" Allocated: {torch.mps.current_allocated_memory() / 1024**2:.2f} MB")
print(f" Driver: {torch.mps.driver_allocated_memory() / 1024**2:.2f} MB")
# Test moving tensors between devices
print("\n🔄 Testing tensor movement between devices...")
test_tensor = torch.randn(1000, 1000)
# To MPS
mps_tensor = test_tensor.to('mps')
print(f" ✅ Moved to MPS: {mps_tensor.device}")
# Back to CPU
cpu_tensor = mps_tensor.cpu()
print(f" ✅ Moved to CPU: {cpu_tensor.device}")
print("\n✨ MPS setup complete and working correctly!")
if __name__ == "__main__":
test_mps_acceleration()