Skip to main content

stygian_graph/domain/
executor.rs

1//! Worker pool executor with backpressure
2//!
3//! Provides a bounded worker pool for running `ScrapingService` tasks
4//! with adaptive backpressure via tokio bounded channels.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use stygian_graph::domain::executor::WorkerPool;
10//!
11//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
12//! let pool = WorkerPool::new(4, 32);
13//! pool.shutdown().await;
14//! # });
15//! ```
16
17use std::sync::Arc;
18
19use tokio::sync::{Mutex, mpsc};
20use tokio::task::JoinSet;
21use tokio_util::sync::CancellationToken;
22
23use crate::domain::error::{GraphError, Result, StygianError};
24use crate::ports::{ScrapingService, ServiceInput, ServiceOutput};
25
26/// A work item sent to a pool worker
27struct WorkItem {
28    /// Service to invoke
29    service: Arc<dyn ScrapingService>,
30    /// Input for this invocation
31    input: ServiceInput,
32    /// One-shot channel to return the result
33    reply: tokio::sync::oneshot::Sender<Result<ServiceOutput>>,
34}
35
36/// High-performance worker pool with bounded backpressure.
37///
38/// Distributes `ScrapingService` tasks across a fixed number of worker
39/// Tokio tasks. When the internal channel is full, callers block until
40/// a slot is available (backpressure).
41///
42/// Supports graceful shutdown via a `CancellationToken`.
43pub struct WorkerPool {
44    tx: mpsc::Sender<WorkItem>,
45    cancel: CancellationToken,
46    workers: Arc<Mutex<JoinSet<()>>>,
47}
48
49impl WorkerPool {
50    /// Create a new worker pool.
51    ///
52    /// - `concurrency`: number of parallel worker tasks
53    /// - `queue_depth`: bounded channel capacity (backpressure threshold)
54    ///
55    /// # Example
56    ///
57    /// ```no_run
58    /// use stygian_graph::domain::executor::WorkerPool;
59    ///
60    /// let pool = WorkerPool::new(4, 32);
61    /// ```
62    #[allow(clippy::significant_drop_tightening)]
63    #[must_use]
64    pub fn new(concurrency: usize, queue_depth: usize) -> Self {
65        let (tx, rx) = mpsc::channel::<WorkItem>(queue_depth);
66        let rx = Arc::new(Mutex::new(rx));
67        let cancel = CancellationToken::new();
68        let mut join_set = JoinSet::new();
69
70        for _ in 0..concurrency {
71            let rx_clone = Arc::clone(&rx);
72            let cancel_clone = cancel.clone();
73
74            join_set.spawn(async move {
75                loop {
76                    // Check for cancellation before locking
77                    if cancel_clone.is_cancelled() {
78                        break;
79                    }
80
81                    let item = {
82                        #[allow(clippy::significant_drop_tightening)]
83                        let mut guard = rx_clone.lock().await;
84                        tokio::select! {
85                            biased;
86                            () = cancel_clone.cancelled() => break,
87                            item = guard.recv() => {
88                                match item {
89                                    Some(item) => item,
90                                    None => break, // Channel closed
91                                }
92                            }
93                        }
94                    };
95
96                    let result = item.service.execute(item.input).await;
97                    // Ignore send error — caller may have dropped the receiver
98                    let _ = item.reply.send(result);
99                }
100            });
101        }
102
103        Self {
104            tx,
105            cancel,
106            workers: Arc::new(Mutex::new(join_set)),
107        }
108    }
109
110    /// Submit a task to the pool.
111    ///
112    /// Blocks (async) if the internal queue is full (backpressure).
113    ///
114    /// # Errors
115    ///
116    /// Returns `GraphError::ExecutionFailed` if the pool has been shut down.
117    pub async fn submit(
118        &self,
119        service: Arc<dyn ScrapingService>,
120        input: ServiceInput,
121    ) -> Result<ServiceOutput> {
122        let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
123
124        self.tx
125            .send(WorkItem {
126                service,
127                input,
128                reply: reply_tx,
129            })
130            .await
131            .map_err(|_| {
132                StygianError::Graph(GraphError::ExecutionFailed(
133                    "Worker pool is shut down".into(),
134                ))
135            })?;
136
137        reply_rx.await.map_err(|_| {
138            StygianError::Graph(GraphError::ExecutionFailed(
139                "Worker task dropped reply channel".into(),
140            ))
141        })?
142    }
143
144    /// Gracefully shut down the worker pool.
145    ///
146    /// Signals all workers to stop after their current task and waits
147    /// for all worker tasks to complete.
148    pub async fn shutdown(self) {
149        self.cancel.cancel();
150        drop(self.tx); // Close sender so workers exit their recv loops
151
152        let mut workers = self.workers.lock().await;
153        while workers.join_next().await.is_some() {}
154    }
155
156    /// Returns the current backpressure state.
157    ///
158    /// `true` if the queue is at capacity and submitting will block.
159    #[must_use]
160    pub fn is_saturated(&self) -> bool {
161        self.tx.capacity() == 0
162    }
163
164    /// Available capacity in the queue.
165    #[must_use]
166    pub fn available_capacity(&self) -> usize {
167        self.tx.capacity()
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use crate::adapters::noop::NoopService;
175
176    #[tokio::test]
177    async fn test_worker_pool_basic_execution() {
178        let pool = WorkerPool::new(2, 10);
179        let svc: Arc<dyn ScrapingService> = Arc::new(NoopService);
180
181        let input = ServiceInput {
182            url: "https://example.com".to_string(),
183            params: serde_json::json!({}),
184        };
185
186        let result = pool.submit(svc, input).await;
187        assert!(result.is_ok());
188
189        pool.shutdown().await;
190    }
191
192    #[tokio::test]
193    async fn test_worker_pool_concurrent_tasks()
194    -> std::result::Result<(), Box<dyn std::error::Error>> {
195        let pool = Arc::new(WorkerPool::new(4, 20));
196        let svc: Arc<dyn ScrapingService> = Arc::new(NoopService);
197
198        let mut handles = Vec::new();
199        for i in 0..10 {
200            let pool_clone = Arc::clone(&pool);
201            let svc_clone = Arc::clone(&svc);
202            handles.push(tokio::spawn(async move {
203                let url = format!("https://example.com/{i}");
204                let input = ServiceInput {
205                    url,
206                    params: serde_json::json!({}),
207                };
208                pool_clone.submit(svc_clone, input).await
209            }));
210        }
211
212        for handle in handles {
213            let result = handle.await?;
214            assert!(result.is_ok(), "Task failed: {result:?}");
215        }
216
217        // Shut down: unwrap the Arc since we hold the only reference
218        if let Some(p) = Arc::into_inner(pool) {
219            p.shutdown().await;
220        }
221        Ok(())
222    }
223
224    #[tokio::test]
225    async fn test_worker_pool_backpressure() {
226        // Small queue: 1 slot, so second submit should block until first completes
227        let pool = WorkerPool::new(1, 1);
228        assert_eq!(pool.available_capacity(), 1);
229
230        let svc: Arc<dyn ScrapingService> = Arc::new(NoopService);
231        let input = ServiceInput {
232            url: "https://example.com".to_string(),
233            params: serde_json::json!({}),
234        };
235
236        let result = pool.submit(svc, input).await;
237        assert!(result.is_ok());
238
239        pool.shutdown().await;
240    }
241
242    #[tokio::test]
243    async fn test_worker_pool_graceful_shutdown() {
244        let pool = WorkerPool::new(2, 10);
245        // Shutdown should complete without panicking even with no tasks
246        pool.shutdown().await;
247    }
248}