1use crate::domain::error::Result;
4use crate::ports::{CircuitBreaker, CircuitState, RateLimitConfig, RateLimiter};
5use async_trait::async_trait;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11pub struct CircuitBreakerImpl {
36 state: Arc<RwLock<CircuitBreakerState>>,
37 failure_threshold: u32,
38 timeout: Duration,
39}
40
41#[derive(Debug)]
42struct CircuitBreakerState {
43 current: CircuitState,
44 failure_count: u32,
45 last_failure_time: Option<Instant>,
46}
47
48impl CircuitBreakerImpl {
49 #[must_use]
56 pub fn new(failure_threshold: u32, timeout: Duration) -> Self {
57 Self {
58 state: Arc::new(RwLock::new(CircuitBreakerState {
59 current: CircuitState::Closed,
60 failure_count: 0,
61 last_failure_time: None,
62 })),
63 failure_threshold,
64 timeout,
65 }
66 }
67
68 fn should_attempt_reset(&self, state: &CircuitBreakerState) -> bool {
70 if state.current != CircuitState::Open {
71 return false;
72 }
73
74 state
75 .last_failure_time
76 .is_some_and(|last_failure| last_failure.elapsed() >= self.timeout)
77 }
78}
79
80impl CircuitBreaker for CircuitBreakerImpl {
81 fn state(&self) -> CircuitState {
82 let state = self.state.read();
83 state.current
84 }
85
86 fn record_success(&self) {
87 let mut state = self.state.write();
88 state.failure_count = 0;
90 state.current = CircuitState::Closed;
91 state.last_failure_time = None;
92 }
93
94 fn record_failure(&self) {
95 let mut state = self.state.write();
96 state.failure_count += 1;
97 state.last_failure_time = Some(Instant::now());
98
99 if state.failure_count >= self.failure_threshold {
101 state.current = CircuitState::Open;
102 }
103 }
104
105 fn attempt_reset(&self) -> bool {
106 let mut state = self.state.write();
107
108 if self.should_attempt_reset(&state) {
109 state.current = CircuitState::HalfOpen;
110 state.failure_count = 0;
111 true
112 } else {
113 false
114 }
115 }
116}
117
118#[derive(Debug, Default, Clone, Copy)]
134pub struct NoopCircuitBreaker;
135
136impl CircuitBreaker for NoopCircuitBreaker {
137 fn state(&self) -> CircuitState {
138 CircuitState::Closed
139 }
140
141 fn record_success(&self) {
142 }
144
145 fn record_failure(&self) {
146 }
148
149 fn attempt_reset(&self) -> bool {
150 false
151 }
152}
153
154pub struct TokenBucketRateLimiter {
183 config: RateLimitConfig,
184 buckets: Arc<RwLock<HashMap<String, TokenBucket>>>,
185}
186
187#[derive(Debug)]
188struct TokenBucket {
189 tokens: u32,
190 last_refill: Instant,
191}
192
193impl TokenBucketRateLimiter {
194 #[must_use]
196 pub fn new(config: RateLimitConfig) -> Self {
197 Self {
198 config,
199 buckets: Arc::new(RwLock::new(HashMap::new())),
200 }
201 }
202
203 fn refill_tokens(&self, bucket: &mut TokenBucket) {
205 let elapsed = bucket.last_refill.elapsed();
206 let refill_rate = f64::from(self.config.max_requests) / self.config.window.as_secs_f64();
207 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
208 let tokens_to_add = (elapsed.as_secs_f64() * refill_rate) as u32;
209
210 if tokens_to_add > 0 {
211 bucket.tokens = (bucket.tokens + tokens_to_add).min(self.config.max_requests);
212 bucket.last_refill = Instant::now();
213 }
214 }
215}
216
217#[async_trait]
218impl RateLimiter for TokenBucketRateLimiter {
219 #[allow(clippy::significant_drop_tightening)]
220 async fn check_rate_limit(&self, key: &str) -> Result<bool> {
221 let has_tokens = {
222 let mut buckets = self.buckets.write();
223 let bucket = buckets
224 .entry(key.to_string())
225 .or_insert_with(|| TokenBucket {
226 tokens: self.config.max_requests,
227 last_refill: Instant::now(),
228 });
229 self.refill_tokens(bucket);
230 bucket.tokens > 0
231 };
232 Ok(has_tokens)
233 }
234
235 async fn record_request(&self, key: &str) -> Result<()> {
236 {
237 let mut buckets = self.buckets.write();
238 if let Some(bucket) = buckets.get_mut(key)
239 && bucket.tokens > 0
240 {
241 bucket.tokens -= 1;
242 }
243 }
244 Ok(())
245 }
246}
247
248#[derive(Debug, Default, Clone, Copy)]
265pub struct NoopRateLimiter;
266
267#[async_trait]
268impl RateLimiter for NoopRateLimiter {
269 async fn check_rate_limit(&self, _key: &str) -> Result<bool> {
270 Ok(true)
271 }
272
273 async fn record_request(&self, _key: &str) -> Result<()> {
274 Ok(())
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
283 fn test_circuit_breaker_closes_on_success() {
284 let cb = CircuitBreakerImpl::new(3, Duration::from_secs(5));
285 cb.record_failure();
286 cb.record_failure();
287 assert_eq!(cb.state(), CircuitState::Closed);
288 cb.record_success();
289 assert_eq!(cb.state(), CircuitState::Closed);
290 }
291
292 #[test]
293 fn test_circuit_breaker_opens_on_threshold() {
294 let cb = CircuitBreakerImpl::new(3, Duration::from_secs(5));
295 cb.record_failure();
296 cb.record_failure();
297 cb.record_failure();
298 assert_eq!(cb.state(), CircuitState::Open);
299 }
300
301 #[test]
302 fn test_noop_circuit_breaker_always_closed() {
303 let cb = NoopCircuitBreaker;
304 assert_eq!(cb.state(), CircuitState::Closed);
305 cb.record_failure();
306 cb.record_failure();
307 cb.record_failure();
308 assert_eq!(cb.state(), CircuitState::Closed);
309 }
310
311 #[tokio::test]
312 async fn test_rate_limiter_allows_within_limit() -> Result<()> {
313 let config = RateLimitConfig {
314 max_requests: 10,
315 window: Duration::from_mins(1),
316 };
317 let limiter = TokenBucketRateLimiter::new(config);
318
319 assert!(limiter.check_rate_limit("test").await?);
320 limiter.record_request("test").await?;
321 assert!(limiter.check_rate_limit("test").await?);
322 Ok(())
323 }
324
325 #[tokio::test]
326 async fn test_noop_rate_limiter_always_allows() -> Result<()> {
327 let limiter = NoopRateLimiter;
328 assert!(limiter.check_rate_limit("any").await?);
329 limiter.record_request("any").await?;
330 assert!(limiter.check_rate_limit("any").await?);
331 Ok(())
332 }
333}
334
335#[derive(Debug, Clone)]
351pub struct RetryPolicy {
352 pub max_attempts: u32,
354 pub base_delay: Duration,
356 pub max_delay: Duration,
358 pub jitter_ms: u64,
360}
361
362impl RetryPolicy {
363 #[must_use]
375 pub const fn new(max_attempts: u32, base_delay: Duration, max_delay: Duration) -> Self {
376 Self {
377 max_attempts,
378 base_delay,
379 max_delay,
380 jitter_ms: 50,
381 }
382 }
383
384 #[must_use]
397 pub const fn with_jitter_ms(mut self, jitter_ms: u64) -> Self {
398 self.jitter_ms = jitter_ms;
399 self
400 }
401
402 #[must_use]
417 pub fn delay_for(&self, attempt: u32) -> Duration {
418 let factor = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
419 #[allow(clippy::cast_possible_truncation)]
420 let base_ms = self.base_delay.as_millis() as u64;
421 let jitter = if self.jitter_ms > 0 {
422 let seed = u64::from(
425 std::time::SystemTime::now()
426 .duration_since(std::time::UNIX_EPOCH)
427 .unwrap_or_default()
428 .subsec_nanos(),
429 );
430 (seed
431 .wrapping_mul(6_364_136_223_846_793_005)
432 .wrapping_add(1_442_695_040_888_963_407)
433 >> 33)
434 % self.jitter_ms
435 } else {
436 0
437 };
438 let ms = base_ms.saturating_mul(factor).saturating_add(jitter);
439 let delay = Duration::from_millis(ms);
440 delay.min(self.max_delay)
441 }
442}
443
444impl Default for RetryPolicy {
445 fn default() -> Self {
446 Self::new(3, Duration::from_millis(200), Duration::from_secs(30))
447 }
448}
449
450pub async fn retry<F, Fut, T, E>(policy: &RetryPolicy, mut f: F) -> std::result::Result<T, E>
487where
488 F: FnMut() -> Fut,
489 Fut: std::future::Future<Output = std::result::Result<T, E>>,
490{
491 let mut result = f().await;
492 for attempt in 1..=policy.max_attempts {
493 if result.is_ok() {
494 return result;
495 }
496 tokio::time::sleep(policy.delay_for(attempt - 1)).await;
497 result = f().await;
498 }
499 result
500}
501
502#[cfg(test)]
503mod retry_tests {
504 use super::*;
505 use std::sync::Arc;
506 use std::sync::atomic::{AtomicU32, Ordering};
507
508 #[test]
509 fn delay_for_doubles() {
510 let p = RetryPolicy::new(4, Duration::from_millis(100), Duration::from_mins(1))
511 .with_jitter_ms(0);
512 assert_eq!(p.delay_for(0), Duration::from_millis(100));
513 assert_eq!(p.delay_for(1), Duration::from_millis(200));
514 assert_eq!(p.delay_for(2), Duration::from_millis(400));
515 assert_eq!(p.delay_for(3), Duration::from_millis(800));
516 }
517
518 #[test]
519 fn delay_capped_at_max() {
520 let p =
521 RetryPolicy::new(10, Duration::from_secs(1), Duration::from_secs(3)).with_jitter_ms(0);
522 assert_eq!(p.delay_for(4), Duration::from_secs(3));
524 }
525
526 #[tokio::test]
527 async fn retry_succeeds_on_first_try() {
528 let policy = RetryPolicy::new(3, Duration::from_millis(1), Duration::from_millis(10))
529 .with_jitter_ms(0);
530 let result: std::result::Result<i32, &str> = retry(&policy, || async { Ok(42) }).await;
531 assert_eq!(result.ok(), Some(42));
532 }
533
534 #[tokio::test]
535 async fn retry_retries_until_success() {
536 let counter = Arc::new(AtomicU32::new(0));
537 let policy = RetryPolicy::new(4, Duration::from_millis(1), Duration::from_millis(50))
538 .with_jitter_ms(0);
539
540 let result: std::result::Result<u32, String> = retry(&policy, || {
541 let c = Arc::clone(&counter);
542 async move {
543 let n = c.fetch_add(1, Ordering::SeqCst);
544 if n < 3 {
545 Err(format!("fail {n}"))
546 } else {
547 Ok(n)
548 }
549 }
550 })
551 .await;
552
553 assert!(result.is_ok());
554 assert_eq!(counter.load(Ordering::SeqCst), 4); }
556
557 #[tokio::test]
558 async fn retry_exhausts_and_returns_last_error() {
559 let policy = RetryPolicy::new(2, Duration::from_millis(1), Duration::from_millis(10))
560 .with_jitter_ms(0);
561 let counter = Arc::new(AtomicU32::new(0));
562
563 let result: std::result::Result<(), String> = retry(&policy, || {
564 let c = Arc::clone(&counter);
565 async move {
566 c.fetch_add(1, Ordering::SeqCst);
567 Err("always fails".to_string())
568 }
569 })
570 .await;
571
572 assert!(result.is_err());
573 assert_eq!(counter.load(Ordering::SeqCst), 3); }
575}