stygian_graph/domain/
executor.rs1use std::sync::Arc;
18
19use tokio::sync::{Mutex, mpsc};
20use tokio::task::JoinSet;
21use tokio_util::sync::CancellationToken;
22
23use crate::domain::error::{GraphError, Result, StygianError};
24use crate::ports::{ScrapingService, ServiceInput, ServiceOutput};
25
26struct WorkItem {
28 service: Arc<dyn ScrapingService>,
30 input: ServiceInput,
32 reply: tokio::sync::oneshot::Sender<Result<ServiceOutput>>,
34}
35
36pub struct WorkerPool {
44 tx: mpsc::Sender<WorkItem>,
45 cancel: CancellationToken,
46 workers: Arc<Mutex<JoinSet<()>>>,
47}
48
49impl WorkerPool {
50 #[allow(clippy::significant_drop_tightening)]
63 #[must_use]
64 pub fn new(concurrency: usize, queue_depth: usize) -> Self {
65 let (tx, rx) = mpsc::channel::<WorkItem>(queue_depth);
66 let rx = Arc::new(Mutex::new(rx));
67 let cancel = CancellationToken::new();
68 let mut join_set = JoinSet::new();
69
70 for _ in 0..concurrency {
71 let rx_clone = Arc::clone(&rx);
72 let cancel_clone = cancel.clone();
73
74 join_set.spawn(async move {
75 loop {
76 if cancel_clone.is_cancelled() {
78 break;
79 }
80
81 let item = {
82 #[allow(clippy::significant_drop_tightening)]
83 let mut guard = rx_clone.lock().await;
84 tokio::select! {
85 biased;
86 () = cancel_clone.cancelled() => break,
87 item = guard.recv() => {
88 match item {
89 Some(item) => item,
90 None => break, }
92 }
93 }
94 };
95
96 let result = item.service.execute(item.input).await;
97 let _ = item.reply.send(result);
99 }
100 });
101 }
102
103 Self {
104 tx,
105 cancel,
106 workers: Arc::new(Mutex::new(join_set)),
107 }
108 }
109
110 pub async fn submit(
118 &self,
119 service: Arc<dyn ScrapingService>,
120 input: ServiceInput,
121 ) -> Result<ServiceOutput> {
122 let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
123
124 self.tx
125 .send(WorkItem {
126 service,
127 input,
128 reply: reply_tx,
129 })
130 .await
131 .map_err(|_| {
132 StygianError::Graph(GraphError::ExecutionFailed(
133 "Worker pool is shut down".into(),
134 ))
135 })?;
136
137 reply_rx.await.map_err(|_| {
138 StygianError::Graph(GraphError::ExecutionFailed(
139 "Worker task dropped reply channel".into(),
140 ))
141 })?
142 }
143
144 pub async fn shutdown(self) {
149 self.cancel.cancel();
150 drop(self.tx); let mut workers = self.workers.lock().await;
153 while workers.join_next().await.is_some() {}
154 }
155
156 #[must_use]
160 pub fn is_saturated(&self) -> bool {
161 self.tx.capacity() == 0
162 }
163
164 #[must_use]
166 pub fn available_capacity(&self) -> usize {
167 self.tx.capacity()
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use crate::adapters::noop::NoopService;
175
176 #[tokio::test]
177 async fn test_worker_pool_basic_execution() {
178 let pool = WorkerPool::new(2, 10);
179 let svc: Arc<dyn ScrapingService> = Arc::new(NoopService);
180
181 let input = ServiceInput {
182 url: "https://example.com".to_string(),
183 params: serde_json::json!({}),
184 };
185
186 let result = pool.submit(svc, input).await;
187 assert!(result.is_ok());
188
189 pool.shutdown().await;
190 }
191
192 #[tokio::test]
193 async fn test_worker_pool_concurrent_tasks()
194 -> std::result::Result<(), Box<dyn std::error::Error>> {
195 let pool = Arc::new(WorkerPool::new(4, 20));
196 let svc: Arc<dyn ScrapingService> = Arc::new(NoopService);
197
198 let mut handles = Vec::new();
199 for i in 0..10 {
200 let pool_clone = Arc::clone(&pool);
201 let svc_clone = Arc::clone(&svc);
202 handles.push(tokio::spawn(async move {
203 let url = format!("https://example.com/{i}");
204 let input = ServiceInput {
205 url,
206 params: serde_json::json!({}),
207 };
208 pool_clone.submit(svc_clone, input).await
209 }));
210 }
211
212 for handle in handles {
213 let result = handle.await?;
214 assert!(result.is_ok(), "Task failed: {result:?}");
215 }
216
217 if let Some(p) = Arc::into_inner(pool) {
219 p.shutdown().await;
220 }
221 Ok(())
222 }
223
224 #[tokio::test]
225 async fn test_worker_pool_backpressure() {
226 let pool = WorkerPool::new(1, 1);
228 assert_eq!(pool.available_capacity(), 1);
229
230 let svc: Arc<dyn ScrapingService> = Arc::new(NoopService);
231 let input = ServiceInput {
232 url: "https://example.com".to_string(),
233 params: serde_json::json!({}),
234 };
235
236 let result = pool.submit(svc, input).await;
237 assert!(result.is_ok());
238
239 pool.shutdown().await;
240 }
241
242 #[tokio::test]
243 async fn test_worker_pool_graceful_shutdown() {
244 let pool = WorkerPool::new(2, 10);
245 pool.shutdown().await;
247 }
248}