""" Tests for Phase 11: retry with backoff utility. """ import asyncio import time from unittest.mock import MagicMock, patch import pytest from src.utils.retry import ( RateLimitError, async_retry_with_backoff, check_response_retryable, retry_with_backoff, ) class TestRetryWithBackoff: def test_success_on_first_try(self): call_count = {"n": 0} @retry_with_backoff(max_retries=3, base_delay=0.01) def fn(): call_count["n"] += 1 return "ok" result = fn() assert result == "ok" assert call_count["n"] == 1 def test_retries_on_exception(self): call_count = {"n": 0} @retry_with_backoff(max_retries=2, base_delay=0.01) def fn(): call_count["n"] += 1 if call_count["n"] < 3: raise ConnectionError("transient") return "ok" with patch("src.utils.retry.time.sleep"): result = fn() assert result == "ok" assert call_count["n"] == 3 def test_raises_after_max_retries(self): @retry_with_backoff(max_retries=2, base_delay=0.01) def fn(): raise ConnectionError("always fails") with patch("src.utils.retry.time.sleep"): with pytest.raises(ConnectionError): fn() def test_exponential_delay(self): sleeps = [] @retry_with_backoff(max_retries=3, base_delay=1.0) def fn(): raise ValueError("fail") with patch("src.utils.retry.time.sleep", side_effect=lambda d: sleeps.append(d)): with pytest.raises(ValueError): fn() assert len(sleeps) == 3 assert sleeps[0] == 1.0 assert sleeps[1] == 2.0 assert sleeps[2] == 4.0 def test_max_delay_capped(self): sleeps = [] @retry_with_backoff(max_retries=5, base_delay=10.0, max_delay=15.0) def fn(): raise ValueError("fail") with patch("src.utils.retry.time.sleep", side_effect=lambda d: sleeps.append(d)): with pytest.raises(ValueError): fn() assert all(d <= 15.0 for d in sleeps) def test_only_retries_specified_exceptions(self): call_count = {"n": 0} @retry_with_backoff(max_retries=3, base_delay=0.01, retryable_exceptions=(ConnectionError,)) def fn(): call_count["n"] += 1 raise ValueError("not retryable") with pytest.raises(ValueError): fn() assert call_count["n"] == 1 # no retries for ValueError class TestAsyncRetryWithBackoff: def test_async_success_on_first_try(self): call_count = {"n": 0} @async_retry_with_backoff(max_retries=3, base_delay=0.01) async def fn(): call_count["n"] += 1 return "ok" result = asyncio.get_event_loop().run_until_complete(fn()) assert result == "ok" assert call_count["n"] == 1 def test_async_retries_on_exception(self): call_count = {"n": 0} @async_retry_with_backoff(max_retries=2, base_delay=0.01) async def fn(): call_count["n"] += 1 if call_count["n"] < 3: raise ConnectionError("transient") return "ok" with patch("src.utils.retry.asyncio.sleep", new=asyncio.coroutine(lambda d: None)): result = asyncio.get_event_loop().run_until_complete(fn()) assert result == "ok" def test_async_raises_after_max_retries(self): @async_retry_with_backoff(max_retries=1, base_delay=0.01) async def fn(): raise ConnectionError("always fails") with patch("src.utils.retry.asyncio.sleep", new=asyncio.coroutine(lambda d: None)): with pytest.raises(ConnectionError): asyncio.get_event_loop().run_until_complete(fn()) class TestCheckResponseRetryable: def test_429_is_retryable(self): assert check_response_retryable(429) is True def test_503_is_retryable(self): assert check_response_retryable(503) is True def test_200_not_retryable(self): assert check_response_retryable(200) is False def test_400_not_retryable(self): assert check_response_retryable(400) is False def test_404_not_retryable(self): assert check_response_retryable(404) is False