stygian_charon/challenge_feedback/
memory.rs1use std::num::NonZeroUsize;
2use std::time::{Duration, SystemTime, UNIX_EPOCH};
3
4use serde::{Deserialize, Serialize};
5
6use crate::cache::LruTtlStore;
7use crate::challenge_feedback::ChallengeOutcome;
8use crate::types::TargetClass;
9
10pub const DEFAULT_CHALLENGE_TTL: Duration = Duration::from_mins(10);
18
19#[allow(clippy::unwrap_used)]
23pub const DEFAULT_CHALLENGE_CAPACITY: NonZeroUsize = match NonZeroUsize::new(64) {
24 Some(value) => value,
25 None => NonZeroUsize::MIN,
26};
27
28const ZERO_FALLBACK_UNIX_SECS: u64 = 0;
33
34#[must_use]
47pub fn challenge_memory_key(domain: &str, target_class: TargetClass) -> String {
48 format!(
49 "charon:challenge:{}:{}",
50 domain.to_ascii_lowercase(),
51 target_class_label(target_class)
52 )
53}
54
55const fn target_class_label(c: TargetClass) -> &'static str {
56 match c {
57 TargetClass::Api => "api",
58 TargetClass::ContentSite => "content_site",
59 TargetClass::HighSecurity => "high_security",
60 TargetClass::Unknown => "unknown",
61 }
62}
63
64#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
91pub struct ChallengeMemoryEntry {
92 pub domain: String,
94 pub target_class: TargetClass,
96 pub last_outcome: ChallengeOutcome,
98 pub observation_count: u32,
101 pub recorded_at_unix_secs: u64,
103}
104
105impl ChallengeMemoryEntry {
106 #[must_use]
112 pub const fn risk_delta(&self) -> f64 {
113 self.last_outcome.risk_delta()
114 }
115}
116
117pub struct ChallengeMemory {
141 store: LruTtlStore<ChallengeMemoryEntry>,
142}
143
144impl ChallengeMemory {
145 #[must_use]
147 pub fn new(capacity: NonZeroUsize, ttl: Duration) -> Self {
148 Self {
149 store: LruTtlStore::new(capacity, ttl),
150 }
151 }
152
153 #[must_use]
157 pub fn with_default_ttl(capacity: NonZeroUsize) -> Self {
158 Self::new(capacity, DEFAULT_CHALLENGE_TTL)
159 }
160
161 #[must_use]
164 pub fn with_defaults() -> Self {
165 Self::new(DEFAULT_CHALLENGE_CAPACITY, DEFAULT_CHALLENGE_TTL)
166 }
167
168 pub fn record(&self, domain: &str, target_class: TargetClass, outcome: ChallengeOutcome) {
186 let key = challenge_memory_key(domain, target_class);
187 let lower = domain.to_ascii_lowercase();
188 let next_count = self
189 .store
190 .peek(&key)
191 .map_or(1, |existing| existing.observation_count.saturating_add(1));
192 let entry = ChallengeMemoryEntry {
193 domain: lower,
194 target_class,
195 last_outcome: outcome,
196 observation_count: next_count,
197 recorded_at_unix_secs: current_unix_secs(),
198 };
199 self.store.put(key, entry);
200 }
201
202 #[must_use]
215 pub fn lookup(&self, domain: &str, target_class: TargetClass) -> Option<ChallengeMemoryEntry> {
216 self.store.get(&challenge_memory_key(domain, target_class))
217 }
218
219 #[must_use]
221 pub fn len(&self) -> usize {
222 self.store.len()
223 }
224
225 #[must_use]
227 pub fn is_empty(&self) -> bool {
228 self.store.is_empty()
229 }
230
231 pub fn clear(&self) {
233 self.store.clear();
234 }
235
236 pub fn invalidate(&self, domain: &str, target_class: TargetClass) {
238 self.store
239 .invalidate(&challenge_memory_key(domain, target_class));
240 }
241}
242
243fn current_unix_secs() -> u64 {
244 SystemTime::now()
245 .duration_since(UNIX_EPOCH)
246 .map_or(ZERO_FALLBACK_UNIX_SECS, |duration| duration.as_secs())
247}
248
249#[cfg(test)]
250#[allow(
251 clippy::unwrap_used,
252 clippy::expect_used,
253 clippy::panic,
254 clippy::indexing_slicing
255)]
256mod tests {
257 use super::*;
258 use std::thread;
259
260 #[test]
261 fn record_overwrites_last_outcome_and_increments_count() {
262 let memory = ChallengeMemory::new(NonZeroUsize::new(4).unwrap(), Duration::from_mins(1));
263 let key = ("example.com", TargetClass::ContentSite);
264
265 memory.record(key.0, key.1, ChallengeOutcome::Pass);
266 memory.record(key.0, key.1, ChallengeOutcome::HardChallenge);
267 memory.record(key.0, key.1, ChallengeOutcome::Captcha);
268
269 let entry = memory.lookup(key.0, key.1).expect("entry present");
270 assert_eq!(entry.last_outcome, ChallengeOutcome::Captcha);
271 assert_eq!(entry.observation_count, 3);
272 assert_eq!(entry.domain, "example.com");
273 assert_eq!(entry.target_class, TargetClass::ContentSite);
274 }
275
276 #[test]
277 fn entries_decay_after_ttl() {
278 let memory = ChallengeMemory::new(NonZeroUsize::new(4).unwrap(), Duration::from_millis(1));
279 memory.record("example.com", TargetClass::Api, ChallengeOutcome::Blocked);
280 thread::sleep(Duration::from_millis(5));
281 assert!(memory.lookup("example.com", TargetClass::Api).is_none());
282 }
283
284 #[test]
285 fn distinct_target_classes_keep_distinct_entries() {
286 let memory = ChallengeMemory::new(NonZeroUsize::new(8).unwrap(), Duration::from_mins(1));
287
288 memory.record("example.com", TargetClass::Api, ChallengeOutcome::Pass);
289 memory.record(
290 "example.com",
291 TargetClass::ContentSite,
292 ChallengeOutcome::Captcha,
293 );
294
295 let api = memory.lookup("example.com", TargetClass::Api).unwrap();
296 let content = memory
297 .lookup("example.com", TargetClass::ContentSite)
298 .unwrap();
299
300 assert_eq!(api.last_outcome, ChallengeOutcome::Pass);
301 assert_eq!(content.last_outcome, ChallengeOutcome::Captcha);
302 }
303
304 #[test]
305 fn clear_drops_everything() {
306 let memory = ChallengeMemory::new(NonZeroUsize::new(4).unwrap(), Duration::from_mins(1));
307 memory.record("example.com", TargetClass::Api, ChallengeOutcome::Pass);
308 memory.record("other.example", TargetClass::Api, ChallengeOutcome::Blocked);
309 assert_eq!(memory.len(), 2);
310 memory.clear();
311 assert!(memory.is_empty());
312 }
313
314 #[test]
315 fn domain_is_normalised_to_lower_case() {
316 let memory = ChallengeMemory::new(NonZeroUsize::new(4).unwrap(), Duration::from_mins(1));
317 memory.record(
318 "Example.COM",
319 TargetClass::Api,
320 ChallengeOutcome::SoftChallenge,
321 );
322 let entry = memory.lookup("EXAMPLE.com", TargetClass::Api).unwrap();
323 assert_eq!(entry.domain, "example.com");
324 assert_eq!(entry.last_outcome, ChallengeOutcome::SoftChallenge);
325 }
326
327 #[test]
328 fn risk_delta_uses_last_outcome() {
329 let memory = ChallengeMemory::new(NonZeroUsize::new(4).unwrap(), Duration::from_mins(1));
330 memory.record(
331 "example.com",
332 TargetClass::Api,
333 ChallengeOutcome::HardChallenge,
334 );
335 let entry = memory.lookup("example.com", TargetClass::Api).unwrap();
336 assert!((entry.risk_delta() - ChallengeOutcome::HardChallenge.risk_delta()).abs() < 1e-9);
337 }
338
339 #[test]
340 fn lru_capacity_is_respected() {
341 let memory = ChallengeMemory::new(NonZeroUsize::new(2).unwrap(), Duration::from_mins(1));
342 memory.record("a.example", TargetClass::Api, ChallengeOutcome::Pass);
343 memory.record("b.example", TargetClass::Api, ChallengeOutcome::Pass);
344 memory.record("c.example", TargetClass::Api, ChallengeOutcome::Pass);
345 assert!(memory.len() <= 2);
346 }
347}