stygian_graph/adapters/
resilience.rs

1//! Resilience adapters
2
3use crate::domain::error::Result;
4use crate::ports::{CircuitBreaker, CircuitState, RateLimitConfig, RateLimiter};
5use async_trait::async_trait;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11/// Circuit breaker implementation with configurable thresholds
12///
13/// Implements the circuit breaker pattern to prevent cascading failures.
14/// Tracks failure rate and automatically opens the circuit when threshold is exceeded.
15///
16/// # State Machine
17///
18/// - **Closed**: Normal operation, all requests pass through
19/// - **Open**: Too many failures, all requests fail fast
20/// - **`HalfOpen`**: Testing recovery, limited requests allowed
21///
22/// # Example
23///
24/// ```
25/// use stygian_graph::adapters::resilience::CircuitBreakerImpl;
26/// use stygian_graph::ports::{CircuitBreaker, CircuitState};
27///
28/// let cb = CircuitBreakerImpl::new(5, std::time::Duration::from_secs(30));
29/// // Record some failures
30/// cb.record_failure();
31/// cb.record_failure();
32/// // Check state
33/// assert!(matches!(cb.state(), CircuitState::Closed | CircuitState::Open));
34/// ```
35pub struct CircuitBreakerImpl {
36    state: Arc<RwLock<CircuitBreakerState>>,
37    failure_threshold: u32,
38    timeout: Duration,
39}
40
41#[derive(Debug)]
42struct CircuitBreakerState {
43    current: CircuitState,
44    failure_count: u32,
45    last_failure_time: Option<Instant>,
46}
47
48impl CircuitBreakerImpl {
49    /// Create a new circuit breaker
50    ///
51    /// # Arguments
52    ///
53    /// * `failure_threshold` - Number of failures before opening circuit
54    /// * `timeout` - Duration to wait before attempting reset
55    pub fn new(failure_threshold: u32, timeout: Duration) -> Self {
56        Self {
57            state: Arc::new(RwLock::new(CircuitBreakerState {
58                current: CircuitState::Closed,
59                failure_count: 0,
60                last_failure_time: None,
61            })),
62            failure_threshold,
63            timeout,
64        }
65    }
66
67    /// Check if timeout has elapsed and circuit can transition to `HalfOpen`
68    fn should_attempt_reset(&self, state: &CircuitBreakerState) -> bool {
69        if state.current != CircuitState::Open {
70            return false;
71        }
72
73        state
74            .last_failure_time
75            .is_some_and(|last_failure| last_failure.elapsed() >= self.timeout)
76    }
77}
78
79impl CircuitBreaker for CircuitBreakerImpl {
80    fn state(&self) -> CircuitState {
81        let state = self.state.read();
82        state.current
83    }
84
85    fn record_success(&self) {
86        let mut state = self.state.write();
87        // Success resets failures and closes circuit
88        state.failure_count = 0;
89        state.current = CircuitState::Closed;
90        state.last_failure_time = None;
91    }
92
93    fn record_failure(&self) {
94        let mut state = self.state.write();
95        state.failure_count += 1;
96        state.last_failure_time = Some(Instant::now());
97
98        // Open circuit if threshold exceeded
99        if state.failure_count >= self.failure_threshold {
100            state.current = CircuitState::Open;
101        }
102    }
103
104    fn attempt_reset(&self) -> bool {
105        let mut state = self.state.write();
106
107        if self.should_attempt_reset(&state) {
108            state.current = CircuitState::HalfOpen;
109            state.failure_count = 0;
110            true
111        } else {
112            false
113        }
114    }
115}
116
117/// No-op circuit breaker for testing
118///
119/// Always reports Closed state and ignores all state transitions.
120/// Useful for testing scenarios where circuit breaker behavior should be disabled.
121///
122/// # Example
123///
124/// ```
125/// use stygian_graph::adapters::resilience::NoopCircuitBreaker;
126/// use stygian_graph::ports::{CircuitBreaker, CircuitState};
127///
128/// let cb = NoopCircuitBreaker;
129/// cb.record_failure();
130/// assert_eq!(cb.state(), CircuitState::Closed);
131/// ```
132#[derive(Debug, Default, Clone, Copy)]
133pub struct NoopCircuitBreaker;
134
135impl CircuitBreaker for NoopCircuitBreaker {
136    fn state(&self) -> CircuitState {
137        CircuitState::Closed
138    }
139
140    fn record_success(&self) {
141        // No-op
142    }
143
144    fn record_failure(&self) {
145        // No-op
146    }
147
148    fn attempt_reset(&self) -> bool {
149        false
150    }
151}
152
153/// Token bucket rate limiter implementation
154///
155/// Implements rate limiting using the token bucket algorithm.
156/// Supports per-key rate limiting for multi-tenant scenarios.
157///
158/// # Algorithm
159///
160/// - Each key has a bucket with a maximum number of tokens
161/// - Tokens are consumed on each request
162/// - Tokens regenerate over time based on the configured window
163/// - Requests are rejected when bucket is empty
164///
165/// # Example
166///
167/// ```
168/// use stygian_graph::adapters::resilience::TokenBucketRateLimiter;
169/// use stygian_graph::ports::{RateLimiter, RateLimitConfig};
170/// use std::time::Duration;
171///
172/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
173/// let config = RateLimitConfig {
174///     max_requests: 10,
175///     window: Duration::from_secs(60),
176/// };
177/// let limiter = TokenBucketRateLimiter::new(config);
178/// assert!(limiter.check_rate_limit("api:test").await.unwrap());
179/// # });
180/// ```
181pub struct TokenBucketRateLimiter {
182    config: RateLimitConfig,
183    buckets: Arc<RwLock<HashMap<String, TokenBucket>>>,
184}
185
186#[derive(Debug)]
187struct TokenBucket {
188    tokens: u32,
189    last_refill: Instant,
190}
191
192impl TokenBucketRateLimiter {
193    /// Create a new token bucket rate limiter
194    pub fn new(config: RateLimitConfig) -> Self {
195        Self {
196            config,
197            buckets: Arc::new(RwLock::new(HashMap::new())),
198        }
199    }
200
201    /// Refill tokens based on elapsed time
202    fn refill_tokens(&self, bucket: &mut TokenBucket) {
203        let elapsed = bucket.last_refill.elapsed();
204        let refill_rate = f64::from(self.config.max_requests) / self.config.window.as_secs_f64();
205        #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
206        let tokens_to_add = (elapsed.as_secs_f64() * refill_rate) as u32;
207
208        if tokens_to_add > 0 {
209            bucket.tokens = (bucket.tokens + tokens_to_add).min(self.config.max_requests);
210            bucket.last_refill = Instant::now();
211        }
212    }
213}
214
215#[async_trait]
216impl RateLimiter for TokenBucketRateLimiter {
217    #[allow(clippy::significant_drop_tightening)]
218    async fn check_rate_limit(&self, key: &str) -> Result<bool> {
219        let has_tokens = {
220            let mut buckets = self.buckets.write();
221            let bucket = buckets
222                .entry(key.to_string())
223                .or_insert_with(|| TokenBucket {
224                    tokens: self.config.max_requests,
225                    last_refill: Instant::now(),
226                });
227            self.refill_tokens(bucket);
228            bucket.tokens > 0
229        };
230        Ok(has_tokens)
231    }
232
233    async fn record_request(&self, key: &str) -> Result<()> {
234        {
235            let mut buckets = self.buckets.write();
236            if let Some(bucket) = buckets.get_mut(key)
237                && bucket.tokens > 0
238            {
239                bucket.tokens -= 1;
240            }
241        }
242        Ok(())
243    }
244}
245
246/// No-op rate limiter for testing
247///
248/// Always allows requests and ignores all rate limit tracking.
249/// Useful for testing scenarios where rate limiting should be disabled.
250///
251/// # Example
252///
253/// ```
254/// use stygian_graph::adapters::resilience::NoopRateLimiter;
255/// use stygian_graph::ports::RateLimiter;
256///
257/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
258/// let limiter = NoopRateLimiter;
259/// assert!(limiter.check_rate_limit("any_key").await.unwrap());
260/// # });
261/// ```
262#[derive(Debug, Default, Clone, Copy)]
263pub struct NoopRateLimiter;
264
265#[async_trait]
266impl RateLimiter for NoopRateLimiter {
267    async fn check_rate_limit(&self, _key: &str) -> Result<bool> {
268        Ok(true)
269    }
270
271    async fn record_request(&self, _key: &str) -> Result<()> {
272        Ok(())
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_circuit_breaker_closes_on_success() {
282        let cb = CircuitBreakerImpl::new(3, Duration::from_secs(5));
283        cb.record_failure();
284        cb.record_failure();
285        assert_eq!(cb.state(), CircuitState::Closed);
286        cb.record_success();
287        assert_eq!(cb.state(), CircuitState::Closed);
288    }
289
290    #[test]
291    fn test_circuit_breaker_opens_on_threshold() {
292        let cb = CircuitBreakerImpl::new(3, Duration::from_secs(5));
293        cb.record_failure();
294        cb.record_failure();
295        cb.record_failure();
296        assert_eq!(cb.state(), CircuitState::Open);
297    }
298
299    #[test]
300    fn test_noop_circuit_breaker_always_closed() {
301        let cb = NoopCircuitBreaker;
302        assert_eq!(cb.state(), CircuitState::Closed);
303        cb.record_failure();
304        cb.record_failure();
305        cb.record_failure();
306        assert_eq!(cb.state(), CircuitState::Closed);
307    }
308
309    #[tokio::test]
310    async fn test_rate_limiter_allows_within_limit() -> Result<()> {
311        let config = RateLimitConfig {
312            max_requests: 10,
313            window: Duration::from_secs(60),
314        };
315        let limiter = TokenBucketRateLimiter::new(config);
316
317        assert!(limiter.check_rate_limit("test").await?);
318        limiter.record_request("test").await?;
319        assert!(limiter.check_rate_limit("test").await?);
320        Ok(())
321    }
322
323    #[tokio::test]
324    async fn test_noop_rate_limiter_always_allows() -> Result<()> {
325        let limiter = NoopRateLimiter;
326        assert!(limiter.check_rate_limit("any").await?);
327        limiter.record_request("any").await?;
328        assert!(limiter.check_rate_limit("any").await?);
329        Ok(())
330    }
331}
332
333// ─── Exponential Backoff Retry ────────────────────────────────────────────────
334
335/// Policy controlling exponential backoff retry behaviour.
336///
337/// Delays follow the formula: `base_delay * 2^attempt + rand(0..jitter_ms)`.
338/// The computed delay is capped at `max_delay`.
339///
340/// # Example
341///
342/// ```
343/// use stygian_graph::adapters::resilience::RetryPolicy;
344/// use std::time::Duration;
345///
346/// let policy = RetryPolicy::new(3, Duration::from_millis(100), Duration::from_secs(10));
347/// ```
348#[derive(Debug, Clone)]
349pub struct RetryPolicy {
350    /// Maximum number of retry attempts (not counting the initial call)
351    pub max_attempts: u32,
352    /// Base delay for the first retry
353    pub base_delay: Duration,
354    /// Maximum delay cap
355    pub max_delay: Duration,
356    /// Additional random jitter ceiling (milliseconds)
357    pub jitter_ms: u64,
358}
359
360impl RetryPolicy {
361    /// Create a new retry policy.
362    ///
363    /// # Example
364    ///
365    /// ```
366    /// use stygian_graph::adapters::resilience::RetryPolicy;
367    /// use std::time::Duration;
368    ///
369    /// let p = RetryPolicy::new(5, Duration::from_millis(200), Duration::from_secs(30));
370    /// assert_eq!(p.max_attempts, 5);
371    /// ```
372    pub const fn new(max_attempts: u32, base_delay: Duration, max_delay: Duration) -> Self {
373        Self {
374            max_attempts,
375            base_delay,
376            max_delay,
377            jitter_ms: 50,
378        }
379    }
380
381    /// Override the jitter ceiling in milliseconds.
382    ///
383    /// # Example
384    ///
385    /// ```
386    /// use stygian_graph::adapters::resilience::RetryPolicy;
387    /// use std::time::Duration;
388    ///
389    /// let p = RetryPolicy::new(3, Duration::from_millis(100), Duration::from_secs(5))
390    ///     .with_jitter_ms(100);
391    /// assert_eq!(p.jitter_ms, 100);
392    /// ```
393    #[must_use]
394    pub const fn with_jitter_ms(mut self, jitter_ms: u64) -> Self {
395        self.jitter_ms = jitter_ms;
396        self
397    }
398
399    /// Compute the sleep duration for a given attempt index (0-based).
400    ///
401    /// # Example
402    ///
403    /// ```
404    /// use stygian_graph::adapters::resilience::RetryPolicy;
405    /// use std::time::Duration;
406    ///
407    /// let p = RetryPolicy::new(3, Duration::from_millis(100), Duration::from_secs(10))
408    ///     .with_jitter_ms(0);
409    /// // attempt 0 → 100 ms, attempt 1 → 200 ms, attempt 2 → 400 ms
410    /// assert_eq!(p.delay_for(0), Duration::from_millis(100));
411    /// assert_eq!(p.delay_for(1), Duration::from_millis(200));
412    /// ```
413    pub fn delay_for(&self, attempt: u32) -> Duration {
414        let factor = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
415        #[allow(clippy::cast_possible_truncation)]
416        let base_ms = self.base_delay.as_millis() as u64;
417        let jitter = if self.jitter_ms > 0 {
418            // Deterministic-enough without pulling in `rand`: use mix of attempt
419            // and current nanos as a low-cost entropy source.
420            let seed = u64::from(
421                std::time::SystemTime::now()
422                    .duration_since(std::time::UNIX_EPOCH)
423                    .unwrap_or_default()
424                    .subsec_nanos(),
425            );
426            (seed
427                .wrapping_mul(6_364_136_223_846_793_005)
428                .wrapping_add(1_442_695_040_888_963_407)
429                >> 33)
430                % self.jitter_ms
431        } else {
432            0
433        };
434        let ms = base_ms.saturating_mul(factor).saturating_add(jitter);
435        let delay = Duration::from_millis(ms);
436        delay.min(self.max_delay)
437    }
438}
439
440impl Default for RetryPolicy {
441    fn default() -> Self {
442        Self::new(3, Duration::from_millis(200), Duration::from_secs(30))
443    }
444}
445
446/// Execute an async operation with automatic retry according to a [`RetryPolicy`].
447///
448/// Returns the first `Ok` value, or the last `Err` after all attempts are exhausted.
449/// Each retry sleeps for an exponentially increasing delay with jitter.
450///
451/// # Example
452///
453/// ```
454/// use stygian_graph::adapters::resilience::{RetryPolicy, retry};
455/// use std::sync::atomic::{AtomicU32, Ordering};
456/// use std::sync::Arc;
457/// use std::time::Duration;
458///
459/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
460/// let attempts = Arc::new(AtomicU32::new(0));
461/// let policy = RetryPolicy::new(3, Duration::from_millis(1), Duration::from_millis(10))
462///     .with_jitter_ms(0);
463///
464/// let result = retry(&policy, || {
465///     let counter = Arc::clone(&attempts);
466///     async move {
467///         let n = counter.fetch_add(1, Ordering::SeqCst);
468///         if n < 2 { Err("not yet".to_string()) } else { Ok(n) }
469///     }
470/// }).await;
471///
472/// assert!(result.is_ok());
473/// assert_eq!(attempts.load(Ordering::SeqCst), 3);
474/// # });
475/// ```
476pub async fn retry<F, Fut, T, E>(policy: &RetryPolicy, mut f: F) -> std::result::Result<T, E>
477where
478    F: FnMut() -> Fut,
479    Fut: std::future::Future<Output = std::result::Result<T, E>>,
480{
481    let mut result = f().await;
482    for attempt in 1..=policy.max_attempts {
483        if result.is_ok() {
484            return result;
485        }
486        tokio::time::sleep(policy.delay_for(attempt - 1)).await;
487        result = f().await;
488    }
489    result
490}
491
492#[cfg(test)]
493mod retry_tests {
494    use super::*;
495    use std::sync::Arc;
496    use std::sync::atomic::{AtomicU32, Ordering};
497
498    #[test]
499    fn delay_for_doubles() {
500        let p = RetryPolicy::new(4, Duration::from_millis(100), Duration::from_secs(60))
501            .with_jitter_ms(0);
502        assert_eq!(p.delay_for(0), Duration::from_millis(100));
503        assert_eq!(p.delay_for(1), Duration::from_millis(200));
504        assert_eq!(p.delay_for(2), Duration::from_millis(400));
505        assert_eq!(p.delay_for(3), Duration::from_millis(800));
506    }
507
508    #[test]
509    fn delay_capped_at_max() {
510        let p = RetryPolicy::new(10, Duration::from_millis(1000), Duration::from_secs(3))
511            .with_jitter_ms(0);
512        // 1000 * 2^4 = 16_000 ms, capped at 3_000 ms
513        assert_eq!(p.delay_for(4), Duration::from_secs(3));
514    }
515
516    #[tokio::test]
517    async fn retry_succeeds_on_first_try() {
518        let policy = RetryPolicy::new(3, Duration::from_millis(1), Duration::from_millis(10))
519            .with_jitter_ms(0);
520        let result: std::result::Result<i32, &str> = retry(&policy, || async { Ok(42) }).await;
521        assert_eq!(result.ok(), Some(42));
522    }
523
524    #[tokio::test]
525    async fn retry_retries_until_success() {
526        let counter = Arc::new(AtomicU32::new(0));
527        let policy = RetryPolicy::new(4, Duration::from_millis(1), Duration::from_millis(50))
528            .with_jitter_ms(0);
529
530        let result: std::result::Result<u32, String> = retry(&policy, || {
531            let c = Arc::clone(&counter);
532            async move {
533                let n = c.fetch_add(1, Ordering::SeqCst);
534                if n < 3 {
535                    Err(format!("fail {n}"))
536                } else {
537                    Ok(n)
538                }
539            }
540        })
541        .await;
542
543        assert!(result.is_ok());
544        assert_eq!(counter.load(Ordering::SeqCst), 4); // 3 failures + 1 success
545    }
546
547    #[tokio::test]
548    async fn retry_exhausts_and_returns_last_error() {
549        let policy = RetryPolicy::new(2, Duration::from_millis(1), Duration::from_millis(10))
550            .with_jitter_ms(0);
551        let counter = Arc::new(AtomicU32::new(0));
552
553        let result: std::result::Result<(), String> = retry(&policy, || {
554            let c = Arc::clone(&counter);
555            async move {
556                c.fetch_add(1, Ordering::SeqCst);
557                Err("always fails".to_string())
558            }
559        })
560        .await;
561
562        assert!(result.is_err());
563        assert_eq!(counter.load(Ordering::SeqCst), 3); // initial + 2 retries
564    }
565}