1use crate::domain::error::{Result, ServiceError, StygianError};
19use crate::ports::work_queue::{TaskStatus, WorkQueuePort, WorkTask};
20use crate::ports::{ScrapingService, ServiceInput};
21use async_trait::async_trait;
22use dashmap::DashMap;
23use std::collections::VecDeque;
24use std::sync::Arc;
25use tokio::sync::Mutex;
26use tracing::{debug, error, info, warn};
27
28#[derive(Clone)]
65pub struct LocalWorkQueue {
66 pending: Arc<Mutex<VecDeque<WorkTask>>>,
67 state: Arc<DashMap<String, TaskStatus>>,
68 max_retries: u32,
70}
71
72impl LocalWorkQueue {
73 pub fn new() -> Self {
75 Self {
76 pending: Arc::new(Mutex::new(VecDeque::new())),
77 state: Arc::new(DashMap::new()),
78 max_retries: 3,
79 }
80 }
81
82 pub fn with_max_retries(max_retries: u32) -> Self {
92 Self {
93 pending: Arc::new(Mutex::new(VecDeque::new())),
94 state: Arc::new(DashMap::new()),
95 max_retries,
96 }
97 }
98}
99
100impl Default for LocalWorkQueue {
101 fn default() -> Self {
102 Self::new()
103 }
104}
105
106#[async_trait]
107impl WorkQueuePort for LocalWorkQueue {
108 async fn enqueue(&self, task: WorkTask) -> Result<()> {
109 debug!(task_id = %task.id, node = %task.node_name, "enqueuing task");
110 self.state.insert(task.id.clone(), TaskStatus::Pending);
111 self.pending.lock().await.push_back(task);
112 Ok(())
113 }
114
115 async fn try_dequeue(&self) -> Result<Option<WorkTask>> {
116 let task = self.pending.lock().await.pop_front();
117 if let Some(ref t) = task {
118 debug!(task_id = %t.id, "dequeued task");
119 self.state.insert(
120 t.id.clone(),
121 TaskStatus::InProgress {
122 worker_id: "local".to_string(),
123 },
124 );
125 }
126 Ok(task)
127 }
128
129 async fn acknowledge(&self, task_id: &str, output: serde_json::Value) -> Result<()> {
130 info!(task_id = %task_id, "task acknowledged (completed)");
131 self.state
132 .insert(task_id.to_string(), TaskStatus::Completed { output });
133 Ok(())
134 }
135
136 async fn fail(&self, task_id: &str, error: &str) -> Result<()> {
137 let attempt = self
138 .state
139 .get(task_id)
140 .map_or(0, |status| match status.value() {
141 TaskStatus::Failed { attempt, .. } => *attempt,
142 _ => 0,
143 });
144
145 if attempt >= self.max_retries {
146 warn!(task_id = %task_id, %error, "task dead-lettered after max retries");
147 self.state.insert(
148 task_id.to_string(),
149 TaskStatus::DeadLetter {
150 error: error.to_string(),
151 },
152 );
153 } else {
154 error!(task_id = %task_id, attempt, %error, "task failed, will retry");
155 self.state.insert(
156 task_id.to_string(),
157 TaskStatus::Failed {
158 error: error.to_string(),
159 attempt: attempt + 1,
160 },
161 );
162 }
163 Ok(())
164 }
165
166 async fn status(&self, task_id: &str) -> Result<Option<TaskStatus>> {
167 Ok(self.state.get(task_id).map(|s| s.value().clone()))
168 }
169
170 async fn collect_results(&self, pipeline_id: &str) -> Result<Vec<(String, serde_json::Value)>> {
171 let mut results = Vec::new();
181 for entry in self.state.iter() {
182 let key = entry.key();
183 if !key.starts_with(pipeline_id) {
185 continue;
186 }
187 if let TaskStatus::Completed { ref output } = *entry.value() {
188 let node_name = key.split("::").nth(1).unwrap_or(key).to_string();
190 results.push((node_name, output.clone()));
191 }
192 }
193 Ok(results)
194 }
195
196 async fn pending_count(&self) -> Result<usize> {
197 Ok(self.pending.lock().await.len())
198 }
199}
200
201pub struct DistributedDagExecutor<Q: WorkQueuePort> {
247 queue: Arc<Q>,
248 worker_concurrency: usize,
249}
250
251impl<Q: WorkQueuePort + 'static> DistributedDagExecutor<Q> {
252 pub fn new(queue: Arc<Q>, worker_concurrency: usize) -> Self {
257 Self {
258 queue,
259 worker_concurrency: worker_concurrency.max(1),
260 }
261 }
262
263 pub async fn execute_wave(
267 &self,
268 pipeline_id: &str,
269 tasks: Vec<WorkTask>,
270 services: &std::collections::HashMap<String, Arc<dyn ScrapingService>>,
271 ) -> Result<Vec<(String, serde_json::Value)>> {
272 let expected = tasks.len();
273 if expected == 0 {
274 return Ok(Vec::new());
275 }
276
277 for task in tasks {
279 self.queue.enqueue(task).await?;
280 }
281
282 let queue = Arc::clone(&self.queue);
284 let services: Arc<std::collections::HashMap<String, Arc<dyn ScrapingService>>> =
285 Arc::new(services.clone());
286
287 let concurrency = self.worker_concurrency.min(expected);
288 let mut handles = tokio::task::JoinSet::new();
289
290 for _ in 0..concurrency {
291 let q = Arc::clone(&queue);
292 let svcs = Arc::clone(&services);
293 handles.spawn(async move {
294 let mut worked = 0usize;
296 loop {
297 match q.try_dequeue().await {
298 Ok(Some(task)) => {
299 let service_input = ServiceInput {
300 url: task
301 .input
302 .get("url")
303 .and_then(serde_json::Value::as_str)
304 .unwrap_or("")
305 .to_string(),
306 params: task.input.clone(),
307 };
308 let output = match svcs.get(&task.node_name) {
309 Some(svc) => svc.execute(service_input.clone()).await,
310 None => {
311 match svcs.get("default") {
313 Some(svc) => svc.execute(service_input).await,
314 None => Err(StygianError::Service(
315 ServiceError::Unavailable(format!(
316 "service '{}' not registered",
317 task.node_name
318 )),
319 )),
320 }
321 }
322 };
323 match output {
324 Ok(out) => {
325 let val = serde_json::json!({
326 "data": out.data,
327 "metadata": out.metadata,
328 });
329 let _ = q.acknowledge(&task.id, val).await;
330 }
331 Err(e) => {
332 let _ = q.fail(&task.id, &e.to_string()).await;
333 }
334 }
335 worked += 1;
336 }
337 Ok(None) => break, Err(e) => {
339 error!(error = %e, "worker dequeue error");
340 break;
341 }
342 }
343 }
344 worked
345 });
346 }
347
348 while handles.join_next().await.is_some() {}
350
351 self.queue.collect_results(pipeline_id).await
353 }
354}
355
356#[cfg(test)]
361#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
362mod tests {
363 use super::*;
364 use serde_json::json;
365
366 fn make_task(pipeline_id: &str, node_name: &str, seq: u32) -> WorkTask {
367 WorkTask {
368 id: format!("{pipeline_id}::{node_name}::{seq:04}"),
369 pipeline_id: pipeline_id.to_string(),
370 node_name: node_name.to_string(),
371 input: json!({"url": "https://example.com"}),
372 wave: 0,
373 attempt: 0,
374 idempotency_key: format!("ik-{seq}"),
375 }
376 }
377
378 #[tokio::test]
379 async fn enqueue_dequeue_roundtrip() {
380 let queue = LocalWorkQueue::new();
381 assert_eq!(queue.pending_count().await.unwrap(), 0);
382
383 queue.enqueue(make_task("p1", "fetch", 1)).await.unwrap();
384 queue.enqueue(make_task("p1", "parse", 2)).await.unwrap();
385 assert_eq!(queue.pending_count().await.unwrap(), 2);
386
387 let t1 = queue.try_dequeue().await.unwrap().unwrap();
388 assert_eq!(t1.node_name, "fetch");
389 assert_eq!(queue.pending_count().await.unwrap(), 1);
390
391 let t2 = queue.try_dequeue().await.unwrap().unwrap();
392 assert_eq!(t2.node_name, "parse");
393 assert_eq!(queue.pending_count().await.unwrap(), 0);
394
395 let empty = queue.try_dequeue().await.unwrap();
397 assert!(empty.is_none());
398 }
399
400 #[tokio::test]
401 async fn acknowledge_records_completed_status() {
402 let queue = LocalWorkQueue::new();
403 queue.enqueue(make_task("p1", "fetch", 1)).await.unwrap();
404 let task = queue.try_dequeue().await.unwrap().unwrap();
405 queue
406 .acknowledge(&task.id, json!({"data": "hello", "status": 200}))
407 .await
408 .unwrap();
409
410 let status = queue.status(&task.id).await.unwrap().unwrap();
411 assert!(matches!(status, TaskStatus::Completed { .. }));
412 }
413
414 #[tokio::test]
415 async fn fail_dead_letters_after_max_retries() {
416 let queue = LocalWorkQueue::with_max_retries(2);
417 queue.enqueue(make_task("p1", "fetch", 1)).await.unwrap();
418 let task = queue.try_dequeue().await.unwrap().unwrap();
419
420 queue.fail(&task.id, "err 1").await.unwrap();
421 queue.fail(&task.id, "err 2").await.unwrap();
422 queue.fail(&task.id, "err 3").await.unwrap();
424
425 let status = queue.status(&task.id).await.unwrap().unwrap();
426 assert!(matches!(status, TaskStatus::DeadLetter { .. }));
427 }
428
429 #[tokio::test]
430 async fn collect_results_filters_by_pipeline_id() {
431 let queue = LocalWorkQueue::new();
432
433 let t1 = make_task("pipeline-A", "node1", 1);
435 let t2 = make_task("pipeline-B", "node1", 2);
436
437 queue.enqueue(t1.clone()).await.unwrap();
438 queue.enqueue(t2.clone()).await.unwrap();
439
440 let deq1 = queue.try_dequeue().await.unwrap().unwrap();
442 let deq2 = queue.try_dequeue().await.unwrap().unwrap();
443
444 queue
445 .acknowledge(&deq1.id, json!({"data": "A-result"}))
446 .await
447 .unwrap();
448 queue
449 .acknowledge(&deq2.id, json!({"data": "B-result"}))
450 .await
451 .unwrap();
452
453 let results_a = queue.collect_results("pipeline-A").await.unwrap();
454 assert_eq!(results_a.len(), 1);
455 assert_eq!(results_a[0].1["data"], "A-result");
456
457 let results_b = queue.collect_results("pipeline-B").await.unwrap();
458 assert_eq!(results_b.len(), 1);
459 assert_eq!(results_b[0].1["data"], "B-result");
460 }
461
462 #[tokio::test]
463 async fn distributed_executor_runs_tasks() {
464 use crate::adapters::noop::NoopService;
465 use std::collections::HashMap;
466
467 let queue = Arc::new(LocalWorkQueue::new());
468 let executor = DistributedDagExecutor::new(Arc::clone(&queue), 2);
469
470 let mut services: HashMap<String, Arc<dyn ScrapingService>> = HashMap::new();
471 services.insert("noop".to_string(), Arc::new(NoopService));
472
473 let tasks = vec![
474 make_task("p1", "noop", 1),
475 make_task("p1", "noop", 2),
476 make_task("p1", "noop", 3),
477 ];
478
479 let results = executor.execute_wave("p1", tasks, &services).await.unwrap();
482 assert!(results.len() <= 3);
484 }
485}