1use std::collections::hash_map::DefaultHasher;
2use std::hash::{Hash, Hasher};
3use std::num::NonZeroUsize;
4use std::sync::Mutex;
5use std::time::{Duration, Instant};
6
7use crate::types::{InvestigationReport, TargetClass};
8
9#[derive(Debug, Clone)]
10struct CacheEntry {
11 report: InvestigationReport,
12 expires_at: Instant,
13}
14
15impl CacheEntry {
16 fn new(report: InvestigationReport, ttl: Duration) -> Self {
17 Self {
18 report,
19 expires_at: Instant::now() + ttl,
20 }
21 }
22
23 fn is_expired(&self) -> bool {
24 Instant::now() >= self.expires_at
25 }
26}
27
28pub trait InvestigationReportCache: Send + Sync {
33 fn get(&self, key: &str) -> Option<InvestigationReport>;
35
36 fn put(&self, key: String, report: InvestigationReport);
38
39 fn invalidate(&self, key: &str);
41
42 fn clear(&self);
44}
45
46#[must_use]
48pub fn investigation_cache_key(har_json: &str, target_class: TargetClass) -> String {
49 let mut hasher = DefaultHasher::new();
50 har_json.hash(&mut hasher);
51 target_class.hash(&mut hasher);
52 format!("charon:investigation:{:016x}", hasher.finish())
53}
54
55pub struct MemoryInvestigationCache {
57 ttl: Duration,
58 inner: Mutex<lru::LruCache<String, CacheEntry>>,
59}
60
61impl MemoryInvestigationCache {
62 #[must_use]
64 pub fn new(capacity: NonZeroUsize, ttl: Duration) -> Self {
65 Self {
66 ttl,
67 inner: Mutex::new(lru::LruCache::new(capacity)),
68 }
69 }
70}
71
72impl InvestigationReportCache for MemoryInvestigationCache {
73 fn get(&self, key: &str) -> Option<InvestigationReport> {
74 let Ok(mut cache) = self.inner.lock() else {
75 return None;
76 };
77
78 match cache.get(key) {
79 Some(entry) if entry.is_expired() => {
80 cache.pop(key);
81 None
82 }
83 Some(entry) => Some(entry.report.clone()),
84 None => None,
85 }
86 }
87
88 fn put(&self, key: String, report: InvestigationReport) {
89 let Ok(mut cache) = self.inner.lock() else {
90 return;
91 };
92
93 cache.put(key, CacheEntry::new(report, self.ttl));
94 }
95
96 fn invalidate(&self, key: &str) {
97 if let Ok(mut cache) = self.inner.lock() {
98 cache.pop(key);
99 }
100 }
101
102 fn clear(&self) {
103 if let Ok(mut cache) = self.inner.lock() {
104 cache.clear();
105 }
106 }
107}
108
109#[cfg(feature = "redis-cache")]
111pub struct RedisInvestigationCache {
112 client: redis::Client,
113 ttl: Duration,
114 key_prefix: String,
115}
116
117#[cfg(feature = "redis-cache")]
118impl RedisInvestigationCache {
119 pub fn new(redis_url: &str, ttl: Duration) -> redis::RedisResult<Self> {
125 let client = redis::Client::open(redis_url)?;
126 Ok(Self {
127 client,
128 ttl,
129 key_prefix: "charon:investigation".to_string(),
130 })
131 }
132
133 fn prefixed_key(&self, key: &str) -> String {
134 format!("{}:{}", self.key_prefix, key)
135 }
136}
137
138#[cfg(feature = "redis-cache")]
139impl InvestigationReportCache for RedisInvestigationCache {
140 fn get(&self, key: &str) -> Option<InvestigationReport> {
141 let mut connection = self.client.get_connection().ok()?;
142 let payload: Option<String> = redis::cmd("GET")
143 .arg(self.prefixed_key(key))
144 .query(&mut connection)
145 .ok()?;
146 payload.and_then(|value| serde_json::from_str::<InvestigationReport>(&value).ok())
147 }
148
149 fn put(&self, key: String, report: InvestigationReport) {
150 let Ok(payload) = serde_json::to_string(&report) else {
151 return;
152 };
153 let ttl_seconds = self.ttl.as_secs();
154 let Ok(mut connection) = self.client.get_connection() else {
155 return;
156 };
157 let _: redis::RedisResult<()> = redis::cmd("SETEX")
158 .arg(self.prefixed_key(&key))
159 .arg(ttl_seconds)
160 .arg(payload)
161 .query(&mut connection);
162 }
163
164 fn invalidate(&self, key: &str) {
165 let Ok(mut connection) = self.client.get_connection() else {
166 return;
167 };
168 let _: redis::RedisResult<()> = redis::cmd("DEL")
169 .arg(self.prefixed_key(key))
170 .query(&mut connection);
171 }
172
173 fn clear(&self) {
174 let pattern = format!("{}:*", self.key_prefix);
175 let Ok(mut connection) = self.client.get_connection() else {
176 return;
177 };
178 let keys: redis::RedisResult<Vec<String>> =
179 redis::cmd("KEYS").arg(pattern).query(&mut connection);
180 if let Ok(keys) = keys
181 && !keys.is_empty()
182 {
183 let _: redis::RedisResult<()> = redis::cmd("DEL").arg(keys).query(&mut connection);
184 }
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::types::{AntiBotProvider, Detection, InvestigationReport};
192 use std::collections::BTreeMap;
193
194 fn sample_report() -> InvestigationReport {
195 InvestigationReport {
196 page_title: Some("https://example.com".to_string()),
197 total_requests: 10,
198 blocked_requests: 2,
199 status_histogram: BTreeMap::from([(200, 8), (403, 2)]),
200 resource_type_histogram: BTreeMap::new(),
201 provider_histogram: BTreeMap::new(),
202 marker_histogram: BTreeMap::new(),
203 top_markers: Vec::new(),
204 hosts: Vec::new(),
205 suspicious_requests: Vec::new(),
206 aggregate: Detection {
207 provider: AntiBotProvider::Unknown,
208 confidence: 0.1,
209 markers: Vec::new(),
210 },
211 target_class: Some(TargetClass::Api),
212 }
213 }
214
215 #[test]
216 fn memory_cache_round_trips_report() {
217 let capacity = NonZeroUsize::new(2).unwrap_or(NonZeroUsize::MIN);
218 let cache = MemoryInvestigationCache::new(capacity, Duration::from_mins(1));
219 let key = investigation_cache_key("{\"log\":{}}", TargetClass::Api);
220 let report = sample_report();
221
222 cache.put(key.clone(), report.clone());
223
224 assert_eq!(cache.get(&key), Some(report));
225 }
226
227 #[test]
228 fn memory_cache_expires_entries() {
229 let capacity = NonZeroUsize::new(2).unwrap_or(NonZeroUsize::MIN);
230 let cache = MemoryInvestigationCache::new(capacity, Duration::from_millis(1));
231 let key = investigation_cache_key("{\"log\":{}}", TargetClass::Api);
232 cache.put(key.clone(), sample_report());
233 std::thread::sleep(Duration::from_millis(5));
234 assert!(cache.get(&key).is_none());
235 }
236
237 #[test]
238 fn cache_key_changes_by_target_class() {
239 let har = "{\"log\":{\"entries\":[]}}";
240 let api = investigation_cache_key(har, TargetClass::Api);
241 let high = investigation_cache_key(har, TargetClass::HighSecurity);
242 assert_ne!(api, high);
243 }
244}