Skip to main content

stygian_charon/
cache.rs

1use std::collections::hash_map::DefaultHasher;
2use std::hash::{Hash, Hasher};
3use std::num::NonZeroUsize;
4use std::sync::Mutex;
5use std::time::{Duration, Instant};
6
7use crate::types::{InvestigationReport, TargetClass};
8
9#[derive(Debug, Clone)]
10struct CacheEntry {
11    report: InvestigationReport,
12    expires_at: Instant,
13}
14
15impl CacheEntry {
16    fn new(report: InvestigationReport, ttl: Duration) -> Self {
17        Self {
18            report,
19            expires_at: Instant::now() + ttl,
20        }
21    }
22
23    fn is_expired(&self) -> bool {
24        Instant::now() >= self.expires_at
25    }
26}
27
28/// Cache abstraction for normalized investigation reports.
29///
30/// Implementations are expected to store cloned [`InvestigationReport`] values keyed by the
31/// hashed HAR payload and target class.
32pub trait InvestigationReportCache: Send + Sync {
33    /// Look up a cached investigation report by cache key.
34    fn get(&self, key: &str) -> Option<InvestigationReport>;
35
36    /// Insert or replace a cached investigation report.
37    fn put(&self, key: String, report: InvestigationReport);
38
39    /// Invalidate a single cache key.
40    fn invalidate(&self, key: &str);
41
42    /// Remove all cached entries.
43    fn clear(&self);
44}
45
46/// Generate a stable cache key from HAR content and target class.
47#[must_use]
48pub fn investigation_cache_key(har_json: &str, target_class: TargetClass) -> String {
49    let mut hasher = DefaultHasher::new();
50    har_json.hash(&mut hasher);
51    target_class.hash(&mut hasher);
52    format!("charon:investigation:{:016x}", hasher.finish())
53}
54
55/// In-memory capacity-bounded LRU cache with TTL for investigation reports.
56pub struct MemoryInvestigationCache {
57    ttl: Duration,
58    inner: Mutex<lru::LruCache<String, CacheEntry>>,
59}
60
61impl MemoryInvestigationCache {
62    /// Create a new in-memory cache.
63    #[must_use]
64    pub fn new(capacity: NonZeroUsize, ttl: Duration) -> Self {
65        Self {
66            ttl,
67            inner: Mutex::new(lru::LruCache::new(capacity)),
68        }
69    }
70}
71
72impl InvestigationReportCache for MemoryInvestigationCache {
73    fn get(&self, key: &str) -> Option<InvestigationReport> {
74        let Ok(mut cache) = self.inner.lock() else {
75            return None;
76        };
77
78        match cache.get(key) {
79            Some(entry) if entry.is_expired() => {
80                cache.pop(key);
81                None
82            }
83            Some(entry) => Some(entry.report.clone()),
84            None => None,
85        }
86    }
87
88    fn put(&self, key: String, report: InvestigationReport) {
89        let Ok(mut cache) = self.inner.lock() else {
90            return;
91        };
92
93        cache.put(key, CacheEntry::new(report, self.ttl));
94    }
95
96    fn invalidate(&self, key: &str) {
97        if let Ok(mut cache) = self.inner.lock() {
98            cache.pop(key);
99        }
100    }
101
102    fn clear(&self) {
103        if let Ok(mut cache) = self.inner.lock() {
104            cache.clear();
105        }
106    }
107}
108
109/// Redis-backed investigation cache.
110#[cfg(feature = "redis-cache")]
111pub struct RedisInvestigationCache {
112    client: redis::Client,
113    ttl: Duration,
114    key_prefix: String,
115}
116
117#[cfg(feature = "redis-cache")]
118impl RedisInvestigationCache {
119    /// Create a new Redis-backed cache using the provided URL.
120    ///
121    /// # Errors
122    ///
123    /// Returns a Redis error if the client cannot be created from `redis_url`.
124    pub fn new(redis_url: &str, ttl: Duration) -> redis::RedisResult<Self> {
125        let client = redis::Client::open(redis_url)?;
126        Ok(Self {
127            client,
128            ttl,
129            key_prefix: "charon:investigation".to_string(),
130        })
131    }
132
133    fn prefixed_key(&self, key: &str) -> String {
134        format!("{}:{}", self.key_prefix, key)
135    }
136}
137
138#[cfg(feature = "redis-cache")]
139impl InvestigationReportCache for RedisInvestigationCache {
140    fn get(&self, key: &str) -> Option<InvestigationReport> {
141        let mut connection = self.client.get_connection().ok()?;
142        let payload: Option<String> = redis::cmd("GET")
143            .arg(self.prefixed_key(key))
144            .query(&mut connection)
145            .ok()?;
146        payload.and_then(|value| serde_json::from_str::<InvestigationReport>(&value).ok())
147    }
148
149    fn put(&self, key: String, report: InvestigationReport) {
150        let Ok(payload) = serde_json::to_string(&report) else {
151            return;
152        };
153        let ttl_seconds = self.ttl.as_secs();
154        let Ok(mut connection) = self.client.get_connection() else {
155            return;
156        };
157        let _: redis::RedisResult<()> = redis::cmd("SETEX")
158            .arg(self.prefixed_key(&key))
159            .arg(ttl_seconds)
160            .arg(payload)
161            .query(&mut connection);
162    }
163
164    fn invalidate(&self, key: &str) {
165        let Ok(mut connection) = self.client.get_connection() else {
166            return;
167        };
168        let _: redis::RedisResult<()> = redis::cmd("DEL")
169            .arg(self.prefixed_key(key))
170            .query(&mut connection);
171    }
172
173    fn clear(&self) {
174        let pattern = format!("{}:*", self.key_prefix);
175        let Ok(mut connection) = self.client.get_connection() else {
176            return;
177        };
178        let keys: redis::RedisResult<Vec<String>> =
179            redis::cmd("KEYS").arg(pattern).query(&mut connection);
180        if let Ok(keys) = keys
181            && !keys.is_empty()
182        {
183            let _: redis::RedisResult<()> = redis::cmd("DEL").arg(keys).query(&mut connection);
184        }
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use crate::types::{AntiBotProvider, Detection, InvestigationReport};
192    use std::collections::BTreeMap;
193
194    fn sample_report() -> InvestigationReport {
195        InvestigationReport {
196            page_title: Some("https://example.com".to_string()),
197            total_requests: 10,
198            blocked_requests: 2,
199            status_histogram: BTreeMap::from([(200, 8), (403, 2)]),
200            resource_type_histogram: BTreeMap::new(),
201            provider_histogram: BTreeMap::new(),
202            marker_histogram: BTreeMap::new(),
203            top_markers: Vec::new(),
204            hosts: Vec::new(),
205            suspicious_requests: Vec::new(),
206            aggregate: Detection {
207                provider: AntiBotProvider::Unknown,
208                confidence: 0.1,
209                markers: Vec::new(),
210            },
211            target_class: Some(TargetClass::Api),
212        }
213    }
214
215    #[test]
216    fn memory_cache_round_trips_report() {
217        let capacity = NonZeroUsize::new(2).unwrap_or(NonZeroUsize::MIN);
218        let cache = MemoryInvestigationCache::new(capacity, Duration::from_mins(1));
219        let key = investigation_cache_key("{\"log\":{}}", TargetClass::Api);
220        let report = sample_report();
221
222        cache.put(key.clone(), report.clone());
223
224        assert_eq!(cache.get(&key), Some(report));
225    }
226
227    #[test]
228    fn memory_cache_expires_entries() {
229        let capacity = NonZeroUsize::new(2).unwrap_or(NonZeroUsize::MIN);
230        let cache = MemoryInvestigationCache::new(capacity, Duration::from_millis(1));
231        let key = investigation_cache_key("{\"log\":{}}", TargetClass::Api);
232        cache.put(key.clone(), sample_report());
233        std::thread::sleep(Duration::from_millis(5));
234        assert!(cache.get(&key).is_none());
235    }
236
237    #[test]
238    fn cache_key_changes_by_target_class() {
239        let har = "{\"log\":{\"entries\":[]}}";
240        let api = investigation_cache_key(har, TargetClass::Api);
241        let high = investigation_cache_key(har, TargetClass::HighSecurity);
242        assert_ne!(api, high);
243    }
244}