1use crate::domain::error::{CacheError, Result, StygianError};
21use crate::ports::work_queue::{TaskStatus, WorkQueuePort, WorkTask};
22use async_trait::async_trait;
23use deadpool_redis::{Config as PoolConfig, Pool, Runtime};
24use redis::AsyncCommands;
25use tracing::{debug, error, info, warn};
26
27#[derive(Debug, Clone)]
31pub struct RedisWorkQueueConfig {
32 pub url: String,
34 pub stream_name: String,
36 pub group_name: String,
38 pub consumer_name: String,
41 pub pool_size: usize,
43 pub max_retries: u32,
45 pub block_timeout_ms: usize,
47 pub idle_threshold_ms: usize,
49}
50
51impl Default for RedisWorkQueueConfig {
52 fn default() -> Self {
53 let host = std::env::var("HOSTNAME").unwrap_or_else(|_| "local".to_string());
54 let consumer_name = format!("{}:{}", host, std::process::id());
55 Self {
56 url: "redis://127.0.0.1:6379".into(),
57 stream_name: "stygian:tasks".into(),
58 group_name: "stygian-workers".into(),
59 consumer_name,
60 pool_size: 8,
61 max_retries: 3,
62 block_timeout_ms: 1000,
63 idle_threshold_ms: 30_000,
64 }
65 }
66}
67
68pub struct RedisWorkQueue {
75 pool: Pool,
76 config: RedisWorkQueueConfig,
77}
78
79impl RedisWorkQueue {
80 pub async fn new(config: RedisWorkQueueConfig) -> Result<Self> {
86 let pool_cfg = PoolConfig::from_url(&config.url);
87 let pool = pool_cfg
88 .builder()
89 .map(|b| b.max_size(config.pool_size))
90 .map_err(|e| {
91 StygianError::Cache(CacheError::WriteFailed(format!(
92 "failed to build Redis pool: {e}"
93 )))
94 })?
95 .runtime(Runtime::Tokio1)
96 .build()
97 .map_err(|e| {
98 StygianError::Cache(CacheError::WriteFailed(format!(
99 "failed to build Redis pool: {e}"
100 )))
101 })?;
102
103 let queue = Self { pool, config };
105 queue.ensure_consumer_group().await?;
106 Ok(queue)
107 }
108
109 pub async fn from_pool(pool: Pool, config: RedisWorkQueueConfig) -> Result<Self> {
111 let queue = Self { pool, config };
112 queue.ensure_consumer_group().await?;
113 Ok(queue)
114 }
115
116 async fn ensure_consumer_group(&self) -> Result<()> {
118 let mut conn = self.pool.get().await.map_err(|e| {
119 StygianError::Cache(CacheError::WriteFailed(format!("Redis pool error: {e}")))
120 })?;
121
122 let result: redis::RedisResult<String> = redis::cmd("XGROUP")
124 .arg("CREATE")
125 .arg(&self.config.stream_name)
126 .arg(&self.config.group_name)
127 .arg("$")
128 .arg("MKSTREAM")
129 .query_async(&mut *conn)
130 .await;
131
132 match result {
133 Ok(_) => {
134 debug!(
135 stream = %self.config.stream_name,
136 group = %self.config.group_name,
137 "created consumer group"
138 );
139 }
140 Err(e) if e.to_string().contains("BUSYGROUP") => {
141 debug!(
142 stream = %self.config.stream_name,
143 group = %self.config.group_name,
144 "consumer group already exists"
145 );
146 }
147 Err(e) => {
148 return Err(StygianError::Cache(CacheError::WriteFailed(format!(
149 "XGROUP CREATE failed: {e}"
150 ))));
151 }
152 }
153
154 Ok(())
155 }
156
157 fn task_meta_key(&self, task_id: &str) -> String {
159 format!("{}:tasks:{}", self.config.stream_name, task_id)
160 }
161
162 fn result_key(&self, task_id: &str) -> String {
164 format!("{}:results:{}", self.config.stream_name, task_id)
165 }
166
167 fn dlq_stream(&self) -> String {
169 format!("{}:dlq", self.config.stream_name)
170 }
171
172 pub async fn reclaim_stuck_tasks(&self) -> Result<Vec<WorkTask>> {
174 let mut conn = self.pool.get().await.map_err(|e| {
175 StygianError::Cache(CacheError::ReadFailed(format!("Redis pool error: {e}")))
176 })?;
177
178 let pending: Vec<Vec<redis::Value>> = redis::cmd("XPENDING")
180 .arg(&self.config.stream_name)
181 .arg(&self.config.group_name)
182 .arg("-")
183 .arg("+")
184 .arg(100_i64)
185 .query_async(&mut *conn)
186 .await
187 .map_err(|e| {
188 StygianError::Cache(CacheError::ReadFailed(format!("XPENDING failed: {e}")))
189 })?;
190
191 let mut reclaimed = Vec::new();
192
193 for entry in &pending {
194 if entry.len() < 3 {
196 continue;
197 }
198 let Some(raw_msg_id) = entry.first() else {
199 continue;
200 };
201 let redis::Value::BulkString(b) = raw_msg_id else {
202 continue;
203 };
204 let msg_id = String::from_utf8_lossy(b.as_slice()).to_string();
205 let idle_ms: usize = match entry.get(2) {
206 Some(redis::Value::Int(n)) => match usize::try_from(*n) {
207 Ok(v) => v,
208 Err(_) => continue,
209 },
210 _ => continue,
211 };
212
213 if idle_ms < self.config.idle_threshold_ms {
214 continue;
215 }
216
217 let claimed: redis::RedisResult<Vec<redis::Value>> = redis::cmd("XCLAIM")
219 .arg(&self.config.stream_name)
220 .arg(&self.config.group_name)
221 .arg(&self.config.consumer_name)
222 .arg(self.config.idle_threshold_ms)
223 .arg(&msg_id)
224 .query_async(&mut *conn)
225 .await;
226
227 if let Ok(messages) = claimed {
228 for msg in &messages {
229 if let Some(task) = Self::parse_stream_message(msg) {
230 info!(task_id = %task.id, idle_ms, "reclaimed stuck task");
231 reclaimed.push(task);
232 }
233 }
234 }
235 }
236
237 Ok(reclaimed)
238 }
239
240 fn parse_stream_message(msg: &redis::Value) -> Option<WorkTask> {
242 let redis::Value::Array(arr) = msg else {
244 return None;
245 };
246 if arr.len() < 2 {
247 return None;
248 }
249 let Some(redis::Value::Array(fields)) = arr.get(1) else {
250 return None;
251 };
252
253 let mut payload: Option<&[u8]> = None;
255 let mut i = 0;
256 while i + 1 < fields.len() {
257 if let Some(redis::Value::BulkString(key)) = fields.get(i)
258 && key == b"payload"
259 && let Some(redis::Value::BulkString(val)) = fields.get(i + 1)
260 {
261 payload = Some(val);
262 }
263 i += 2;
264 }
265
266 let payload = payload?;
267 serde_json::from_slice(payload).ok()
268 }
269}
270
271#[async_trait]
274impl WorkQueuePort for RedisWorkQueue {
275 async fn enqueue(&self, task: WorkTask) -> Result<()> {
276 let mut conn = self.pool.get().await.map_err(|e| {
277 StygianError::Cache(CacheError::WriteFailed(format!("Redis pool error: {e}")))
278 })?;
279
280 let payload = serde_json::to_string(&task).map_err(|e| {
281 StygianError::Cache(CacheError::WriteFailed(format!(
282 "task serialisation failed: {e}"
283 )))
284 })?;
285
286 let _msg_id: String = redis::cmd("XADD")
288 .arg(&self.config.stream_name)
289 .arg("*")
290 .arg("payload")
291 .arg(&payload)
292 .query_async(&mut *conn)
293 .await
294 .map_err(|e| {
295 StygianError::Cache(CacheError::WriteFailed(format!("XADD failed: {e}")))
296 })?;
297
298 let meta_key = self.task_meta_key(&task.id);
300 let meta = serde_json::json!({
301 "pipeline_id": task.pipeline_id,
302 "node_name": task.node_name,
303 "attempt": task.attempt,
304 "status": "pending",
305 });
306 conn.set::<_, _, ()>(&meta_key, meta.to_string())
307 .await
308 .map_err(|e| {
309 StygianError::Cache(CacheError::WriteFailed(format!(
310 "SET task meta failed: {e}"
311 )))
312 })?;
313
314 debug!(task_id = %task.id, node = %task.node_name, "enqueued task to Redis stream");
315 Ok(())
316 }
317
318 async fn try_dequeue(&self) -> Result<Option<WorkTask>> {
319 let mut conn = self.pool.get().await.map_err(|e| {
320 StygianError::Cache(CacheError::ReadFailed(format!("Redis pool error: {e}")))
321 })?;
322
323 let result: redis::RedisResult<redis::Value> = redis::cmd("XREADGROUP")
325 .arg("GROUP")
326 .arg(&self.config.group_name)
327 .arg(&self.config.consumer_name)
328 .arg("COUNT")
329 .arg(1_i64)
330 .arg("BLOCK")
331 .arg(self.config.block_timeout_ms)
332 .arg("STREAMS")
333 .arg(&self.config.stream_name)
334 .arg(">")
335 .query_async(&mut *conn)
336 .await;
337
338 let value = match result {
339 Ok(v) => v,
340 Err(e) => {
341 if e.to_string().contains("nil") {
343 return Ok(None);
344 }
345 return Err(StygianError::Cache(CacheError::ReadFailed(format!(
346 "XREADGROUP failed: {e}"
347 ))));
348 }
349 };
350
351 let streams = match &value {
353 redis::Value::Array(s) if !s.is_empty() => s,
354 _ => return Ok(None),
355 };
356
357 let stream_data = match streams.first() {
358 Some(redis::Value::Array(s)) if s.len() >= 2 => s,
359 _ => return Ok(None),
360 };
361
362 let messages = match stream_data.get(1) {
363 Some(redis::Value::Array(m)) if !m.is_empty() => m,
364 _ => return Ok(None),
365 };
366
367 if let Some(first_message) = messages.first()
368 && let Some(task) = Self::parse_stream_message(first_message)
369 {
370 let meta_key = self.task_meta_key(&task.id);
372 let meta = serde_json::json!({
373 "pipeline_id": task.pipeline_id,
374 "node_name": task.node_name,
375 "attempt": task.attempt,
376 "status": "in_progress",
377 "worker_id": self.config.consumer_name,
378 });
379 let _ = conn.set::<_, _, ()>(&meta_key, meta.to_string()).await;
380
381 debug!(task_id = %task.id, consumer = %self.config.consumer_name, "dequeued task");
382 return Ok(Some(task));
383 }
384
385 Ok(None)
386 }
387
388 async fn acknowledge(&self, task_id: &str, output: serde_json::Value) -> Result<()> {
389 let mut conn = self.pool.get().await.map_err(|e| {
390 StygianError::Cache(CacheError::WriteFailed(format!("Redis pool error: {e}")))
391 })?;
392
393 let result_key = self.result_key(task_id);
395 let output_str = output.to_string();
396 conn.set::<_, _, ()>(&result_key, &output_str)
397 .await
398 .map_err(|e| {
399 StygianError::Cache(CacheError::WriteFailed(format!("SET result failed: {e}")))
400 })?;
401
402 let meta_key = self.task_meta_key(task_id);
404 let meta_raw: Option<String> = conn.get(&meta_key).await.unwrap_or(None);
405 if let Some(raw) = meta_raw
406 && let Ok(mut meta) = serde_json::from_str::<serde_json::Value>(&raw)
407 {
408 if let Some(obj) = meta.as_object_mut() {
409 obj.insert("status".to_string(), serde_json::json!("completed"));
410 }
411 let _ = conn.set::<_, _, ()>(&meta_key, meta.to_string()).await;
412 }
413
414 info!(task_id = %task_id, "task acknowledged (completed)");
415 Ok(())
416 }
417
418 async fn fail(&self, task_id: &str, error_msg: &str) -> Result<()> {
419 let mut conn = self.pool.get().await.map_err(|e| {
420 StygianError::Cache(CacheError::WriteFailed(format!("Redis pool error: {e}")))
421 })?;
422
423 let meta_key = self.task_meta_key(task_id);
425 let meta_raw: Option<String> = conn.get(&meta_key).await.unwrap_or(None);
426
427 let attempt = meta_raw
428 .as_ref()
429 .and_then(|raw| serde_json::from_str::<serde_json::Value>(raw).ok())
430 .and_then(|m| m.get("attempt").and_then(serde_json::Value::as_u64))
431 .and_then(|n| u32::try_from(n).ok())
432 .unwrap_or(0);
433
434 if attempt >= self.config.max_retries {
435 let dlq = self.dlq_stream();
437 let dlq_payload = serde_json::json!({
438 "task_id": task_id,
439 "error": error_msg,
440 "attempt": attempt,
441 });
442 let _: redis::RedisResult<String> = redis::cmd("XADD")
443 .arg(&dlq)
444 .arg("*")
445 .arg("payload")
446 .arg(dlq_payload.to_string())
447 .query_async(&mut *conn)
448 .await;
449
450 let meta = serde_json::json!({
452 "status": "dead_letter",
453 "error": error_msg,
454 "attempt": attempt,
455 });
456 let _ = conn.set::<_, _, ()>(&meta_key, meta.to_string()).await;
457
458 warn!(task_id = %task_id, %error_msg, attempt, "task dead-lettered after max retries");
459 } else {
460 let meta = serde_json::json!({
462 "status": "failed",
463 "error": error_msg,
464 "attempt": attempt + 1,
465 });
466 let _ = conn.set::<_, _, ()>(&meta_key, meta.to_string()).await;
467
468 error!(task_id = %task_id, attempt = attempt + 1, %error_msg, "task failed, will retry");
469 }
470
471 Ok(())
472 }
473
474 async fn status(&self, task_id: &str) -> Result<Option<TaskStatus>> {
475 let mut conn = self.pool.get().await.map_err(|e| {
476 StygianError::Cache(CacheError::ReadFailed(format!("Redis pool error: {e}")))
477 })?;
478
479 let meta_key = self.task_meta_key(task_id);
480 let meta_raw: Option<String> = conn.get(&meta_key).await.unwrap_or(None);
481
482 let Some(raw) = meta_raw else {
483 return Ok(None);
484 };
485
486 let meta: serde_json::Value = serde_json::from_str(&raw).map_err(|e| {
487 StygianError::Cache(CacheError::ReadFailed(format!(
488 "task meta deserialise failed: {e}"
489 )))
490 })?;
491
492 let status_str = meta
493 .get("status")
494 .and_then(serde_json::Value::as_str)
495 .unwrap_or("pending");
496
497 let status = match status_str {
498 "in_progress" => TaskStatus::InProgress {
499 worker_id: meta
500 .get("worker_id")
501 .and_then(serde_json::Value::as_str)
502 .unwrap_or("unknown")
503 .to_string(),
504 },
505 "completed" => {
506 let result_key = self.result_key(task_id);
508 let output_raw: Option<String> = conn.get(&result_key).await.unwrap_or(None);
509 let output = output_raw
510 .and_then(|r| serde_json::from_str(&r).ok())
511 .unwrap_or(serde_json::Value::Null);
512 TaskStatus::Completed { output }
513 }
514 "failed" => TaskStatus::Failed {
515 error: meta
516 .get("error")
517 .and_then(serde_json::Value::as_str)
518 .unwrap_or("")
519 .to_string(),
520 attempt: meta
521 .get("attempt")
522 .and_then(serde_json::Value::as_u64)
523 .and_then(|n| u32::try_from(n).ok())
524 .unwrap_or(0),
525 },
526 "dead_letter" => TaskStatus::DeadLetter {
527 error: meta
528 .get("error")
529 .and_then(serde_json::Value::as_str)
530 .unwrap_or("")
531 .to_string(),
532 },
533 _ => TaskStatus::Pending,
534 };
535
536 Ok(Some(status))
537 }
538
539 async fn collect_results(&self, pipeline_id: &str) -> Result<Vec<(String, serde_json::Value)>> {
540 let mut conn = self.pool.get().await.map_err(|e| {
541 StygianError::Cache(CacheError::ReadFailed(format!("Redis pool error: {e}")))
542 })?;
543
544 let pattern = format!("{}:tasks:*", self.config.stream_name);
546 let keys: Vec<String> = redis::cmd("KEYS")
547 .arg(&pattern)
548 .query_async(&mut *conn)
549 .await
550 .map_err(|e| {
551 StygianError::Cache(CacheError::ReadFailed(format!("KEYS scan failed: {e}")))
552 })?;
553
554 let mut results = Vec::new();
555
556 for key in &keys {
557 let meta_raw: Option<String> = conn.get(key).await.unwrap_or(None);
558 let Some(raw) = meta_raw else { continue };
559 let Ok(meta) = serde_json::from_str::<serde_json::Value>(&raw) else {
560 continue;
561 };
562
563 if meta.get("pipeline_id").and_then(serde_json::Value::as_str) != Some(pipeline_id) {
565 continue;
566 }
567 if meta.get("status").and_then(serde_json::Value::as_str) != Some("completed") {
568 continue;
569 }
570
571 let node_name = meta
572 .get("node_name")
573 .and_then(serde_json::Value::as_str)
574 .unwrap_or("")
575 .to_string();
576
577 let task_id = key.rsplit(':').next().unwrap_or("");
579 let result_key = self.result_key(task_id);
580 let output_raw: Option<String> = conn.get(&result_key).await.unwrap_or(None);
581 let output = output_raw
582 .and_then(|r| serde_json::from_str(&r).ok())
583 .unwrap_or(serde_json::Value::Null);
584
585 results.push((node_name, output));
586 }
587
588 Ok(results)
589 }
590
591 async fn pending_count(&self) -> Result<usize> {
592 let mut conn = self.pool.get().await.map_err(|e| {
593 StygianError::Cache(CacheError::ReadFailed(format!("Redis pool error: {e}")))
594 })?;
595
596 let len: usize = redis::cmd("XLEN")
598 .arg(&self.config.stream_name)
599 .query_async(&mut *conn)
600 .await
601 .map_err(|e| {
602 StygianError::Cache(CacheError::ReadFailed(format!("XLEN failed: {e}")))
603 })?;
604
605 Ok(len)
606 }
607}
608
609#[cfg(test)]
612mod tests {
613 use super::*;
614 use serde_json::json;
615
616 #[test]
617 fn test_task_serialisation_roundtrip() -> std::result::Result<(), Box<dyn std::error::Error>> {
618 let task = WorkTask {
619 id: "t-1".to_string(),
620 pipeline_id: "p-1".to_string(),
621 node_name: "fetch".to_string(),
622 input: json!({"url": "https://example.com"}),
623 wave: 0,
624 attempt: 0,
625 idempotency_key: "ik-1".to_string(),
626 };
627
628 let serialised = serde_json::to_string(&task)?;
629 let deserialised: WorkTask = serde_json::from_str(&serialised)?;
630
631 assert_eq!(deserialised.id, task.id);
632 assert_eq!(deserialised.pipeline_id, task.pipeline_id);
633 assert_eq!(deserialised.node_name, task.node_name);
634 assert_eq!(deserialised.input, task.input);
635 assert_eq!(deserialised.wave, task.wave);
636 assert_eq!(deserialised.attempt, task.attempt);
637 assert_eq!(deserialised.idempotency_key, task.idempotency_key);
638 Ok(())
639 }
640
641 #[test]
642 fn test_default_config() {
643 let cfg = RedisWorkQueueConfig::default();
644 assert_eq!(cfg.url, "redis://127.0.0.1:6379");
645 assert_eq!(cfg.stream_name, "stygian:tasks");
646 assert_eq!(cfg.group_name, "stygian-workers");
647 assert_eq!(cfg.max_retries, 3);
648 assert_eq!(cfg.block_timeout_ms, 1000);
649 assert_eq!(cfg.idle_threshold_ms, 30_000);
650 assert!(!cfg.consumer_name.is_empty());
651 }
652
653 #[test]
654 fn test_key_generation() {
655 let stream_name = "stygian:tasks";
656 let task_id = "abc-123";
657 assert_eq!(
658 format!("{stream_name}:tasks:{task_id}"),
659 "stygian:tasks:tasks:abc-123"
660 );
661 assert_eq!(
662 format!("{stream_name}:results:{task_id}"),
663 "stygian:tasks:results:abc-123"
664 );
665 assert_eq!(format!("{stream_name}:dlq"), "stygian:tasks:dlq");
666 }
667
668 #[test]
669 fn test_parse_stream_message_empty() {
670 let msg = redis::Value::Nil;
671 assert!(RedisWorkQueue::parse_stream_message(&msg).is_none());
672 }
673
674 #[test]
675 fn test_parse_stream_message_valid() -> std::result::Result<(), Box<dyn std::error::Error>> {
676 let task = WorkTask {
677 id: "t-1".to_string(),
678 pipeline_id: "p-1".to_string(),
679 node_name: "fetch".to_string(),
680 input: json!({"url": "https://example.com"}),
681 wave: 0,
682 attempt: 0,
683 idempotency_key: "ik-1".to_string(),
684 };
685 let payload = serde_json::to_vec(&task)?;
686
687 let msg = redis::Value::Array(vec![
688 redis::Value::BulkString(b"1234-0".to_vec()),
689 redis::Value::Array(vec![
690 redis::Value::BulkString(b"payload".to_vec()),
691 redis::Value::BulkString(payload),
692 ]),
693 ]);
694
695 let parsed = RedisWorkQueue::parse_stream_message(&msg)
696 .ok_or_else(|| std::io::Error::other("expected parse_stream_message to return task"))?;
697 assert_eq!(parsed.id, "t-1");
698 assert_eq!(parsed.node_name, "fetch");
699 Ok(())
700 }
701
702 #[test]
703 fn test_consumer_name_is_unique() {
704 let cfg1 = RedisWorkQueueConfig::default();
705 assert!(cfg1.consumer_name.contains(&std::process::id().to_string()));
707 }
708}