1use std::collections::HashMap;
45use std::sync::{Arc, RwLock};
46use std::time::{Duration, Instant};
47
48use async_trait::async_trait;
49
50use crate::domain::error::{Result, ServiceError, StygianError};
51use crate::ports::escalation::{EscalationPolicy, EscalationTier, ResponseContext};
52use crate::ports::{ScrapingService, ServiceInput, ServiceOutput};
53
54#[derive(Debug, Clone)]
58pub struct EscalationConfig {
59 pub max_tier: EscalationTier,
61 pub base_tier: EscalationTier,
63 pub cache_ttl: Duration,
65}
66
67impl Default for EscalationConfig {
68 fn default() -> Self {
69 Self {
70 max_tier: EscalationTier::BrowserAdvanced,
71 base_tier: EscalationTier::HttpPlain,
72 cache_ttl: Duration::from_hours(1),
73 }
74 }
75}
76
77fn is_cloudflare_challenge(body: &str) -> bool {
81 body.contains("Just a moment")
82 || body.contains("cf-browser-verification")
83 || body.contains("__cf_bm")
84 || body.contains("Checking if the site connection is secure")
85}
86
87fn is_datadome_interstitial(body: &str) -> bool {
89 body.contains("datadome") || body.contains("dd_referrer")
90}
91
92fn is_perimeterx_challenge(body: &str) -> bool {
94 body.contains("_pxParam") || body.contains("_px.js") || body.contains("blockScript")
95}
96
97fn has_captcha_marker(body: &str) -> bool {
99 body.contains("recaptcha") || body.contains("hcaptcha") || body.contains("turnstile")
100}
101
102type CacheEntry = (EscalationTier, Instant);
106
107#[derive(Clone)]
111pub struct DefaultEscalationPolicy {
112 config: EscalationConfig,
113 cache: Arc<RwLock<HashMap<String, CacheEntry>>>,
115}
116
117impl DefaultEscalationPolicy {
118 #[must_use]
120 pub fn new(config: EscalationConfig) -> Self {
121 Self {
122 config,
123 cache: Arc::new(RwLock::new(HashMap::new())),
124 }
125 }
126
127 #[must_use]
133 pub fn context_from_body(status: u16, body: &str) -> ResponseContext {
134 ResponseContext {
135 status,
136 body_empty: body.trim().is_empty(),
137 has_cloudflare_challenge: is_cloudflare_challenge(body)
138 || is_datadome_interstitial(body)
139 || is_perimeterx_challenge(body),
140 has_captcha: has_captcha_marker(body),
141 }
142 }
143
144 pub fn initial_tier_for_domain(&self, domain: &str) -> EscalationTier {
149 let result = {
150 let cache = self
151 .cache
152 .read()
153 .unwrap_or_else(std::sync::PoisonError::into_inner);
154 cache.get(domain).copied()
155 };
156 if let Some((tier, expires_at)) = result
157 && Instant::now() < expires_at
158 {
159 tracing::debug!(domain, tier = %tier, "using cached initial escalation tier");
160 return tier;
161 }
162 self.config.base_tier
163 }
164
165 pub fn record_tier_success(&self, domain: &str, tier: EscalationTier) {
171 if tier <= self.config.base_tier {
172 return; }
174 let expires_at = Instant::now() + self.config.cache_ttl;
175 let mut cache = self
176 .cache
177 .write()
178 .unwrap_or_else(std::sync::PoisonError::into_inner);
179 let should_insert = cache.get(domain).is_none_or(|(cached, _)| tier >= *cached);
180 if should_insert {
181 tracing::info!(domain, tier = %tier, "caching successful escalation tier");
182 cache.insert(domain.to_string(), (tier, expires_at));
183 }
184 }
185
186 pub fn purge_expired_cache(&self) -> usize {
191 let mut cache = self
192 .cache
193 .write()
194 .unwrap_or_else(std::sync::PoisonError::into_inner);
195 let now = Instant::now();
196 let before = cache.len();
197 cache.retain(|_, (_, expires_at)| now < *expires_at);
198 before - cache.len()
199 }
200}
201
202impl EscalationPolicy for DefaultEscalationPolicy {
203 fn initial_tier(&self) -> EscalationTier {
204 self.config.base_tier
205 }
206
207 fn should_escalate(
208 &self,
209 ctx: &ResponseContext,
210 current: EscalationTier,
211 ) -> Option<EscalationTier> {
212 if current >= self.max_tier() {
213 return None;
214 }
215
216 let needs_escalation = ctx.status == 403
217 || ctx.status == 429
218 || ctx.has_cloudflare_challenge
219 || ctx.has_captcha
220 || (ctx.body_empty && current >= EscalationTier::HttpTlsProfiled);
221
222 if needs_escalation {
223 let next = current.next()?;
224 tracing::info!(
225 status = ctx.status,
226 current_tier = %current,
227 next_tier = %next,
228 "escalating request to higher tier"
229 );
230 Some(next)
231 } else {
232 None
233 }
234 }
235
236 fn max_tier(&self) -> EscalationTier {
237 self.config.max_tier
238 }
239}
240
241fn domain_from_url(url: &str) -> &str {
247 let after_scheme = url
248 .strip_prefix("https://")
249 .or_else(|| url.strip_prefix("http://"))
250 .unwrap_or(url);
251 let host_port = after_scheme
253 .split_once('/')
254 .map_or(after_scheme, |(h, _)| h);
255 host_port.split_once(':').map_or(host_port, |(h, _)| h)
257}
258
259pub struct EscalatingScrapingService {
286 tier_services: HashMap<EscalationTier, Arc<dyn ScrapingService>>,
287 policy: DefaultEscalationPolicy,
288}
289
290impl EscalatingScrapingService {
291 #[must_use]
295 pub fn new(policy: DefaultEscalationPolicy) -> Self {
296 Self {
297 tier_services: HashMap::new(),
298 policy,
299 }
300 }
301
302 #[must_use]
304 pub fn with_tier(mut self, tier: EscalationTier, service: Arc<dyn ScrapingService>) -> Self {
305 self.tier_services.insert(tier, service);
306 self
307 }
308
309 fn service_at_or_above(
311 &self,
312 tier: EscalationTier,
313 ) -> Option<(EscalationTier, &Arc<dyn ScrapingService>)> {
314 let mut current = Some(tier);
315 while let Some(t) = current {
316 if let Some(svc) = self.tier_services.get(&t) {
317 return Some((t, svc));
318 }
319 current = t.next();
320 }
321 None
322 }
323}
324
325#[async_trait]
326impl ScrapingService for EscalatingScrapingService {
327 async fn execute(&self, input: ServiceInput) -> Result<ServiceOutput> {
328 let host = domain_from_url(&input.url).to_string();
329 let mut current_tier = self.policy.initial_tier_for_domain(&host);
330 let mut escalation_path: Vec<EscalationTier> = Vec::new();
331
332 loop {
333 let (actual_tier, service) =
335 self.service_at_or_above(current_tier).ok_or_else(|| {
336 StygianError::Service(ServiceError::Unavailable(format!(
337 "no service configured for escalation tier '{current_tier}' or above"
338 )))
339 })?;
340
341 if actual_tier != current_tier {
342 tracing::debug!(
343 requested = %current_tier,
344 resolved = %actual_tier,
345 "no service at requested tier, using next available"
346 );
347 current_tier = actual_tier;
348 }
349
350 match service.execute(input.clone()).await {
351 Ok(output) => {
352 let status = output
353 .metadata
354 .get("status_code")
355 .and_then(serde_json::Value::as_u64)
356 .map_or(200_u16, |s| u16::try_from(s).unwrap_or(200_u16));
357 let ctx = DefaultEscalationPolicy::context_from_body(status, &output.data);
358
359 if let Some(next_tier) = self.policy.should_escalate(&ctx, current_tier) {
360 escalation_path.push(current_tier);
361 current_tier = next_tier;
362 continue;
363 }
364
365 self.policy.record_tier_success(&host, current_tier);
367
368 let mut metadata = output.metadata;
369 if let Some(obj) = metadata.as_object_mut() {
370 obj.insert(
371 "escalation_tier".to_string(),
372 serde_json::Value::String(current_tier.to_string()),
373 );
374 obj.insert(
375 "escalation_path".to_string(),
376 serde_json::Value::Array(
377 escalation_path
378 .iter()
379 .map(|t| serde_json::Value::String(t.to_string()))
380 .collect(),
381 ),
382 );
383 }
384
385 return Ok(ServiceOutput {
386 data: output.data,
387 metadata,
388 });
389 }
390
391 Err(e) => {
392 match current_tier.next().filter(|&t| t <= self.policy.max_tier()) {
394 Some(next_tier) => {
395 tracing::info!(
396 tier = %current_tier,
397 next = %next_tier,
398 error = %e,
399 "service error, escalating to next tier"
400 );
401 escalation_path.push(current_tier);
402 current_tier = next_tier;
403 }
404 None => return Err(e),
405 }
406 }
407 }
408 }
409 }
410
411 fn name(&self) -> &'static str {
412 "http_escalating"
413 }
414}
415
416#[cfg(test)]
419#[allow(clippy::unwrap_used)]
420mod tests {
421 use super::*;
422
423 fn default_policy() -> DefaultEscalationPolicy {
424 DefaultEscalationPolicy::new(EscalationConfig::default())
425 }
426
427 fn ok_ctx(status: u16) -> ResponseContext {
428 ResponseContext {
429 status,
430 body_empty: false,
431 has_cloudflare_challenge: false,
432 has_captcha: false,
433 }
434 }
435
436 #[test]
439 fn initial_tier_returns_base() {
440 assert_eq!(default_policy().initial_tier(), EscalationTier::HttpPlain);
441 }
442
443 #[test]
444 fn status_200_no_markers_does_not_escalate() {
445 let policy = default_policy();
446 assert!(
447 policy
448 .should_escalate(&ok_ctx(200), EscalationTier::HttpPlain)
449 .is_none()
450 );
451 }
452
453 #[test]
454 fn status_403_triggers_escalation() {
455 let policy = default_policy();
456 assert_eq!(
457 policy.should_escalate(&ok_ctx(403), EscalationTier::HttpPlain),
458 Some(EscalationTier::HttpTlsProfiled),
459 );
460 }
461
462 #[test]
463 fn status_429_triggers_escalation() {
464 let policy = default_policy();
465 assert_eq!(
466 policy.should_escalate(&ok_ctx(429), EscalationTier::HttpPlain),
467 Some(EscalationTier::HttpTlsProfiled),
468 );
469 }
470
471 #[test]
472 fn cloudflare_challenge_escalates_from_tls_profiled() {
473 let policy = default_policy();
474 let ctx = ResponseContext {
475 status: 200,
476 body_empty: false,
477 has_cloudflare_challenge: true,
478 has_captcha: false,
479 };
480 assert_eq!(
481 policy.should_escalate(&ctx, EscalationTier::HttpTlsProfiled),
482 Some(EscalationTier::BrowserBasic),
483 );
484 }
485
486 #[test]
487 fn captcha_escalates_from_browser_basic() {
488 let policy = default_policy();
489 let ctx = ResponseContext {
490 status: 200,
491 body_empty: false,
492 has_cloudflare_challenge: false,
493 has_captcha: true,
494 };
495 assert_eq!(
496 policy.should_escalate(&ctx, EscalationTier::BrowserBasic),
497 Some(EscalationTier::BrowserAdvanced),
498 );
499 }
500
501 #[test]
502 fn max_tier_cap_prevents_further_escalation() {
503 let policy = DefaultEscalationPolicy::new(EscalationConfig {
504 max_tier: EscalationTier::BrowserBasic,
505 ..EscalationConfig::default()
506 });
507 assert!(
509 policy
510 .should_escalate(&ok_ctx(403), EscalationTier::BrowserBasic)
511 .is_none()
512 );
513 }
514
515 #[test]
516 fn empty_body_at_http_plain_does_not_escalate() {
517 let policy = default_policy();
518 let ctx = ResponseContext {
519 status: 200,
520 body_empty: true,
521 has_cloudflare_challenge: false,
522 has_captcha: false,
523 };
524 assert!(
526 policy
527 .should_escalate(&ctx, EscalationTier::HttpPlain)
528 .is_none()
529 );
530 }
531
532 #[test]
533 fn empty_body_at_tls_profiled_triggers_escalation() {
534 let policy = default_policy();
535 let ctx = ResponseContext {
536 status: 200,
537 body_empty: true,
538 has_cloudflare_challenge: false,
539 has_captcha: false,
540 };
541 assert_eq!(
542 policy.should_escalate(&ctx, EscalationTier::HttpTlsProfiled),
543 Some(EscalationTier::BrowserBasic),
544 );
545 }
546
547 #[test]
550 fn domain_cache_starts_at_base_tier() {
551 let policy = default_policy();
552 assert_eq!(
553 policy.initial_tier_for_domain("example.com"),
554 EscalationTier::HttpPlain
555 );
556 }
557
558 #[test]
559 fn domain_cache_returns_recorded_tier() {
560 let policy = default_policy();
561 policy.record_tier_success("guarded.io", EscalationTier::BrowserBasic);
562 assert_eq!(
563 policy.initial_tier_for_domain("guarded.io"),
564 EscalationTier::BrowserBasic
565 );
566 }
567
568 #[test]
569 fn domain_cache_does_not_regress() {
570 let policy = default_policy();
571 policy.record_tier_success("strict.io", EscalationTier::BrowserAdvanced);
572 policy.record_tier_success("strict.io", EscalationTier::BrowserBasic); assert_eq!(
574 policy.initial_tier_for_domain("strict.io"),
575 EscalationTier::BrowserAdvanced
576 );
577 }
578
579 #[test]
580 fn record_base_tier_does_not_pollute_cache() {
581 let policy = default_policy();
582 policy.record_tier_success("plain.io", EscalationTier::HttpPlain);
583 assert_eq!(
585 policy.initial_tier_for_domain("plain.io"),
586 EscalationTier::HttpPlain
587 );
588 }
589
590 #[test]
591 fn purge_expired_removes_entries() {
592 let policy = DefaultEscalationPolicy::new(EscalationConfig {
593 cache_ttl: Duration::from_millis(1),
594 ..EscalationConfig::default()
595 });
596 policy.record_tier_success("fast-expiry.com", EscalationTier::BrowserBasic);
597 std::thread::sleep(Duration::from_millis(10));
598 let removed = policy.purge_expired_cache();
599 assert_eq!(removed, 1);
600 assert_eq!(
602 policy.initial_tier_for_domain("fast-expiry.com"),
603 EscalationTier::HttpPlain
604 );
605 }
606
607 #[test]
610 fn context_from_body_detects_cloudflare() {
611 let body = "<html><title>Just a moment...</title></html>";
612 let ctx = DefaultEscalationPolicy::context_from_body(403, body);
613 assert!(ctx.has_cloudflare_challenge);
614 assert_eq!(ctx.status, 403);
615 assert!(!ctx.body_empty);
616 }
617
618 #[test]
619 fn context_from_body_detects_perimeterx() {
620 let body = r#"<script src="/_px.js"></script>"#;
621 let ctx = DefaultEscalationPolicy::context_from_body(200, body);
622 assert!(ctx.has_cloudflare_challenge);
623 }
624
625 #[test]
626 fn context_from_body_detects_datadome() {
627 let body = r#"<meta name="datadome" content="protected">"#;
628 let ctx = DefaultEscalationPolicy::context_from_body(200, body);
629 assert!(ctx.has_cloudflare_challenge);
630 }
631
632 #[test]
633 fn context_from_body_detects_captcha() {
634 let body = r#"<script src="hcaptcha.com/1/api.js"></script>"#;
635 let ctx = DefaultEscalationPolicy::context_from_body(200, body);
636 assert!(ctx.has_captcha);
637 assert!(!ctx.has_cloudflare_challenge);
638 }
639
640 #[test]
641 fn context_from_body_empty_whitespace() {
642 let ctx = DefaultEscalationPolicy::context_from_body(200, " \n ");
643 assert!(ctx.body_empty);
644 }
645
646 #[test]
649 fn detection_helpers_match_markers() {
650 assert!(is_cloudflare_challenge("Just a moment..."));
651 assert!(is_cloudflare_challenge("cf-browser-verification token"));
652 assert!(is_datadome_interstitial("window.datadome = {}"));
653 assert!(is_perimeterx_challenge("var _pxParam1 = 'abc'"));
654 assert!(has_captcha_marker("www.google.com/recaptcha/api.js"));
655 assert!(has_captcha_marker("turnstile.cloudflare.com"));
656 }
657
658 #[test]
661 fn domain_from_url_strips_scheme_and_path() {
662 assert_eq!(
663 domain_from_url("https://example.com/path?q=1"),
664 "example.com"
665 );
666 assert_eq!(
667 domain_from_url("http://sub.example.com/"),
668 "sub.example.com"
669 );
670 }
671
672 #[test]
673 fn domain_from_url_strips_port() {
674 assert_eq!(
675 domain_from_url("https://example.com:8443/api"),
676 "example.com"
677 );
678 }
679
680 #[test]
681 fn domain_from_url_no_scheme_passes_through() {
682 let raw = "example.com/path";
684 let result = domain_from_url(raw);
685 assert!(!result.contains("http"));
686 }
687
688 struct MockService {
692 body: &'static str,
693 status: u16,
694 }
695
696 #[async_trait]
697 impl ScrapingService for MockService {
698 async fn execute(
699 &self,
700 _input: ServiceInput,
701 ) -> crate::domain::error::Result<ServiceOutput> {
702 Ok(ServiceOutput {
703 data: self.body.to_string(),
704 metadata: serde_json::json!({ "status_code": self.status }),
705 })
706 }
707 fn name(&self) -> &'static str {
708 "mock"
709 }
710 }
711
712 struct FailingService;
714
715 #[async_trait]
716 impl ScrapingService for FailingService {
717 async fn execute(
718 &self,
719 _input: ServiceInput,
720 ) -> crate::domain::error::Result<ServiceOutput> {
721 Err(StygianError::Service(ServiceError::Unavailable(
722 "blocked".into(),
723 )))
724 }
725 fn name(&self) -> &'static str {
726 "failing"
727 }
728 }
729
730 fn test_input() -> ServiceInput {
731 ServiceInput {
732 url: "https://example.com/data".to_string(),
733 params: serde_json::Value::Null,
734 }
735 }
736
737 #[tokio::test]
738 async fn escalating_service_returns_ok_on_clean_response() {
739 let policy = DefaultEscalationPolicy::new(EscalationConfig::default());
740 let svc = EscalatingScrapingService::new(policy).with_tier(
741 EscalationTier::HttpPlain,
742 Arc::new(MockService {
743 body: "<html>hello</html>",
744 status: 200,
745 }),
746 );
747 let output = svc.execute(test_input()).await.unwrap();
748 assert_eq!(
749 output
750 .metadata
751 .get("escalation_tier")
752 .and_then(serde_json::Value::as_str)
753 .unwrap(),
754 "http_plain"
755 );
756 let path = output
757 .metadata
758 .get("escalation_path")
759 .and_then(serde_json::Value::as_array)
760 .unwrap();
761 assert!(path.is_empty());
762 }
763
764 #[tokio::test]
765 async fn escalating_service_escalates_on_cf_challenge() {
766 let policy = DefaultEscalationPolicy::new(EscalationConfig::default());
767 let svc = EscalatingScrapingService::new(policy)
768 .with_tier(
769 EscalationTier::HttpPlain,
770 Arc::new(MockService {
771 body: "<html><title>Just a moment...</title></html>",
772 status: 200,
773 }),
774 )
775 .with_tier(
776 EscalationTier::HttpTlsProfiled,
777 Arc::new(MockService {
778 body: "<html>real content</html>",
779 status: 200,
780 }),
781 );
782 let output = svc.execute(test_input()).await.unwrap();
783 assert_eq!(
784 output
785 .metadata
786 .get("escalation_tier")
787 .and_then(serde_json::Value::as_str)
788 .unwrap(),
789 "http_tls_profiled"
790 );
791 let path = output
792 .metadata
793 .get("escalation_path")
794 .and_then(serde_json::Value::as_array)
795 .unwrap();
796 assert_eq!(path.len(), 1);
797 assert_eq!(
798 path.first().and_then(serde_json::Value::as_str).unwrap(),
799 "http_plain"
800 );
801 }
802
803 #[tokio::test]
804 async fn escalating_service_escalates_on_service_error() {
805 let policy = DefaultEscalationPolicy::new(EscalationConfig::default());
806 let svc = EscalatingScrapingService::new(policy)
807 .with_tier(EscalationTier::HttpPlain, Arc::new(FailingService))
808 .with_tier(
809 EscalationTier::BrowserBasic,
810 Arc::new(MockService {
811 body: "<html>recovered</html>",
812 status: 200,
813 }),
814 );
815 let output = svc.execute(test_input()).await.unwrap();
816 assert_eq!(
817 output
818 .metadata
819 .get("escalation_tier")
820 .and_then(serde_json::Value::as_str)
821 .unwrap(),
822 "browser_basic"
823 );
824 }
825
826 #[tokio::test]
827 async fn escalating_service_returns_error_when_all_tiers_fail() {
828 let policy = DefaultEscalationPolicy::new(EscalationConfig {
829 max_tier: EscalationTier::BrowserBasic,
830 ..EscalationConfig::default()
831 });
832 let svc = EscalatingScrapingService::new(policy)
833 .with_tier(EscalationTier::HttpPlain, Arc::new(FailingService))
834 .with_tier(EscalationTier::BrowserBasic, Arc::new(FailingService));
835
836 assert!(svc.execute(test_input()).await.is_err());
837 }
838
839 #[tokio::test]
840 async fn escalating_service_no_services_returns_error() {
841 let policy = DefaultEscalationPolicy::new(EscalationConfig::default());
842 let svc = EscalatingScrapingService::new(policy);
843 assert!(svc.execute(test_input()).await.is_err());
844 }
845
846 #[tokio::test]
847 async fn escalating_service_updates_domain_cache_on_success() {
848 let policy = DefaultEscalationPolicy::new(EscalationConfig::default());
849 let svc = EscalatingScrapingService::new(policy.clone())
850 .with_tier(
851 EscalationTier::HttpPlain,
852 Arc::new(MockService {
853 body: "<html><title>Just a moment...</title></html>",
854 status: 200,
855 }),
856 )
857 .with_tier(
858 EscalationTier::HttpTlsProfiled,
859 Arc::new(MockService {
860 body: "<html>ok</html>",
861 status: 200,
862 }),
863 );
864
865 svc.execute(test_input()).await.unwrap();
866
867 assert_eq!(
869 policy.initial_tier_for_domain("example.com"),
870 EscalationTier::HttpTlsProfiled
871 );
872 }
873}