1use std::collections::HashMap;
15use std::sync::Arc;
16
17use petgraph::algo::toposort;
18use petgraph::graph::{DiGraph, NodeIndex};
19use serde::{Deserialize, Serialize};
20use tokio::sync::Mutex;
21
22use super::error::{GraphError, StygianError};
23use crate::ports::{ScrapingService, ServiceInput, ServiceOutput};
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct Node {
43 pub id: String,
45
46 pub service: String,
48
49 pub config: serde_json::Value,
51
52 #[serde(default)]
54 pub metadata: serde_json::Value,
55}
56
57impl Node {
58 pub fn new(
71 id: impl Into<String>,
72 service: impl Into<String>,
73 config: serde_json::Value,
74 ) -> Self {
75 Self {
76 id: id.into(),
77 service: service.into(),
78 config,
79 metadata: serde_json::Value::Null,
80 }
81 }
82
83 pub fn with_metadata(
85 id: impl Into<String>,
86 service: impl Into<String>,
87 config: serde_json::Value,
88 metadata: serde_json::Value,
89 ) -> Self {
90 Self {
91 id: id.into(),
92 service: service.into(),
93 config,
94 metadata,
95 }
96 }
97
98 pub fn validate(&self) -> Result<(), StygianError> {
104 if self.id.is_empty() {
105 return Err(GraphError::InvalidEdge("Node ID cannot be empty".into()).into());
106 }
107 if self.service.is_empty() {
108 return Err(GraphError::InvalidEdge("Node service type cannot be empty".into()).into());
109 }
110 Ok(())
111 }
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct Edge {
128 pub from: String,
130
131 pub to: String,
133
134 #[serde(default)]
136 pub config: serde_json::Value,
137}
138
139impl Edge {
140 pub fn new(from: impl Into<String>, to: impl Into<String>) -> Self {
152 Self {
153 from: from.into(),
154 to: to.into(),
155 config: serde_json::Value::Null,
156 }
157 }
158
159 pub fn with_config(
161 from: impl Into<String>,
162 to: impl Into<String>,
163 config: serde_json::Value,
164 ) -> Self {
165 Self {
166 from: from.into(),
167 to: to.into(),
168 config,
169 }
170 }
171
172 pub fn validate(&self) -> Result<(), StygianError> {
178 if self.from.is_empty() || self.to.is_empty() {
179 return Err(GraphError::InvalidEdge("Edge endpoints cannot be empty".into()).into());
180 }
181 Ok(())
182 }
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct Pipeline {
200 pub name: String,
202
203 pub nodes: Vec<Node>,
205
206 pub edges: Vec<Edge>,
208
209 #[serde(default)]
211 pub metadata: serde_json::Value,
212}
213
214impl Pipeline {
215 pub fn new(name: impl Into<String>) -> Self {
227 Self {
228 name: name.into(),
229 nodes: Vec::new(),
230 edges: Vec::new(),
231 metadata: serde_json::Value::Null,
232 }
233 }
234
235 pub fn add_node(&mut self, node: Node) {
237 self.nodes.push(node);
238 }
239
240 pub fn add_edge(&mut self, edge: Edge) {
242 self.edges.push(edge);
243 }
244
245 pub fn validate(&self) -> Result<(), StygianError> {
251 for node in &self.nodes {
252 node.validate()?;
253 }
254 for edge in &self.edges {
255 edge.validate()?;
256 }
257 Ok(())
258 }
259}
260
261#[derive(Debug, Clone)]
263pub struct NodeResult {
264 pub node_id: String,
266 pub output: ServiceOutput,
268}
269
270pub struct DagExecutor {
276 graph: DiGraph<Node, ()>,
277 _node_indices: HashMap<String, NodeIndex>,
278}
279
280impl DagExecutor {
281 pub fn new() -> Self {
291 Self {
292 graph: DiGraph::new(),
293 _node_indices: HashMap::new(),
294 }
295 }
296
297 pub fn from_pipeline(pipeline: &Pipeline) -> Result<Self, StygianError> {
304 pipeline.validate()?;
305
306 let mut graph = DiGraph::new();
307 let mut node_indices = HashMap::new();
308
309 for node in &pipeline.nodes {
311 let idx = graph.add_node(node.clone());
312 node_indices.insert(node.id.clone(), idx);
313 }
314
315 for edge in &pipeline.edges {
317 let from_idx = node_indices
318 .get(&edge.from)
319 .ok_or_else(|| GraphError::NodeNotFound(edge.from.clone()))?;
320 let to_idx = node_indices
321 .get(&edge.to)
322 .ok_or_else(|| GraphError::NodeNotFound(edge.to.clone()))?;
323 graph.add_edge(*from_idx, *to_idx, ());
324 }
325
326 if petgraph::algo::is_cyclic_directed(&graph) {
328 return Err(GraphError::CycleDetected.into());
329 }
330
331 Ok(Self {
332 graph,
333 _node_indices: node_indices,
334 })
335 }
336
337 pub async fn execute(
347 &self,
348 services: &HashMap<String, Arc<dyn ScrapingService>>,
349 ) -> Result<Vec<NodeResult>, StygianError> {
350 let topo_order = toposort(&self.graph, None).map_err(|_| GraphError::CycleDetected)?;
352
353 let waves = self.build_execution_waves(&topo_order);
355
356 let results: Arc<Mutex<HashMap<String, ServiceOutput>>> =
358 Arc::new(Mutex::new(HashMap::new()));
359
360 for wave in waves {
361 let mut handles = Vec::new();
363
364 for node_idx in wave {
365 let node = self.graph[node_idx].clone();
366 let service = services.get(&node.service).cloned().ok_or_else(|| {
367 GraphError::InvalidPipeline(format!(
368 "No service registered for type '{}'",
369 node.service
370 ))
371 })?;
372
373 let upstream_data = {
375 let store = results.lock().await;
376 let mut data = serde_json::Map::new();
377 for pred_idx in self
378 .graph
379 .neighbors_directed(node_idx, petgraph::Direction::Incoming)
380 {
381 let pred_id = &self.graph[pred_idx].id;
382 if let Some(out) = store.get(pred_id) {
383 data.insert(
384 pred_id.clone(),
385 serde_json::Value::String(out.data.clone()),
386 );
387 }
388 }
389 serde_json::Value::Object(data)
390 };
391
392 let input = ServiceInput {
393 url: node
394 .config
395 .get("url")
396 .and_then(|v| v.as_str())
397 .unwrap_or("")
398 .to_string(),
399 params: upstream_data,
400 };
401
402 let results_clone = Arc::clone(&results);
403 let node_id = node.id.clone();
404
405 handles.push(tokio::spawn(async move {
406 let output = service.execute(input).await?;
407 results_clone
408 .lock()
409 .await
410 .insert(node_id.clone(), output.clone());
411 Ok::<NodeResult, StygianError>(NodeResult { node_id, output })
412 }));
413 }
414
415 for handle in handles {
417 handle
418 .await
419 .map_err(|e| GraphError::ExecutionFailed(format!("Task join error: {e}")))??;
420 }
421 }
422
423 let store = results.lock().await;
425 let final_results = topo_order
426 .iter()
427 .filter_map(|idx| {
428 let node_id = &self.graph[*idx].id;
429 store.get(node_id).map(|output| NodeResult {
430 node_id: node_id.clone(),
431 output: output.clone(),
432 })
433 })
434 .collect();
435
436 Ok(final_results)
437 }
438
439 fn build_execution_waves(&self, topo_order: &[NodeIndex]) -> Vec<Vec<NodeIndex>> {
443 let mut level: HashMap<NodeIndex, usize> = HashMap::new();
444
445 for &idx in topo_order {
446 let max_pred_level = self
447 .graph
448 .neighbors_directed(idx, petgraph::Direction::Incoming)
449 .map(|pred| level.get(&pred).copied().unwrap_or(0) + 1)
450 .max()
451 .unwrap_or(0);
452 level.insert(idx, max_pred_level);
453 }
454
455 let max_level = level.values().copied().max().unwrap_or(0);
456 let mut waves: Vec<Vec<NodeIndex>> = vec![Vec::new(); max_level + 1];
457 for (idx, lvl) in level {
458 if let Some(wave) = waves.get_mut(lvl) {
459 wave.push(idx);
460 }
461 }
462 waves
463 }
464}
465
466impl Default for DagExecutor {
467 fn default() -> Self {
468 Self::new()
469 }
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475 use crate::domain::error::Result;
476
477 #[test]
478 fn test_node_creation() {
479 let node = Node::new(
480 "test",
481 "http",
482 serde_json::json!({"url": "https://example.com"}),
483 );
484 assert_eq!(node.id, "test");
485 assert_eq!(node.service, "http");
486 }
487
488 #[test]
489 fn test_edge_creation() {
490 let edge = Edge::new("a", "b");
491 assert_eq!(edge.from, "a");
492 assert_eq!(edge.to, "b");
493 }
494
495 #[test]
496 fn test_pipeline_validation() {
497 let mut pipeline = Pipeline::new("test");
498 pipeline.add_node(Node::new("fetch", "http", serde_json::json!({})));
499 pipeline.add_node(Node::new("extract", "ai", serde_json::json!({})));
500 pipeline.add_edge(Edge::new("fetch", "extract"));
501
502 assert!(pipeline.validate().is_ok());
503 }
504
505 #[test]
506 fn test_cycle_detection() {
507 let mut pipeline = Pipeline::new("cyclic");
508 pipeline.add_node(Node::new("a", "http", serde_json::json!({})));
509 pipeline.add_node(Node::new("b", "http", serde_json::json!({})));
510 pipeline.add_edge(Edge::new("a", "b"));
511 pipeline.add_edge(Edge::new("b", "a")); let result = DagExecutor::from_pipeline(&pipeline);
514 assert!(matches!(
515 result,
516 Err(StygianError::Graph(GraphError::CycleDetected))
517 ));
518 }
519
520 #[tokio::test]
523 async fn test_diamond_concurrent_execution() -> Result<()> {
524 use crate::adapters::noop::NoopService;
525
526 let mut pipeline = Pipeline::new("diamond");
528 pipeline.add_node(Node::new("A", "noop", serde_json::json!({"url": ""})));
529 pipeline.add_node(Node::new("B", "noop", serde_json::json!({"url": ""})));
530 pipeline.add_node(Node::new("C", "noop", serde_json::json!({"url": ""})));
531 pipeline.add_node(Node::new("D", "noop", serde_json::json!({"url": ""})));
532 pipeline.add_edge(Edge::new("A", "B"));
533 pipeline.add_edge(Edge::new("A", "C"));
534 pipeline.add_edge(Edge::new("B", "D"));
535 pipeline.add_edge(Edge::new("C", "D"));
536
537 let executor = DagExecutor::from_pipeline(&pipeline)?;
538
539 let mut services: HashMap<String, std::sync::Arc<dyn crate::ports::ScrapingService>> =
540 HashMap::new();
541 services.insert("noop".to_string(), std::sync::Arc::new(NoopService));
542
543 let results = executor.execute(&services).await?;
544
545 assert_eq!(results.len(), 4);
547 let ids: Vec<&str> = results.iter().map(|r| r.node_id.as_str()).collect();
548 assert!(ids.contains(&"A"));
549 assert!(ids.contains(&"B"));
550 assert!(ids.contains(&"C"));
551 assert!(ids.contains(&"D"));
552 Ok(())
553 }
554
555 #[tokio::test]
556 async fn test_missing_service_returns_error() -> Result<()> {
557 let mut pipeline = Pipeline::new("test");
558 pipeline.add_node(Node::new("fetch", "http", serde_json::json!({})));
559
560 let executor = DagExecutor::from_pipeline(&pipeline)?;
561 let services: HashMap<String, std::sync::Arc<dyn crate::ports::ScrapingService>> =
562 HashMap::new();
563
564 let result = executor.execute(&services).await;
565 assert!(result.is_err());
566 Ok(())
567 }
568}