1use std::collections::hash_map::DefaultHasher;
2use std::hash::{Hash, Hasher};
3use std::num::NonZeroUsize;
4use std::sync::Mutex;
5use std::time::{Duration, Instant};
6
7use crate::types::{InvestigationReport, TargetClass};
8
9#[derive(Debug, Clone)]
11struct TtlEntry<V> {
12 value: V,
13 expires_at: Instant,
14}
15
16impl<V> TtlEntry<V> {
17 fn new(value: V, ttl: Duration) -> Self {
18 Self {
19 value,
20 expires_at: Instant::now() + ttl,
21 }
22 }
23
24 fn is_expired(&self) -> bool {
25 Instant::now() >= self.expires_at
26 }
27}
28
29pub(crate) struct LruTtlStore<V> {
44 ttl: Duration,
45 inner: Mutex<lru::LruCache<String, TtlEntry<V>>>,
46}
47
48impl<V: Clone> LruTtlStore<V> {
49 #[must_use]
51 pub(crate) fn new(capacity: NonZeroUsize, ttl: Duration) -> Self {
52 Self {
53 ttl,
54 inner: Mutex::new(lru::LruCache::new(capacity)),
55 }
56 }
57
58 #[must_use]
60 pub(crate) const fn ttl(&self) -> Duration {
61 self.ttl
62 }
63
64 pub(crate) fn get(&self, key: &str) -> Option<V> {
67 let Ok(mut cache) = self.inner.lock() else {
68 return None;
69 };
70
71 match cache.get(key) {
72 Some(entry) if entry.is_expired() => {
73 cache.pop(key);
74 None
75 }
76 Some(entry) => Some(entry.value.clone()),
77 None => None,
78 }
79 }
80
81 #[allow(dead_code)]
85 pub(crate) fn peek(&self, key: &str) -> Option<V> {
86 let Ok(cache) = self.inner.lock() else {
87 return None;
88 };
89
90 match cache.peek(key) {
91 Some(entry) if entry.is_expired() => None,
92 Some(entry) => Some(entry.value.clone()),
93 None => None,
94 }
95 }
96
97 pub(crate) fn put(&self, key: String, value: V) {
99 let Ok(mut cache) = self.inner.lock() else {
100 return;
101 };
102
103 cache.put(key, TtlEntry::new(value, self.ttl));
104 }
105
106 pub(crate) fn invalidate(&self, key: &str) {
108 if let Ok(mut cache) = self.inner.lock() {
109 cache.pop(key);
110 }
111 }
112
113 pub(crate) fn clear(&self) {
115 if let Ok(mut cache) = self.inner.lock() {
116 cache.clear();
117 }
118 }
119
120 #[allow(dead_code)]
123 pub(crate) fn len(&self) -> usize {
124 self.inner.lock().map_or(0, |cache| cache.len())
125 }
126
127 #[allow(dead_code)]
129 pub(crate) fn is_empty(&self) -> bool {
130 self.len() == 0
131 }
132}
133
134pub trait InvestigationReportCache: Send + Sync {
139 fn get(&self, key: &str) -> Option<InvestigationReport>;
141
142 fn put(&self, key: String, report: InvestigationReport);
144
145 fn invalidate(&self, key: &str);
147
148 fn clear(&self);
150}
151
152#[must_use]
154pub fn investigation_cache_key(har_json: &str, target_class: TargetClass) -> String {
155 let mut hasher = DefaultHasher::new();
156 har_json.hash(&mut hasher);
157 target_class.hash(&mut hasher);
158 format!("charon:investigation:{:016x}", hasher.finish())
159}
160
161pub struct MemoryInvestigationCache {
163 store: LruTtlStore<InvestigationReport>,
164}
165
166impl MemoryInvestigationCache {
167 #[must_use]
169 pub fn new(capacity: NonZeroUsize, ttl: Duration) -> Self {
170 Self {
171 store: LruTtlStore::new(capacity, ttl),
172 }
173 }
174
175 #[must_use]
177 pub fn len(&self) -> usize {
178 self.store.len()
179 }
180
181 #[must_use]
183 pub fn is_empty(&self) -> bool {
184 self.store.is_empty()
185 }
186}
187
188impl InvestigationReportCache for MemoryInvestigationCache {
189 fn get(&self, key: &str) -> Option<InvestigationReport> {
190 self.store.get(key)
191 }
192
193 fn put(&self, key: String, report: InvestigationReport) {
194 self.store.put(key, report);
195 }
196
197 fn invalidate(&self, key: &str) {
198 self.store.invalidate(key);
199 }
200
201 fn clear(&self) {
202 self.store.clear();
203 }
204}
205
206#[cfg(feature = "redis-cache")]
208pub struct RedisInvestigationCache {
209 client: redis::Client,
210 ttl: Duration,
211 key_prefix: String,
212}
213
214#[cfg(feature = "redis-cache")]
215impl RedisInvestigationCache {
216 pub fn new(redis_url: &str, ttl: Duration) -> redis::RedisResult<Self> {
222 let client = redis::Client::open(redis_url)?;
223 Ok(Self {
224 client,
225 ttl,
226 key_prefix: "charon:investigation".to_string(),
227 })
228 }
229
230 fn prefixed_key(&self, key: &str) -> String {
231 format!("{}:{}", self.key_prefix, key)
232 }
233}
234
235#[cfg(feature = "redis-cache")]
236impl InvestigationReportCache for RedisInvestigationCache {
237 fn get(&self, key: &str) -> Option<InvestigationReport> {
238 let mut connection = self.client.get_connection().ok()?;
239 let payload: Option<String> = redis::cmd("GET")
240 .arg(self.prefixed_key(key))
241 .query(&mut connection)
242 .ok()?;
243 payload.and_then(|value| serde_json::from_str::<InvestigationReport>(&value).ok())
244 }
245
246 fn put(&self, key: String, report: InvestigationReport) {
247 let Ok(payload) = serde_json::to_string(&report) else {
248 return;
249 };
250 let ttl_seconds = self.ttl.as_secs();
251 let Ok(mut connection) = self.client.get_connection() else {
252 return;
253 };
254 let _: redis::RedisResult<()> = redis::cmd("SETEX")
255 .arg(self.prefixed_key(&key))
256 .arg(ttl_seconds)
257 .arg(payload)
258 .query(&mut connection);
259 }
260
261 fn invalidate(&self, key: &str) {
262 let Ok(mut connection) = self.client.get_connection() else {
263 return;
264 };
265 let _: redis::RedisResult<()> = redis::cmd("DEL")
266 .arg(self.prefixed_key(key))
267 .query(&mut connection);
268 }
269
270 fn clear(&self) {
271 let pattern = format!("{}:*", self.key_prefix);
272 let Ok(mut connection) = self.client.get_connection() else {
273 return;
274 };
275 let keys: redis::RedisResult<Vec<String>> =
276 redis::cmd("KEYS").arg(pattern).query(&mut connection);
277 if let Ok(keys) = keys
278 && !keys.is_empty()
279 {
280 let _: redis::RedisResult<()> = redis::cmd("DEL").arg(keys).query(&mut connection);
281 }
282 }
283}
284
285#[cfg(test)]
286#[allow(
287 clippy::unwrap_used,
288 clippy::expect_used,
289 clippy::panic,
290 clippy::indexing_slicing
291)]
292mod tests {
293 use super::*;
294 use crate::types::{AntiBotProvider, Detection, InvestigationReport};
295 use std::collections::BTreeMap;
296
297 fn sample_report() -> InvestigationReport {
298 InvestigationReport {
299 page_title: Some("https://example.com".to_string()),
300 total_requests: 10,
301 blocked_requests: 2,
302 status_histogram: BTreeMap::from([(200, 8), (403, 2)]),
303 resource_type_histogram: BTreeMap::new(),
304 provider_histogram: BTreeMap::new(),
305 marker_histogram: BTreeMap::new(),
306 top_markers: Vec::new(),
307 hosts: Vec::new(),
308 suspicious_requests: Vec::new(),
309 aggregate: Detection {
310 provider: AntiBotProvider::Unknown,
311 confidence: 0.1,
312 markers: Vec::new(),
313 },
314 target_class: Some(TargetClass::Api),
315 }
316 }
317
318 #[test]
319 fn memory_cache_round_trips_report() {
320 let capacity = NonZeroUsize::new(2).unwrap_or(NonZeroUsize::MIN);
321 let cache = MemoryInvestigationCache::new(capacity, Duration::from_mins(1));
322 let key = investigation_cache_key("{\"log\":{}}", TargetClass::Api);
323 let report = sample_report();
324
325 cache.put(key.clone(), report.clone());
326
327 assert_eq!(cache.get(&key), Some(report));
328 }
329
330 #[test]
331 fn memory_cache_expires_entries() {
332 let capacity = NonZeroUsize::new(2).unwrap_or(NonZeroUsize::MIN);
333 let cache = MemoryInvestigationCache::new(capacity, Duration::from_millis(1));
334 let key = investigation_cache_key("{\"log\":{}}", TargetClass::Api);
335 cache.put(key.clone(), sample_report());
336 std::thread::sleep(Duration::from_millis(5));
337 assert!(cache.get(&key).is_none());
338 }
339
340 #[test]
341 fn cache_key_changes_by_target_class() {
342 let har = "{\"log\":{\"entries\":[]}}";
343 let api = investigation_cache_key(har, TargetClass::Api);
344 let high = investigation_cache_key(har, TargetClass::HighSecurity);
345 assert_ne!(api, high);
346 }
347}