stygian_graph/adapters/
distributed.rs

1//! Distributed execution adapters
2//!
3//! Provides [`LocalWorkQueue`](distributed::LocalWorkQueue) (in-process, for single-node and testing) and
4//! [`DistributedDagExecutor`](distributed::DistributedDagExecutor) (wraps any [`WorkQueuePort`](crate::ports::work_queue::WorkQueuePort) to distribute DAG
5//! waves across workers).
6//!
7//! # Design
8//!
9//! ```text
10//! DistributedDagExecutor
11//!    │
12//!    ├─ resolve wave N (topological sort already done by DagExecutor)
13//!    ├─ enqueue every node in the wave as a WorkTask
14//!    ├─ spawn worker tasks that call try_dequeue + service.execute
15//!    └─ collect_results when all tasks in wave are Completed
16//! ```
17
18use crate::domain::error::{Result, ServiceError, StygianError};
19use crate::ports::work_queue::{TaskStatus, WorkQueuePort, WorkTask};
20use crate::ports::{ScrapingService, ServiceInput};
21use async_trait::async_trait;
22use dashmap::DashMap;
23use std::collections::VecDeque;
24use std::sync::Arc;
25use tokio::sync::Mutex;
26use tracing::{debug, error, info, warn};
27
28// ─────────────────────────────────────────────────────────────────────────────
29// LocalWorkQueue
30// ─────────────────────────────────────────────────────────────────────────────
31
32/// In-memory work queue for single-node deployments and unit tests.
33///
34/// All state is stored in `Arc`-wrapped structures so the queue can be cheaply
35/// cloned and shared across worker tasks.
36///
37/// # Example
38///
39/// ```
40/// use stygian_graph::adapters::distributed::LocalWorkQueue;
41/// use stygian_graph::ports::work_queue::{WorkQueuePort, WorkTask};
42/// use serde_json::json;
43///
44/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
45/// let queue = LocalWorkQueue::new();
46/// assert_eq!(queue.pending_count().await.unwrap(), 0);
47///
48/// let task = WorkTask {
49///     id: "t-1".to_string(),
50///     pipeline_id: "p-1".to_string(),
51///     node_name: "fetch".to_string(),
52///     input: json!({"url": "https://example.com"}),
53///     wave: 0,
54///     attempt: 0,
55///     idempotency_key: "ik-t1".to_string(),
56/// };
57/// queue.enqueue(task).await.unwrap();
58/// assert_eq!(queue.pending_count().await.unwrap(), 1);
59///
60/// let dequeued = queue.try_dequeue().await.unwrap().unwrap();
61/// assert_eq!(dequeued.node_name, "fetch");
62/// # });
63/// ```
64#[derive(Clone)]
65pub struct LocalWorkQueue {
66    pending: Arc<Mutex<VecDeque<WorkTask>>>,
67    state: Arc<DashMap<String, TaskStatus>>,
68    /// Max retries before a task moves to the dead-letter state
69    max_retries: u32,
70}
71
72impl LocalWorkQueue {
73    /// Create a new `LocalWorkQueue` with default settings (`max_retries = 3`).
74    pub fn new() -> Self {
75        Self {
76            pending: Arc::new(Mutex::new(VecDeque::new())),
77            state: Arc::new(DashMap::new()),
78            max_retries: 3,
79        }
80    }
81
82    /// Create a `LocalWorkQueue` with a custom retry limit.
83    ///
84    /// # Example
85    ///
86    /// ```
87    /// use stygian_graph::adapters::distributed::LocalWorkQueue;
88    ///
89    /// let queue = LocalWorkQueue::with_max_retries(5);
90    /// ```
91    pub fn with_max_retries(max_retries: u32) -> Self {
92        Self {
93            pending: Arc::new(Mutex::new(VecDeque::new())),
94            state: Arc::new(DashMap::new()),
95            max_retries,
96        }
97    }
98}
99
100impl Default for LocalWorkQueue {
101    fn default() -> Self {
102        Self::new()
103    }
104}
105
106#[async_trait]
107impl WorkQueuePort for LocalWorkQueue {
108    async fn enqueue(&self, task: WorkTask) -> Result<()> {
109        debug!(task_id = %task.id, node = %task.node_name, "enqueuing task");
110        self.state.insert(task.id.clone(), TaskStatus::Pending);
111        self.pending.lock().await.push_back(task);
112        Ok(())
113    }
114
115    async fn try_dequeue(&self) -> Result<Option<WorkTask>> {
116        let task = self.pending.lock().await.pop_front();
117        if let Some(ref t) = task {
118            debug!(task_id = %t.id, "dequeued task");
119            self.state.insert(
120                t.id.clone(),
121                TaskStatus::InProgress {
122                    worker_id: "local".to_string(),
123                },
124            );
125        }
126        Ok(task)
127    }
128
129    async fn acknowledge(&self, task_id: &str, output: serde_json::Value) -> Result<()> {
130        info!(task_id = %task_id, "task acknowledged (completed)");
131        self.state
132            .insert(task_id.to_string(), TaskStatus::Completed { output });
133        Ok(())
134    }
135
136    async fn fail(&self, task_id: &str, error: &str) -> Result<()> {
137        let attempt = self
138            .state
139            .get(task_id)
140            .map_or(0, |status| match status.value() {
141                TaskStatus::Failed { attempt, .. } => *attempt,
142                _ => 0,
143            });
144
145        if attempt >= self.max_retries {
146            warn!(task_id = %task_id, %error, "task dead-lettered after max retries");
147            self.state.insert(
148                task_id.to_string(),
149                TaskStatus::DeadLetter {
150                    error: error.to_string(),
151                },
152            );
153        } else {
154            error!(task_id = %task_id, attempt, %error, "task failed, will retry");
155            self.state.insert(
156                task_id.to_string(),
157                TaskStatus::Failed {
158                    error: error.to_string(),
159                    attempt: attempt + 1,
160                },
161            );
162        }
163        Ok(())
164    }
165
166    async fn status(&self, task_id: &str) -> Result<Option<TaskStatus>> {
167        Ok(self.state.get(task_id).map(|s| s.value().clone()))
168    }
169
170    async fn collect_results(&self, pipeline_id: &str) -> Result<Vec<(String, serde_json::Value)>> {
171        // We need to find tasks by pipeline_id — the state map is keyed by
172        // task_id so we collect all Completed entries whose pipeline_id matches.
173        // LocalWorkQueue stores the task in the pending queue; once dequeued
174        // we lose the pipeline_id mapping. We use a secondary index maintained
175        // in the pipeline_tasks map instead.
176        //
177        // For simplicity in the local adapter, we scan all state entries and
178        // match on pipeline_id encoded in the task_id prefix convention
179        // "pipeline_id::node_name::task_id".
180        let mut results = Vec::new();
181        for entry in self.state.iter() {
182            let key = entry.key();
183            // Convention: task_id == "{pipeline_id}::{node_name}::{ulid}"
184            if !key.starts_with(pipeline_id) {
185                continue;
186            }
187            if let TaskStatus::Completed { ref output } = *entry.value() {
188                // Extract node_name from the middle segment
189                let node_name = key.split("::").nth(1).unwrap_or(key).to_string();
190                results.push((node_name, output.clone()));
191            }
192        }
193        Ok(results)
194    }
195
196    async fn pending_count(&self) -> Result<usize> {
197        Ok(self.pending.lock().await.len())
198    }
199}
200
201// ─────────────────────────────────────────────────────────────────────────────
202// DistributedDagExecutor
203// ─────────────────────────────────────────────────────────────────────────────
204
205/// Executes a DAG wave using a [`WorkQueuePort`] to distribute node-level tasks
206/// across workers.
207///
208/// Workers are spawned as Tokio tasks that pull from the queue, call the
209/// appropriate service, and acknowledge results.  For local development the
210/// [`LocalWorkQueue`] is used; in production any queue backend can be plugged
211/// in without changing this executor.
212///
213/// # Example
214///
215/// ```
216/// use stygian_graph::adapters::distributed::{DistributedDagExecutor, LocalWorkQueue};
217/// use stygian_graph::ports::work_queue::WorkTask;
218///
219/// use stygian_graph::adapters::noop::NoopService;
220/// use serde_json::json;
221/// use std::sync::Arc;
222/// use std::collections::HashMap;
223///
224/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
225/// let queue = Arc::new(LocalWorkQueue::new());
226/// let executor = DistributedDagExecutor::new(queue, 4);
227///
228/// let mut services: HashMap<String, Arc<dyn stygian_graph::ports::ScrapingService>> =
229///     HashMap::new();
230/// services.insert("noop".to_string(), Arc::new(NoopService));
231///
232/// let tasks = vec![WorkTask {
233///     id: "p1::fetch::01".to_string(),
234///     pipeline_id: "p1".to_string(),
235///     node_name: "fetch".to_string(),
236///     input: json!({"url": "https://example.com"}),
237///     wave: 0,
238///     attempt: 0,
239///     idempotency_key: "ik-01".to_string(),
240/// }];
241///
242/// let results = executor.execute_wave("p1", tasks, &services).await.unwrap();
243/// assert!(!results.is_empty() || results.is_empty()); // noop returns empty
244/// # });
245/// ```
246pub struct DistributedDagExecutor<Q: WorkQueuePort> {
247    queue: Arc<Q>,
248    worker_concurrency: usize,
249}
250
251impl<Q: WorkQueuePort + 'static> DistributedDagExecutor<Q> {
252    /// Create a new executor with the given work queue and worker concurrency.
253    ///
254    /// `worker_concurrency` controls how many parallel worker tasks drain the
255    /// queue.
256    pub fn new(queue: Arc<Q>, worker_concurrency: usize) -> Self {
257        Self {
258            queue,
259            worker_concurrency: worker_concurrency.max(1),
260        }
261    }
262
263    /// Execute a single wave of tasks, distributing them across workers.
264    ///
265    /// Returns `(node_name, output)` pairs for all tasks in the wave.
266    pub async fn execute_wave(
267        &self,
268        pipeline_id: &str,
269        tasks: Vec<WorkTask>,
270        services: &std::collections::HashMap<String, Arc<dyn ScrapingService>>,
271    ) -> Result<Vec<(String, serde_json::Value)>> {
272        let expected = tasks.len();
273        if expected == 0 {
274            return Ok(Vec::new());
275        }
276
277        // Enqueue all tasks in this wave
278        for task in tasks {
279            self.queue.enqueue(task).await?;
280        }
281
282        // Spawn workers to drain the queue
283        let queue = Arc::clone(&self.queue);
284        let services: Arc<std::collections::HashMap<String, Arc<dyn ScrapingService>>> =
285            Arc::new(services.clone());
286
287        let concurrency = self.worker_concurrency.min(expected);
288        let mut handles = tokio::task::JoinSet::new();
289
290        for _ in 0..concurrency {
291            let q = Arc::clone(&queue);
292            let svcs = Arc::clone(&services);
293            handles.spawn(async move {
294                // Each worker drains the queue until it finds nothing
295                let mut worked = 0usize;
296                loop {
297                    match q.try_dequeue().await {
298                        Ok(Some(task)) => {
299                            let service_input = ServiceInput {
300                                url: task
301                                    .input
302                                    .get("url")
303                                    .and_then(serde_json::Value::as_str)
304                                    .unwrap_or("")
305                                    .to_string(),
306                                params: task.input.clone(),
307                            };
308                            let output = match svcs.get(&task.node_name) {
309                                Some(svc) => svc.execute(service_input.clone()).await,
310                                None => {
311                                    // Fallback: look for a service named "default"
312                                    match svcs.get("default") {
313                                        Some(svc) => svc.execute(service_input).await,
314                                        None => Err(StygianError::Service(
315                                            ServiceError::Unavailable(format!(
316                                                "service '{}' not registered",
317                                                task.node_name
318                                            )),
319                                        )),
320                                    }
321                                }
322                            };
323                            match output {
324                                Ok(out) => {
325                                    let val = serde_json::json!({
326                                        "data": out.data,
327                                        "metadata": out.metadata,
328                                    });
329                                    let _ = q.acknowledge(&task.id, val).await;
330                                }
331                                Err(e) => {
332                                    let _ = q.fail(&task.id, &e.to_string()).await;
333                                }
334                            }
335                            worked += 1;
336                        }
337                        Ok(None) => break, // queue empty
338                        Err(e) => {
339                            error!(error = %e, "worker dequeue error");
340                            break;
341                        }
342                    }
343                }
344                worked
345            });
346        }
347
348        // Wait for all workers
349        while handles.join_next().await.is_some() {}
350
351        // Collect results
352        self.queue.collect_results(pipeline_id).await
353    }
354}
355
356// ─────────────────────────────────────────────────────────────────────────────
357// Tests
358// ─────────────────────────────────────────────────────────────────────────────
359
360#[cfg(test)]
361#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
362mod tests {
363    use super::*;
364    use serde_json::json;
365
366    fn make_task(pipeline_id: &str, node_name: &str, seq: u32) -> WorkTask {
367        WorkTask {
368            id: format!("{pipeline_id}::{node_name}::{seq:04}"),
369            pipeline_id: pipeline_id.to_string(),
370            node_name: node_name.to_string(),
371            input: json!({"url": "https://example.com"}),
372            wave: 0,
373            attempt: 0,
374            idempotency_key: format!("ik-{seq}"),
375        }
376    }
377
378    #[tokio::test]
379    async fn enqueue_dequeue_roundtrip() {
380        let queue = LocalWorkQueue::new();
381        assert_eq!(queue.pending_count().await.unwrap(), 0);
382
383        queue.enqueue(make_task("p1", "fetch", 1)).await.unwrap();
384        queue.enqueue(make_task("p1", "parse", 2)).await.unwrap();
385        assert_eq!(queue.pending_count().await.unwrap(), 2);
386
387        let t1 = queue.try_dequeue().await.unwrap().unwrap();
388        assert_eq!(t1.node_name, "fetch");
389        assert_eq!(queue.pending_count().await.unwrap(), 1);
390
391        let t2 = queue.try_dequeue().await.unwrap().unwrap();
392        assert_eq!(t2.node_name, "parse");
393        assert_eq!(queue.pending_count().await.unwrap(), 0);
394
395        // Queue empty — returns None
396        let empty = queue.try_dequeue().await.unwrap();
397        assert!(empty.is_none());
398    }
399
400    #[tokio::test]
401    async fn acknowledge_records_completed_status() {
402        let queue = LocalWorkQueue::new();
403        queue.enqueue(make_task("p1", "fetch", 1)).await.unwrap();
404        let task = queue.try_dequeue().await.unwrap().unwrap();
405        queue
406            .acknowledge(&task.id, json!({"data": "hello", "status": 200}))
407            .await
408            .unwrap();
409
410        let status = queue.status(&task.id).await.unwrap().unwrap();
411        assert!(matches!(status, TaskStatus::Completed { .. }));
412    }
413
414    #[tokio::test]
415    async fn fail_dead_letters_after_max_retries() {
416        let queue = LocalWorkQueue::with_max_retries(2);
417        queue.enqueue(make_task("p1", "fetch", 1)).await.unwrap();
418        let task = queue.try_dequeue().await.unwrap().unwrap();
419
420        queue.fail(&task.id, "err 1").await.unwrap();
421        queue.fail(&task.id, "err 2").await.unwrap();
422        // attempt 2 == max_retries → dead-letter
423        queue.fail(&task.id, "err 3").await.unwrap();
424
425        let status = queue.status(&task.id).await.unwrap().unwrap();
426        assert!(matches!(status, TaskStatus::DeadLetter { .. }));
427    }
428
429    #[tokio::test]
430    async fn collect_results_filters_by_pipeline_id() {
431        let queue = LocalWorkQueue::new();
432
433        // Two pipelines, one task each
434        let t1 = make_task("pipeline-A", "node1", 1);
435        let t2 = make_task("pipeline-B", "node1", 2);
436
437        queue.enqueue(t1.clone()).await.unwrap();
438        queue.enqueue(t2.clone()).await.unwrap();
439
440        // Both dequeued and acknowledged
441        let deq1 = queue.try_dequeue().await.unwrap().unwrap();
442        let deq2 = queue.try_dequeue().await.unwrap().unwrap();
443
444        queue
445            .acknowledge(&deq1.id, json!({"data": "A-result"}))
446            .await
447            .unwrap();
448        queue
449            .acknowledge(&deq2.id, json!({"data": "B-result"}))
450            .await
451            .unwrap();
452
453        let results_a = queue.collect_results("pipeline-A").await.unwrap();
454        assert_eq!(results_a.len(), 1);
455        assert_eq!(results_a[0].1["data"], "A-result");
456
457        let results_b = queue.collect_results("pipeline-B").await.unwrap();
458        assert_eq!(results_b.len(), 1);
459        assert_eq!(results_b[0].1["data"], "B-result");
460    }
461
462    #[tokio::test]
463    async fn distributed_executor_runs_tasks() {
464        use crate::adapters::noop::NoopService;
465        use std::collections::HashMap;
466
467        let queue = Arc::new(LocalWorkQueue::new());
468        let executor = DistributedDagExecutor::new(Arc::clone(&queue), 2);
469
470        let mut services: HashMap<String, Arc<dyn ScrapingService>> = HashMap::new();
471        services.insert("noop".to_string(), Arc::new(NoopService));
472
473        let tasks = vec![
474            make_task("p1", "noop", 1),
475            make_task("p1", "noop", 2),
476            make_task("p1", "noop", 3),
477        ];
478
479        // Execute wave — NoopService returns empty data, so results may be empty
480        // but the call must succeed without panic/error
481        let results = executor.execute_wave("p1", tasks, &services).await.unwrap();
482        // 3 tasks were acknowledged; results will contain completed ones
483        assert!(results.len() <= 3);
484    }
485}