1use crate::domain::error::{CacheError, Result, StygianError};
21use crate::ports::work_queue::{TaskStatus, WorkQueuePort, WorkTask};
22use async_trait::async_trait;
23use deadpool_redis::{Config as PoolConfig, Pool, Runtime};
24use redis::AsyncCommands;
25use tracing::{debug, error, info, warn};
26
27#[derive(Debug, Clone)]
31pub struct RedisWorkQueueConfig {
32 pub url: String,
34 pub stream_name: String,
36 pub group_name: String,
38 pub consumer_name: String,
41 pub pool_size: usize,
43 pub max_retries: u32,
45 pub block_timeout_ms: usize,
47 pub idle_threshold_ms: usize,
49}
50
51impl Default for RedisWorkQueueConfig {
52 fn default() -> Self {
53 let host = std::env::var("HOSTNAME").unwrap_or_else(|_| "local".to_string());
54 let consumer_name = format!("{}:{}", host, std::process::id());
55 Self {
56 url: "redis://127.0.0.1:6379".into(),
57 stream_name: "stygian:tasks".into(),
58 group_name: "stygian-workers".into(),
59 consumer_name,
60 pool_size: 8,
61 max_retries: 3,
62 block_timeout_ms: 1000,
63 idle_threshold_ms: 30_000,
64 }
65 }
66}
67
68pub struct RedisWorkQueue {
75 pool: Pool,
76 config: RedisWorkQueueConfig,
77}
78
79impl RedisWorkQueue {
80 pub async fn new(config: RedisWorkQueueConfig) -> Result<Self> {
86 let pool_cfg = PoolConfig::from_url(&config.url);
87 let pool = pool_cfg
88 .builder()
89 .map(|b| b.max_size(config.pool_size))
90 .map_err(|e| {
91 StygianError::Cache(CacheError::WriteFailed(format!(
92 "failed to build Redis pool: {e}"
93 )))
94 })?
95 .runtime(Runtime::Tokio1)
96 .build()
97 .map_err(|e| {
98 StygianError::Cache(CacheError::WriteFailed(format!(
99 "failed to build Redis pool: {e}"
100 )))
101 })?;
102
103 let queue = Self { pool, config };
105 queue.ensure_consumer_group().await?;
106 Ok(queue)
107 }
108
109 pub async fn from_pool(pool: Pool, config: RedisWorkQueueConfig) -> Result<Self> {
116 let queue = Self { pool, config };
117 queue.ensure_consumer_group().await?;
118 Ok(queue)
119 }
120
121 async fn ensure_consumer_group(&self) -> Result<()> {
123 let mut conn = self.pool.get().await.map_err(|e| {
124 StygianError::Cache(CacheError::WriteFailed(format!("Redis pool error: {e}")))
125 })?;
126
127 let result: redis::RedisResult<String> = redis::cmd("XGROUP")
129 .arg("CREATE")
130 .arg(&self.config.stream_name)
131 .arg(&self.config.group_name)
132 .arg("$")
133 .arg("MKSTREAM")
134 .query_async(&mut *conn)
135 .await;
136
137 match result {
138 Ok(_) => {
139 debug!(
140 stream = %self.config.stream_name,
141 group = %self.config.group_name,
142 "created consumer group"
143 );
144 }
145 Err(e) if e.to_string().contains("BUSYGROUP") => {
146 debug!(
147 stream = %self.config.stream_name,
148 group = %self.config.group_name,
149 "consumer group already exists"
150 );
151 }
152 Err(e) => {
153 return Err(StygianError::Cache(CacheError::WriteFailed(format!(
154 "XGROUP CREATE failed: {e}"
155 ))));
156 }
157 }
158
159 Ok(())
160 }
161
162 fn task_meta_key(&self, task_id: &str) -> String {
164 format!("{}:tasks:{}", self.config.stream_name, task_id)
165 }
166
167 fn result_key(&self, task_id: &str) -> String {
169 format!("{}:results:{}", self.config.stream_name, task_id)
170 }
171
172 fn dlq_stream(&self) -> String {
174 format!("{}:dlq", self.config.stream_name)
175 }
176
177 pub async fn reclaim_stuck_tasks(&self) -> Result<Vec<WorkTask>> {
184 let mut conn = self.pool.get().await.map_err(|e| {
185 StygianError::Cache(CacheError::ReadFailed(format!("Redis pool error: {e}")))
186 })?;
187
188 let pending: Vec<Vec<redis::Value>> = redis::cmd("XPENDING")
190 .arg(&self.config.stream_name)
191 .arg(&self.config.group_name)
192 .arg("-")
193 .arg("+")
194 .arg(100_i64)
195 .query_async(&mut *conn)
196 .await
197 .map_err(|e| {
198 StygianError::Cache(CacheError::ReadFailed(format!("XPENDING failed: {e}")))
199 })?;
200
201 let mut reclaimed = Vec::new();
202
203 for entry in &pending {
204 if entry.len() < 3 {
206 continue;
207 }
208 let Some(raw_msg_id) = entry.first() else {
209 continue;
210 };
211 let redis::Value::BulkString(b) = raw_msg_id else {
212 continue;
213 };
214 let msg_id = String::from_utf8_lossy(b.as_slice()).to_string();
215 let idle_ms: usize = match entry.get(2) {
216 Some(redis::Value::Int(n)) => match usize::try_from(*n) {
217 Ok(v) => v,
218 Err(_) => continue,
219 },
220 _ => continue,
221 };
222
223 if idle_ms < self.config.idle_threshold_ms {
224 continue;
225 }
226
227 let claimed: redis::RedisResult<Vec<redis::Value>> = redis::cmd("XCLAIM")
229 .arg(&self.config.stream_name)
230 .arg(&self.config.group_name)
231 .arg(&self.config.consumer_name)
232 .arg(self.config.idle_threshold_ms)
233 .arg(&msg_id)
234 .query_async(&mut *conn)
235 .await;
236
237 if let Ok(messages) = claimed {
238 for msg in &messages {
239 if let Some(task) = Self::parse_stream_message(msg) {
240 info!(task_id = %task.id, idle_ms, "reclaimed stuck task");
241 reclaimed.push(task);
242 }
243 }
244 }
245 }
246
247 Ok(reclaimed)
248 }
249
250 fn parse_stream_message(msg: &redis::Value) -> Option<WorkTask> {
252 let redis::Value::Array(arr) = msg else {
254 return None;
255 };
256 if arr.len() < 2 {
257 return None;
258 }
259 let Some(redis::Value::Array(fields)) = arr.get(1) else {
260 return None;
261 };
262
263 let mut payload: Option<&[u8]> = None;
265 let mut i = 0;
266 while i + 1 < fields.len() {
267 if let Some(redis::Value::BulkString(key)) = fields.get(i)
268 && key == b"payload"
269 && let Some(redis::Value::BulkString(val)) = fields.get(i + 1)
270 {
271 payload = Some(val);
272 }
273 i += 2;
274 }
275
276 let payload = payload?;
277 serde_json::from_slice(payload).ok()
278 }
279}
280
281#[async_trait]
284impl WorkQueuePort for RedisWorkQueue {
285 async fn enqueue(&self, task: WorkTask) -> Result<()> {
286 let mut conn = self.pool.get().await.map_err(|e| {
287 StygianError::Cache(CacheError::WriteFailed(format!("Redis pool error: {e}")))
288 })?;
289
290 let payload = serde_json::to_string(&task).map_err(|e| {
291 StygianError::Cache(CacheError::WriteFailed(format!(
292 "task serialisation failed: {e}"
293 )))
294 })?;
295
296 let _msg_id: String = redis::cmd("XADD")
298 .arg(&self.config.stream_name)
299 .arg("*")
300 .arg("payload")
301 .arg(&payload)
302 .query_async(&mut *conn)
303 .await
304 .map_err(|e| {
305 StygianError::Cache(CacheError::WriteFailed(format!("XADD failed: {e}")))
306 })?;
307
308 let meta_key = self.task_meta_key(&task.id);
310 let meta = serde_json::json!({
311 "pipeline_id": task.pipeline_id,
312 "node_name": task.node_name,
313 "attempt": task.attempt,
314 "status": "pending",
315 });
316 conn.set::<_, _, ()>(&meta_key, meta.to_string())
317 .await
318 .map_err(|e| {
319 StygianError::Cache(CacheError::WriteFailed(format!(
320 "SET task meta failed: {e}"
321 )))
322 })?;
323
324 debug!(task_id = %task.id, node = %task.node_name, "enqueued task to Redis stream");
325 Ok(())
326 }
327
328 async fn try_dequeue(&self) -> Result<Option<WorkTask>> {
329 let mut conn = self.pool.get().await.map_err(|e| {
330 StygianError::Cache(CacheError::ReadFailed(format!("Redis pool error: {e}")))
331 })?;
332
333 let result: redis::RedisResult<redis::Value> = redis::cmd("XREADGROUP")
335 .arg("GROUP")
336 .arg(&self.config.group_name)
337 .arg(&self.config.consumer_name)
338 .arg("COUNT")
339 .arg(1_i64)
340 .arg("BLOCK")
341 .arg(self.config.block_timeout_ms)
342 .arg("STREAMS")
343 .arg(&self.config.stream_name)
344 .arg(">")
345 .query_async(&mut *conn)
346 .await;
347
348 let value = match result {
349 Ok(v) => v,
350 Err(e) => {
351 if e.to_string().contains("nil") {
353 return Ok(None);
354 }
355 return Err(StygianError::Cache(CacheError::ReadFailed(format!(
356 "XREADGROUP failed: {e}"
357 ))));
358 }
359 };
360
361 let streams = match &value {
363 redis::Value::Array(s) if !s.is_empty() => s,
364 _ => return Ok(None),
365 };
366
367 let stream_data = match streams.first() {
368 Some(redis::Value::Array(s)) if s.len() >= 2 => s,
369 _ => return Ok(None),
370 };
371
372 let messages = match stream_data.get(1) {
373 Some(redis::Value::Array(m)) if !m.is_empty() => m,
374 _ => return Ok(None),
375 };
376
377 if let Some(first_message) = messages.first()
378 && let Some(task) = Self::parse_stream_message(first_message)
379 {
380 let meta_key = self.task_meta_key(&task.id);
382 let meta = serde_json::json!({
383 "pipeline_id": task.pipeline_id,
384 "node_name": task.node_name,
385 "attempt": task.attempt,
386 "status": "in_progress",
387 "worker_id": self.config.consumer_name,
388 });
389 let _ = conn.set::<_, _, ()>(&meta_key, meta.to_string()).await;
390
391 debug!(task_id = %task.id, consumer = %self.config.consumer_name, "dequeued task");
392 return Ok(Some(task));
393 }
394
395 Ok(None)
396 }
397
398 async fn acknowledge(&self, task_id: &str, output: serde_json::Value) -> Result<()> {
399 let mut conn = self.pool.get().await.map_err(|e| {
400 StygianError::Cache(CacheError::WriteFailed(format!("Redis pool error: {e}")))
401 })?;
402
403 let result_key = self.result_key(task_id);
405 let output_str = output.to_string();
406 conn.set::<_, _, ()>(&result_key, &output_str)
407 .await
408 .map_err(|e| {
409 StygianError::Cache(CacheError::WriteFailed(format!("SET result failed: {e}")))
410 })?;
411
412 let meta_key = self.task_meta_key(task_id);
414 let meta_raw: Option<String> = conn.get(&meta_key).await.unwrap_or(None);
415 if let Some(raw) = meta_raw
416 && let Ok(mut meta) = serde_json::from_str::<serde_json::Value>(&raw)
417 {
418 if let Some(obj) = meta.as_object_mut() {
419 obj.insert("status".to_string(), serde_json::json!("completed"));
420 }
421 let _ = conn.set::<_, _, ()>(&meta_key, meta.to_string()).await;
422 }
423
424 info!(task_id = %task_id, "task acknowledged (completed)");
425 Ok(())
426 }
427
428 async fn fail(&self, task_id: &str, error_msg: &str) -> Result<()> {
429 let mut conn = self.pool.get().await.map_err(|e| {
430 StygianError::Cache(CacheError::WriteFailed(format!("Redis pool error: {e}")))
431 })?;
432
433 let meta_key = self.task_meta_key(task_id);
435 let meta_raw: Option<String> = conn.get(&meta_key).await.unwrap_or(None);
436
437 let attempt = meta_raw
438 .as_ref()
439 .and_then(|raw| serde_json::from_str::<serde_json::Value>(raw).ok())
440 .and_then(|m| m.get("attempt").and_then(serde_json::Value::as_u64))
441 .and_then(|n| u32::try_from(n).ok())
442 .unwrap_or(0);
443
444 if attempt >= self.config.max_retries {
445 let dlq = self.dlq_stream();
447 let dlq_payload = serde_json::json!({
448 "task_id": task_id,
449 "error": error_msg,
450 "attempt": attempt,
451 });
452 let _: redis::RedisResult<String> = redis::cmd("XADD")
453 .arg(&dlq)
454 .arg("*")
455 .arg("payload")
456 .arg(dlq_payload.to_string())
457 .query_async(&mut *conn)
458 .await;
459
460 let meta = serde_json::json!({
462 "status": "dead_letter",
463 "error": error_msg,
464 "attempt": attempt,
465 });
466 let _ = conn.set::<_, _, ()>(&meta_key, meta.to_string()).await;
467
468 warn!(task_id = %task_id, %error_msg, attempt, "task dead-lettered after max retries");
469 } else {
470 let meta = serde_json::json!({
472 "status": "failed",
473 "error": error_msg,
474 "attempt": attempt + 1,
475 });
476 let _ = conn.set::<_, _, ()>(&meta_key, meta.to_string()).await;
477
478 error!(task_id = %task_id, attempt = attempt + 1, %error_msg, "task failed, will retry");
479 }
480
481 Ok(())
482 }
483
484 async fn status(&self, task_id: &str) -> Result<Option<TaskStatus>> {
485 let mut conn = self.pool.get().await.map_err(|e| {
486 StygianError::Cache(CacheError::ReadFailed(format!("Redis pool error: {e}")))
487 })?;
488
489 let meta_key = self.task_meta_key(task_id);
490 let meta_raw: Option<String> = conn.get(&meta_key).await.unwrap_or(None);
491
492 let Some(raw) = meta_raw else {
493 return Ok(None);
494 };
495
496 let meta: serde_json::Value = serde_json::from_str(&raw).map_err(|e| {
497 StygianError::Cache(CacheError::ReadFailed(format!(
498 "task meta deserialise failed: {e}"
499 )))
500 })?;
501
502 let status_str = meta
503 .get("status")
504 .and_then(serde_json::Value::as_str)
505 .unwrap_or("pending");
506
507 let status = match status_str {
508 "in_progress" => TaskStatus::InProgress {
509 worker_id: meta
510 .get("worker_id")
511 .and_then(serde_json::Value::as_str)
512 .unwrap_or("unknown")
513 .to_string(),
514 },
515 "completed" => {
516 let result_key = self.result_key(task_id);
518 let output_raw: Option<String> = conn.get(&result_key).await.unwrap_or(None);
519 let output = output_raw
520 .and_then(|r| serde_json::from_str(&r).ok())
521 .unwrap_or(serde_json::Value::Null);
522 TaskStatus::Completed { output }
523 }
524 "failed" => TaskStatus::Failed {
525 error: meta
526 .get("error")
527 .and_then(serde_json::Value::as_str)
528 .unwrap_or("")
529 .to_string(),
530 attempt: meta
531 .get("attempt")
532 .and_then(serde_json::Value::as_u64)
533 .and_then(|n| u32::try_from(n).ok())
534 .unwrap_or(0),
535 },
536 "dead_letter" => TaskStatus::DeadLetter {
537 error: meta
538 .get("error")
539 .and_then(serde_json::Value::as_str)
540 .unwrap_or("")
541 .to_string(),
542 },
543 _ => TaskStatus::Pending,
544 };
545
546 Ok(Some(status))
547 }
548
549 async fn collect_results(&self, pipeline_id: &str) -> Result<Vec<(String, serde_json::Value)>> {
550 let mut conn = self.pool.get().await.map_err(|e| {
551 StygianError::Cache(CacheError::ReadFailed(format!("Redis pool error: {e}")))
552 })?;
553
554 let pattern = format!("{}:tasks:*", self.config.stream_name);
556 let keys: Vec<String> = redis::cmd("KEYS")
557 .arg(&pattern)
558 .query_async(&mut *conn)
559 .await
560 .map_err(|e| {
561 StygianError::Cache(CacheError::ReadFailed(format!("KEYS scan failed: {e}")))
562 })?;
563
564 let mut results = Vec::new();
565
566 for key in &keys {
567 let meta_raw: Option<String> = conn.get(key).await.unwrap_or(None);
568 let Some(raw) = meta_raw else { continue };
569 let Ok(meta) = serde_json::from_str::<serde_json::Value>(&raw) else {
570 continue;
571 };
572
573 if meta.get("pipeline_id").and_then(serde_json::Value::as_str) != Some(pipeline_id) {
575 continue;
576 }
577 if meta.get("status").and_then(serde_json::Value::as_str) != Some("completed") {
578 continue;
579 }
580
581 let node_name = meta
582 .get("node_name")
583 .and_then(serde_json::Value::as_str)
584 .unwrap_or("")
585 .to_string();
586
587 let task_id = key.rsplit(':').next().unwrap_or("");
589 let result_key = self.result_key(task_id);
590 let output_raw: Option<String> = conn.get(&result_key).await.unwrap_or(None);
591 let output = output_raw
592 .and_then(|r| serde_json::from_str(&r).ok())
593 .unwrap_or(serde_json::Value::Null);
594
595 results.push((node_name, output));
596 }
597
598 Ok(results)
599 }
600
601 async fn pending_count(&self) -> Result<usize> {
602 let mut conn = self.pool.get().await.map_err(|e| {
603 StygianError::Cache(CacheError::ReadFailed(format!("Redis pool error: {e}")))
604 })?;
605
606 let len: usize = redis::cmd("XLEN")
608 .arg(&self.config.stream_name)
609 .query_async(&mut *conn)
610 .await
611 .map_err(|e| {
612 StygianError::Cache(CacheError::ReadFailed(format!("XLEN failed: {e}")))
613 })?;
614
615 Ok(len)
616 }
617}
618
619#[cfg(test)]
622mod tests {
623 use super::*;
624 use serde_json::json;
625
626 #[test]
627 fn test_task_serialisation_roundtrip() -> std::result::Result<(), Box<dyn std::error::Error>> {
628 let task = WorkTask {
629 id: "t-1".to_string(),
630 pipeline_id: "p-1".to_string(),
631 node_name: "fetch".to_string(),
632 input: json!({"url": "https://example.com"}),
633 wave: 0,
634 attempt: 0,
635 idempotency_key: "ik-1".to_string(),
636 };
637
638 let serialised = serde_json::to_string(&task)?;
639 let deserialised: WorkTask = serde_json::from_str(&serialised)?;
640
641 assert_eq!(deserialised.id, task.id);
642 assert_eq!(deserialised.pipeline_id, task.pipeline_id);
643 assert_eq!(deserialised.node_name, task.node_name);
644 assert_eq!(deserialised.input, task.input);
645 assert_eq!(deserialised.wave, task.wave);
646 assert_eq!(deserialised.attempt, task.attempt);
647 assert_eq!(deserialised.idempotency_key, task.idempotency_key);
648 Ok(())
649 }
650
651 #[test]
652 fn test_default_config() {
653 let cfg = RedisWorkQueueConfig::default();
654 assert_eq!(cfg.url, "redis://127.0.0.1:6379");
655 assert_eq!(cfg.stream_name, "stygian:tasks");
656 assert_eq!(cfg.group_name, "stygian-workers");
657 assert_eq!(cfg.max_retries, 3);
658 assert_eq!(cfg.block_timeout_ms, 1000);
659 assert_eq!(cfg.idle_threshold_ms, 30_000);
660 assert!(!cfg.consumer_name.is_empty());
661 }
662
663 #[test]
664 fn test_key_generation() {
665 let stream_name = "stygian:tasks";
666 let task_id = "abc-123";
667 assert_eq!(
668 format!("{stream_name}:tasks:{task_id}"),
669 "stygian:tasks:tasks:abc-123"
670 );
671 assert_eq!(
672 format!("{stream_name}:results:{task_id}"),
673 "stygian:tasks:results:abc-123"
674 );
675 assert_eq!(format!("{stream_name}:dlq"), "stygian:tasks:dlq");
676 }
677
678 #[test]
679 fn test_parse_stream_message_empty() {
680 let msg = redis::Value::Nil;
681 assert!(RedisWorkQueue::parse_stream_message(&msg).is_none());
682 }
683
684 #[test]
685 fn test_parse_stream_message_valid() -> std::result::Result<(), Box<dyn std::error::Error>> {
686 let task = WorkTask {
687 id: "t-1".to_string(),
688 pipeline_id: "p-1".to_string(),
689 node_name: "fetch".to_string(),
690 input: json!({"url": "https://example.com"}),
691 wave: 0,
692 attempt: 0,
693 idempotency_key: "ik-1".to_string(),
694 };
695 let payload = serde_json::to_vec(&task)?;
696
697 let msg = redis::Value::Array(vec![
698 redis::Value::BulkString(b"1234-0".to_vec()),
699 redis::Value::Array(vec![
700 redis::Value::BulkString(b"payload".to_vec()),
701 redis::Value::BulkString(payload),
702 ]),
703 ]);
704
705 let parsed = RedisWorkQueue::parse_stream_message(&msg)
706 .ok_or_else(|| std::io::Error::other("expected parse_stream_message to return task"))?;
707 assert_eq!(parsed.id, "t-1");
708 assert_eq!(parsed.node_name, "fetch");
709 Ok(())
710 }
711
712 #[test]
713 fn test_consumer_name_is_unique() {
714 let cfg1 = RedisWorkQueueConfig::default();
715 assert!(cfg1.consumer_name.contains(&std::process::id().to_string()));
717 }
718}