1use crate::domain::error::Result;
4use crate::ports::{CircuitBreaker, CircuitState, RateLimitConfig, RateLimiter};
5use async_trait::async_trait;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11pub struct CircuitBreakerImpl {
36 state: Arc<RwLock<CircuitBreakerState>>,
37 failure_threshold: u32,
38 timeout: Duration,
39}
40
41#[derive(Debug)]
42struct CircuitBreakerState {
43 current: CircuitState,
44 failure_count: u32,
45 last_failure_time: Option<Instant>,
46}
47
48impl CircuitBreakerImpl {
49 pub fn new(failure_threshold: u32, timeout: Duration) -> Self {
56 Self {
57 state: Arc::new(RwLock::new(CircuitBreakerState {
58 current: CircuitState::Closed,
59 failure_count: 0,
60 last_failure_time: None,
61 })),
62 failure_threshold,
63 timeout,
64 }
65 }
66
67 fn should_attempt_reset(&self, state: &CircuitBreakerState) -> bool {
69 if state.current != CircuitState::Open {
70 return false;
71 }
72
73 state
74 .last_failure_time
75 .is_some_and(|last_failure| last_failure.elapsed() >= self.timeout)
76 }
77}
78
79impl CircuitBreaker for CircuitBreakerImpl {
80 fn state(&self) -> CircuitState {
81 let state = self.state.read();
82 state.current
83 }
84
85 fn record_success(&self) {
86 let mut state = self.state.write();
87 state.failure_count = 0;
89 state.current = CircuitState::Closed;
90 state.last_failure_time = None;
91 }
92
93 fn record_failure(&self) {
94 let mut state = self.state.write();
95 state.failure_count += 1;
96 state.last_failure_time = Some(Instant::now());
97
98 if state.failure_count >= self.failure_threshold {
100 state.current = CircuitState::Open;
101 }
102 }
103
104 fn attempt_reset(&self) -> bool {
105 let mut state = self.state.write();
106
107 if self.should_attempt_reset(&state) {
108 state.current = CircuitState::HalfOpen;
109 state.failure_count = 0;
110 true
111 } else {
112 false
113 }
114 }
115}
116
117#[derive(Debug, Default, Clone, Copy)]
133pub struct NoopCircuitBreaker;
134
135impl CircuitBreaker for NoopCircuitBreaker {
136 fn state(&self) -> CircuitState {
137 CircuitState::Closed
138 }
139
140 fn record_success(&self) {
141 }
143
144 fn record_failure(&self) {
145 }
147
148 fn attempt_reset(&self) -> bool {
149 false
150 }
151}
152
153pub struct TokenBucketRateLimiter {
182 config: RateLimitConfig,
183 buckets: Arc<RwLock<HashMap<String, TokenBucket>>>,
184}
185
186#[derive(Debug)]
187struct TokenBucket {
188 tokens: u32,
189 last_refill: Instant,
190}
191
192impl TokenBucketRateLimiter {
193 pub fn new(config: RateLimitConfig) -> Self {
195 Self {
196 config,
197 buckets: Arc::new(RwLock::new(HashMap::new())),
198 }
199 }
200
201 fn refill_tokens(&self, bucket: &mut TokenBucket) {
203 let elapsed = bucket.last_refill.elapsed();
204 let refill_rate = f64::from(self.config.max_requests) / self.config.window.as_secs_f64();
205 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
206 let tokens_to_add = (elapsed.as_secs_f64() * refill_rate) as u32;
207
208 if tokens_to_add > 0 {
209 bucket.tokens = (bucket.tokens + tokens_to_add).min(self.config.max_requests);
210 bucket.last_refill = Instant::now();
211 }
212 }
213}
214
215#[async_trait]
216impl RateLimiter for TokenBucketRateLimiter {
217 #[allow(clippy::significant_drop_tightening)]
218 async fn check_rate_limit(&self, key: &str) -> Result<bool> {
219 let has_tokens = {
220 let mut buckets = self.buckets.write();
221 let bucket = buckets
222 .entry(key.to_string())
223 .or_insert_with(|| TokenBucket {
224 tokens: self.config.max_requests,
225 last_refill: Instant::now(),
226 });
227 self.refill_tokens(bucket);
228 bucket.tokens > 0
229 };
230 Ok(has_tokens)
231 }
232
233 async fn record_request(&self, key: &str) -> Result<()> {
234 {
235 let mut buckets = self.buckets.write();
236 if let Some(bucket) = buckets.get_mut(key)
237 && bucket.tokens > 0
238 {
239 bucket.tokens -= 1;
240 }
241 }
242 Ok(())
243 }
244}
245
246#[derive(Debug, Default, Clone, Copy)]
263pub struct NoopRateLimiter;
264
265#[async_trait]
266impl RateLimiter for NoopRateLimiter {
267 async fn check_rate_limit(&self, _key: &str) -> Result<bool> {
268 Ok(true)
269 }
270
271 async fn record_request(&self, _key: &str) -> Result<()> {
272 Ok(())
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_circuit_breaker_closes_on_success() {
282 let cb = CircuitBreakerImpl::new(3, Duration::from_secs(5));
283 cb.record_failure();
284 cb.record_failure();
285 assert_eq!(cb.state(), CircuitState::Closed);
286 cb.record_success();
287 assert_eq!(cb.state(), CircuitState::Closed);
288 }
289
290 #[test]
291 fn test_circuit_breaker_opens_on_threshold() {
292 let cb = CircuitBreakerImpl::new(3, Duration::from_secs(5));
293 cb.record_failure();
294 cb.record_failure();
295 cb.record_failure();
296 assert_eq!(cb.state(), CircuitState::Open);
297 }
298
299 #[test]
300 fn test_noop_circuit_breaker_always_closed() {
301 let cb = NoopCircuitBreaker;
302 assert_eq!(cb.state(), CircuitState::Closed);
303 cb.record_failure();
304 cb.record_failure();
305 cb.record_failure();
306 assert_eq!(cb.state(), CircuitState::Closed);
307 }
308
309 #[tokio::test]
310 async fn test_rate_limiter_allows_within_limit() -> Result<()> {
311 let config = RateLimitConfig {
312 max_requests: 10,
313 window: Duration::from_secs(60),
314 };
315 let limiter = TokenBucketRateLimiter::new(config);
316
317 assert!(limiter.check_rate_limit("test").await?);
318 limiter.record_request("test").await?;
319 assert!(limiter.check_rate_limit("test").await?);
320 Ok(())
321 }
322
323 #[tokio::test]
324 async fn test_noop_rate_limiter_always_allows() -> Result<()> {
325 let limiter = NoopRateLimiter;
326 assert!(limiter.check_rate_limit("any").await?);
327 limiter.record_request("any").await?;
328 assert!(limiter.check_rate_limit("any").await?);
329 Ok(())
330 }
331}
332
333#[derive(Debug, Clone)]
349pub struct RetryPolicy {
350 pub max_attempts: u32,
352 pub base_delay: Duration,
354 pub max_delay: Duration,
356 pub jitter_ms: u64,
358}
359
360impl RetryPolicy {
361 pub const fn new(max_attempts: u32, base_delay: Duration, max_delay: Duration) -> Self {
373 Self {
374 max_attempts,
375 base_delay,
376 max_delay,
377 jitter_ms: 50,
378 }
379 }
380
381 #[must_use]
394 pub const fn with_jitter_ms(mut self, jitter_ms: u64) -> Self {
395 self.jitter_ms = jitter_ms;
396 self
397 }
398
399 pub fn delay_for(&self, attempt: u32) -> Duration {
414 let factor = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
415 #[allow(clippy::cast_possible_truncation)]
416 let base_ms = self.base_delay.as_millis() as u64;
417 let jitter = if self.jitter_ms > 0 {
418 let seed = u64::from(
421 std::time::SystemTime::now()
422 .duration_since(std::time::UNIX_EPOCH)
423 .unwrap_or_default()
424 .subsec_nanos(),
425 );
426 (seed
427 .wrapping_mul(6_364_136_223_846_793_005)
428 .wrapping_add(1_442_695_040_888_963_407)
429 >> 33)
430 % self.jitter_ms
431 } else {
432 0
433 };
434 let ms = base_ms.saturating_mul(factor).saturating_add(jitter);
435 let delay = Duration::from_millis(ms);
436 delay.min(self.max_delay)
437 }
438}
439
440impl Default for RetryPolicy {
441 fn default() -> Self {
442 Self::new(3, Duration::from_millis(200), Duration::from_secs(30))
443 }
444}
445
446pub async fn retry<F, Fut, T, E>(policy: &RetryPolicy, mut f: F) -> std::result::Result<T, E>
477where
478 F: FnMut() -> Fut,
479 Fut: std::future::Future<Output = std::result::Result<T, E>>,
480{
481 let mut result = f().await;
482 for attempt in 1..=policy.max_attempts {
483 if result.is_ok() {
484 return result;
485 }
486 tokio::time::sleep(policy.delay_for(attempt - 1)).await;
487 result = f().await;
488 }
489 result
490}
491
492#[cfg(test)]
493mod retry_tests {
494 use super::*;
495 use std::sync::Arc;
496 use std::sync::atomic::{AtomicU32, Ordering};
497
498 #[test]
499 fn delay_for_doubles() {
500 let p = RetryPolicy::new(4, Duration::from_millis(100), Duration::from_secs(60))
501 .with_jitter_ms(0);
502 assert_eq!(p.delay_for(0), Duration::from_millis(100));
503 assert_eq!(p.delay_for(1), Duration::from_millis(200));
504 assert_eq!(p.delay_for(2), Duration::from_millis(400));
505 assert_eq!(p.delay_for(3), Duration::from_millis(800));
506 }
507
508 #[test]
509 fn delay_capped_at_max() {
510 let p = RetryPolicy::new(10, Duration::from_millis(1000), Duration::from_secs(3))
511 .with_jitter_ms(0);
512 assert_eq!(p.delay_for(4), Duration::from_secs(3));
514 }
515
516 #[tokio::test]
517 async fn retry_succeeds_on_first_try() {
518 let policy = RetryPolicy::new(3, Duration::from_millis(1), Duration::from_millis(10))
519 .with_jitter_ms(0);
520 let result: std::result::Result<i32, &str> = retry(&policy, || async { Ok(42) }).await;
521 assert_eq!(result.ok(), Some(42));
522 }
523
524 #[tokio::test]
525 async fn retry_retries_until_success() {
526 let counter = Arc::new(AtomicU32::new(0));
527 let policy = RetryPolicy::new(4, Duration::from_millis(1), Duration::from_millis(50))
528 .with_jitter_ms(0);
529
530 let result: std::result::Result<u32, String> = retry(&policy, || {
531 let c = Arc::clone(&counter);
532 async move {
533 let n = c.fetch_add(1, Ordering::SeqCst);
534 if n < 3 {
535 Err(format!("fail {n}"))
536 } else {
537 Ok(n)
538 }
539 }
540 })
541 .await;
542
543 assert!(result.is_ok());
544 assert_eq!(counter.load(Ordering::SeqCst), 4); }
546
547 #[tokio::test]
548 async fn retry_exhausts_and_returns_last_error() {
549 let policy = RetryPolicy::new(2, Duration::from_millis(1), Duration::from_millis(10))
550 .with_jitter_ms(0);
551 let counter = Arc::new(AtomicU32::new(0));
552
553 let result: std::result::Result<(), String> = retry(&policy, || {
554 let c = Arc::clone(&counter);
555 async move {
556 c.fetch_add(1, Ordering::SeqCst);
557 Err("always fails".to_string())
558 }
559 })
560 .await;
561
562 assert!(result.is_err());
563 assert_eq!(counter.load(Ordering::SeqCst), 3); }
565}