stygian_graph/adapters/distributed.rs
1//! Distributed execution adapters
2//!
3//! Provides [`LocalWorkQueue`](distributed::LocalWorkQueue) (in-process, for single-node and testing) and
4//! [`DistributedDagExecutor`](distributed::DistributedDagExecutor) (wraps any [`WorkQueuePort`](crate::ports::work_queue::WorkQueuePort) to distribute DAG
5//! waves across workers).
6//!
7//! # Design
8//!
9//! ```text
10//! DistributedDagExecutor
11//! │
12//! ├─ resolve wave N (topological sort already done by DagExecutor)
13//! ├─ enqueue every node in the wave as a WorkTask
14//! ├─ spawn worker tasks that call try_dequeue + service.execute
15//! └─ collect_results when all tasks in wave are Completed
16//! ```
17
18use crate::domain::error::{Result, ServiceError, StygianError};
19use crate::ports::work_queue::{TaskStatus, WorkQueuePort, WorkTask};
20use crate::ports::{ScrapingService, ServiceInput};
21use async_trait::async_trait;
22use dashmap::DashMap;
23use std::collections::VecDeque;
24use std::sync::Arc;
25use tokio::sync::Mutex;
26use tracing::{debug, error, info, warn};
27
28// ─────────────────────────────────────────────────────────────────────────────
29// LocalWorkQueue
30// ─────────────────────────────────────────────────────────────────────────────
31
32/// In-memory work queue for single-node deployments and unit tests.
33///
34/// All state is stored in `Arc`-wrapped structures so the queue can be cheaply
35/// cloned and shared across worker tasks.
36///
37/// # Example
38///
39/// ```
40/// use stygian_graph::adapters::distributed::LocalWorkQueue;
41/// use stygian_graph::ports::work_queue::{WorkQueuePort, WorkTask};
42/// use serde_json::json;
43///
44/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
45/// let queue = LocalWorkQueue::new();
46/// assert_eq!(queue.pending_count().await.unwrap(), 0);
47///
48/// let task = WorkTask {
49/// id: "t-1".to_string(),
50/// pipeline_id: "p-1".to_string(),
51/// node_name: "fetch".to_string(),
52/// input: json!({"url": "https://example.com"}),
53/// wave: 0,
54/// attempt: 0,
55/// idempotency_key: "ik-t1".to_string(),
56/// };
57/// queue.enqueue(task).await.unwrap();
58/// assert_eq!(queue.pending_count().await.unwrap(), 1);
59///
60/// let dequeued = queue.try_dequeue().await.unwrap().unwrap();
61/// assert_eq!(dequeued.node_name, "fetch");
62/// # });
63/// ```
64#[derive(Clone)]
65pub struct LocalWorkQueue {
66 pending: Arc<Mutex<VecDeque<WorkTask>>>,
67 state: Arc<DashMap<String, TaskStatus>>,
68 /// Max retries before a task moves to the dead-letter state
69 max_retries: u32,
70}
71
72impl LocalWorkQueue {
73 /// Create a new `LocalWorkQueue` with default settings (`max_retries = 3`).
74 #[must_use]
75 pub fn new() -> Self {
76 Self {
77 pending: Arc::new(Mutex::new(VecDeque::new())),
78 state: Arc::new(DashMap::new()),
79 max_retries: 3,
80 }
81 }
82
83 /// Create a `LocalWorkQueue` with a custom retry limit.
84 ///
85 /// # Example
86 ///
87 /// ```
88 /// use stygian_graph::adapters::distributed::LocalWorkQueue;
89 ///
90 /// let queue = LocalWorkQueue::with_max_retries(5);
91 /// ```
92 #[must_use]
93 pub fn with_max_retries(max_retries: u32) -> Self {
94 Self {
95 pending: Arc::new(Mutex::new(VecDeque::new())),
96 state: Arc::new(DashMap::new()),
97 max_retries,
98 }
99 }
100}
101
102impl Default for LocalWorkQueue {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108#[async_trait]
109impl WorkQueuePort for LocalWorkQueue {
110 async fn enqueue(&self, task: WorkTask) -> Result<()> {
111 debug!(task_id = %task.id, node = %task.node_name, "enqueuing task");
112 self.state.insert(task.id.clone(), TaskStatus::Pending);
113 self.pending.lock().await.push_back(task);
114 Ok(())
115 }
116
117 async fn try_dequeue(&self) -> Result<Option<WorkTask>> {
118 let task = self.pending.lock().await.pop_front();
119 if let Some(ref t) = task {
120 debug!(task_id = %t.id, "dequeued task");
121 self.state.insert(
122 t.id.clone(),
123 TaskStatus::InProgress {
124 worker_id: "local".to_string(),
125 },
126 );
127 }
128 Ok(task)
129 }
130
131 async fn acknowledge(&self, task_id: &str, output: serde_json::Value) -> Result<()> {
132 info!(task_id = %task_id, "task acknowledged (completed)");
133 self.state
134 .insert(task_id.to_string(), TaskStatus::Completed { output });
135 Ok(())
136 }
137
138 async fn fail(&self, task_id: &str, error: &str) -> Result<()> {
139 let attempt = self
140 .state
141 .get(task_id)
142 .map_or(0, |status| match status.value() {
143 TaskStatus::Failed { attempt, .. } => *attempt,
144 _ => 0,
145 });
146
147 if attempt >= self.max_retries {
148 warn!(task_id = %task_id, %error, "task dead-lettered after max retries");
149 self.state.insert(
150 task_id.to_string(),
151 TaskStatus::DeadLetter {
152 error: error.to_string(),
153 },
154 );
155 } else {
156 error!(task_id = %task_id, attempt, %error, "task failed, will retry");
157 self.state.insert(
158 task_id.to_string(),
159 TaskStatus::Failed {
160 error: error.to_string(),
161 attempt: attempt + 1,
162 },
163 );
164 }
165 Ok(())
166 }
167
168 async fn status(&self, task_id: &str) -> Result<Option<TaskStatus>> {
169 Ok(self.state.get(task_id).map(|s| s.value().clone()))
170 }
171
172 async fn collect_results(&self, pipeline_id: &str) -> Result<Vec<(String, serde_json::Value)>> {
173 // We need to find tasks by pipeline_id — the state map is keyed by
174 // task_id so we collect all Completed entries whose pipeline_id matches.
175 // LocalWorkQueue stores the task in the pending queue; once dequeued
176 // we lose the pipeline_id mapping. We use a secondary index maintained
177 // in the pipeline_tasks map instead.
178 //
179 // For simplicity in the local adapter, we scan all state entries and
180 // match on pipeline_id encoded in the task_id prefix convention
181 // "pipeline_id::node_name::task_id".
182 let mut results = Vec::new();
183 for entry in self.state.iter() {
184 let key = entry.key();
185 // Convention: task_id == "{pipeline_id}::{node_name}::{ulid}"
186 if !key.starts_with(pipeline_id) {
187 continue;
188 }
189 if let TaskStatus::Completed { ref output } = *entry.value() {
190 // Extract node_name from the middle segment
191 let node_name = key.split("::").nth(1).unwrap_or(key).to_string();
192 results.push((node_name, output.clone()));
193 }
194 }
195 Ok(results)
196 }
197
198 async fn pending_count(&self) -> Result<usize> {
199 Ok(self.pending.lock().await.len())
200 }
201}
202
203// ─────────────────────────────────────────────────────────────────────────────
204// DistributedDagExecutor
205// ─────────────────────────────────────────────────────────────────────────────
206
207/// Executes a DAG wave using a [`WorkQueuePort`] to distribute node-level tasks
208/// across workers.
209///
210/// Workers are spawned as Tokio tasks that pull from the queue, call the
211/// appropriate service, and acknowledge results. For local development the
212/// [`LocalWorkQueue`] is used; in production any queue backend can be plugged
213/// in without changing this executor.
214///
215/// # Example
216///
217/// ```
218/// use stygian_graph::adapters::distributed::{DistributedDagExecutor, LocalWorkQueue};
219/// use stygian_graph::ports::work_queue::WorkTask;
220///
221/// use stygian_graph::adapters::noop::NoopService;
222/// use serde_json::json;
223/// use std::sync::Arc;
224/// use std::collections::HashMap;
225///
226/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
227/// let queue = Arc::new(LocalWorkQueue::new());
228/// let executor = DistributedDagExecutor::new(queue, 4);
229///
230/// let mut services: HashMap<String, Arc<dyn stygian_graph::ports::ScrapingService>> =
231/// HashMap::new();
232/// services.insert("noop".to_string(), Arc::new(NoopService));
233///
234/// let tasks = vec![WorkTask {
235/// id: "p1::fetch::01".to_string(),
236/// pipeline_id: "p1".to_string(),
237/// node_name: "fetch".to_string(),
238/// input: json!({"url": "https://example.com"}),
239/// wave: 0,
240/// attempt: 0,
241/// idempotency_key: "ik-01".to_string(),
242/// }];
243///
244/// let results = executor.execute_wave("p1", tasks, &services).await.unwrap();
245/// assert!(!results.is_empty() || results.is_empty()); // noop returns empty
246/// # });
247/// ```
248pub struct DistributedDagExecutor<Q: WorkQueuePort> {
249 queue: Arc<Q>,
250 worker_concurrency: usize,
251}
252
253impl<Q: WorkQueuePort + 'static> DistributedDagExecutor<Q> {
254 /// Create a new executor with the given work queue and worker concurrency.
255 ///
256 /// `worker_concurrency` controls how many parallel worker tasks drain the
257 /// queue.
258 pub fn new(queue: Arc<Q>, worker_concurrency: usize) -> Self {
259 Self {
260 queue,
261 worker_concurrency: worker_concurrency.max(1),
262 }
263 }
264
265 /// Execute a single wave of tasks, distributing them across workers.
266 ///
267 /// Returns `(node_name, output)` pairs for all tasks in the wave.
268 ///
269 /// # Panics
270 ///
271 /// Panics if an internal `Mutex` is poisoned (i.e. another thread panicked
272 /// while holding the lock). Treat this as unrecoverable.
273 ///
274 /// # Errors
275 ///
276 /// Returns [`StygianError`] when a service reports a failure, the executor
277 /// is shut down, or a worker task cannot be enqueued.
278 pub async fn execute_wave(
279 &self,
280 pipeline_id: &str,
281 tasks: Vec<WorkTask>,
282 services: &std::collections::HashMap<String, Arc<dyn ScrapingService>>,
283 ) -> Result<Vec<(String, serde_json::Value)>> {
284 let expected = tasks.len();
285 if expected == 0 {
286 return Ok(Vec::new());
287 }
288
289 // Enqueue all tasks in this wave
290 for task in tasks {
291 self.queue.enqueue(task).await?;
292 }
293
294 // Spawn workers to drain the queue
295 let queue = Arc::clone(&self.queue);
296 let services: Arc<std::collections::HashMap<String, Arc<dyn ScrapingService>>> =
297 Arc::new(services.clone());
298
299 let concurrency = self.worker_concurrency.min(expected);
300 let mut handles = tokio::task::JoinSet::new();
301
302 for _ in 0..concurrency {
303 let q = Arc::clone(&queue);
304 let svcs = Arc::clone(&services);
305 handles.spawn(async move {
306 // Each worker drains the queue until it finds nothing
307 let mut worked = 0usize;
308 loop {
309 match q.try_dequeue().await {
310 Ok(Some(task)) => {
311 let service_input = ServiceInput {
312 url: task
313 .input
314 .get("url")
315 .and_then(serde_json::Value::as_str)
316 .unwrap_or("")
317 .to_string(),
318 params: task.input.clone(),
319 };
320 let output = match svcs.get(&task.node_name) {
321 Some(svc) => svc.execute(service_input.clone()).await,
322 None => {
323 // Fallback: look for a service named "default"
324 match svcs.get("default") {
325 Some(svc) => svc.execute(service_input).await,
326 None => Err(StygianError::Service(
327 ServiceError::Unavailable(format!(
328 "service '{}' not registered",
329 task.node_name
330 )),
331 )),
332 }
333 }
334 };
335 match output {
336 Ok(out) => {
337 // codeql[rust/unused-variable] - `out` is consumed by the `json!` macro below.
338 let val = serde_json::json!({
339 "data": out.data,
340 "metadata": out.metadata,
341 });
342 let _ = q.acknowledge(&task.id, val).await;
343 }
344 Err(e) => {
345 let _ = q.fail(&task.id, &e.to_string()).await;
346 }
347 }
348 worked += 1;
349 }
350 Ok(None) => break, // queue empty
351 Err(e) => {
352 // codeql[rust/unused-variable] - `e` is used via the structured field below.
353 error!(error = %e, "worker dequeue error");
354 break;
355 }
356 }
357 }
358 worked
359 });
360 }
361
362 // Wait for all workers
363 while handles.join_next().await.is_some() {}
364
365 // Collect results
366 self.queue.collect_results(pipeline_id).await
367 }
368}
369
370// ─────────────────────────────────────────────────────────────────────────────
371// Tests
372// ─────────────────────────────────────────────────────────────────────────────
373
374#[cfg(test)]
375#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
376mod tests {
377 use super::*;
378 use serde_json::json;
379
380 fn make_task(pipeline_id: &str, node_name: &str, seq: u32) -> WorkTask {
381 WorkTask {
382 id: format!("{pipeline_id}::{node_name}::{seq:04}"),
383 pipeline_id: pipeline_id.to_string(),
384 node_name: node_name.to_string(),
385 input: json!({"url": "https://example.com"}),
386 wave: 0,
387 attempt: 0,
388 idempotency_key: format!("ik-{seq}"),
389 }
390 }
391
392 #[tokio::test]
393 async fn enqueue_dequeue_roundtrip() {
394 let queue = LocalWorkQueue::new();
395 assert_eq!(queue.pending_count().await.unwrap(), 0);
396
397 queue.enqueue(make_task("p1", "fetch", 1)).await.unwrap();
398 queue.enqueue(make_task("p1", "parse", 2)).await.unwrap();
399 assert_eq!(queue.pending_count().await.unwrap(), 2);
400
401 let t1 = queue.try_dequeue().await.unwrap().unwrap();
402 assert_eq!(t1.node_name, "fetch");
403 assert_eq!(queue.pending_count().await.unwrap(), 1);
404
405 let t2 = queue.try_dequeue().await.unwrap().unwrap();
406 assert_eq!(t2.node_name, "parse");
407 assert_eq!(queue.pending_count().await.unwrap(), 0);
408
409 // Queue empty — returns None
410 let empty = queue.try_dequeue().await.unwrap();
411 assert!(empty.is_none());
412 }
413
414 #[tokio::test]
415 async fn acknowledge_records_completed_status() {
416 let queue = LocalWorkQueue::new();
417 queue.enqueue(make_task("p1", "fetch", 1)).await.unwrap();
418 let task = queue.try_dequeue().await.unwrap().unwrap();
419 queue
420 .acknowledge(&task.id, json!({"data": "hello", "status": 200}))
421 .await
422 .unwrap();
423
424 let status = queue.status(&task.id).await.unwrap().unwrap();
425 assert!(matches!(status, TaskStatus::Completed { .. }));
426 }
427
428 #[tokio::test]
429 async fn fail_dead_letters_after_max_retries() {
430 let queue = LocalWorkQueue::with_max_retries(2);
431 queue.enqueue(make_task("p1", "fetch", 1)).await.unwrap();
432 let task = queue.try_dequeue().await.unwrap().unwrap();
433
434 queue.fail(&task.id, "err 1").await.unwrap();
435 queue.fail(&task.id, "err 2").await.unwrap();
436 // attempt 2 == max_retries → dead-letter
437 queue.fail(&task.id, "err 3").await.unwrap();
438
439 let status = queue.status(&task.id).await.unwrap().unwrap();
440 assert!(matches!(status, TaskStatus::DeadLetter { .. }));
441 }
442
443 #[tokio::test]
444 async fn collect_results_filters_by_pipeline_id() {
445 let queue = LocalWorkQueue::new();
446
447 // Two pipelines, one task each
448 let t1 = make_task("pipeline-A", "node1", 1);
449 let t2 = make_task("pipeline-B", "node1", 2);
450
451 queue.enqueue(t1.clone()).await.unwrap();
452 queue.enqueue(t2.clone()).await.unwrap();
453
454 // Both dequeued and acknowledged
455 let deq1 = queue.try_dequeue().await.unwrap().unwrap();
456 let deq2 = queue.try_dequeue().await.unwrap().unwrap();
457
458 queue
459 .acknowledge(&deq1.id, json!({"data": "A-result"}))
460 .await
461 .unwrap();
462 queue
463 .acknowledge(&deq2.id, json!({"data": "B-result"}))
464 .await
465 .unwrap();
466
467 let results_a = queue.collect_results("pipeline-A").await.unwrap();
468 assert_eq!(results_a.len(), 1);
469 assert_eq!(results_a[0].1["data"], "A-result");
470
471 let results_b = queue.collect_results("pipeline-B").await.unwrap();
472 assert_eq!(results_b.len(), 1);
473 assert_eq!(results_b[0].1["data"], "B-result");
474 }
475
476 #[tokio::test]
477 async fn distributed_executor_runs_tasks() {
478 use crate::adapters::noop::NoopService;
479 use std::collections::HashMap;
480
481 let queue = Arc::new(LocalWorkQueue::new());
482 let executor = DistributedDagExecutor::new(Arc::clone(&queue), 2);
483
484 let mut services: HashMap<String, Arc<dyn ScrapingService>> = HashMap::new();
485 services.insert("noop".to_string(), Arc::new(NoopService));
486
487 let tasks = vec![
488 make_task("p1", "noop", 1),
489 make_task("p1", "noop", 2),
490 make_task("p1", "noop", 3),
491 ];
492
493 // Execute wave — NoopService returns empty data, so results may be empty
494 // but the call must succeed without panic/error
495 let results = executor.execute_wave("p1", tasks, &services).await.unwrap();
496 // 3 tasks were acknowledged; results will contain completed ones
497 assert!(results.len() <= 3);
498 }
499}