stygian_graph/domain/
executor.rs1use std::sync::Arc;
18
19use tokio::sync::{Mutex, mpsc};
20use tokio::task::JoinSet;
21use tokio_util::sync::CancellationToken;
22
23use crate::domain::error::{GraphError, Result, StygianError};
24use crate::ports::{ScrapingService, ServiceInput, ServiceOutput};
25
26struct WorkItem {
28 service: Arc<dyn ScrapingService>,
30 input: ServiceInput,
32 reply: tokio::sync::oneshot::Sender<Result<ServiceOutput>>,
34}
35
36pub struct WorkerPool {
44 tx: mpsc::Sender<WorkItem>,
45 cancel: CancellationToken,
46 workers: Arc<Mutex<JoinSet<()>>>,
47}
48
49impl WorkerPool {
50 #[allow(clippy::significant_drop_tightening)]
63 pub fn new(concurrency: usize, queue_depth: usize) -> Self {
64 let (tx, rx) = mpsc::channel::<WorkItem>(queue_depth);
65 let rx = Arc::new(Mutex::new(rx));
66 let cancel = CancellationToken::new();
67 let mut join_set = JoinSet::new();
68
69 for _ in 0..concurrency {
70 let rx_clone = Arc::clone(&rx);
71 let cancel_clone = cancel.clone();
72
73 join_set.spawn(async move {
74 loop {
75 if cancel_clone.is_cancelled() {
77 break;
78 }
79
80 let item = {
81 #[allow(clippy::significant_drop_tightening)]
82 let mut guard = rx_clone.lock().await;
83 tokio::select! {
84 biased;
85 () = cancel_clone.cancelled() => break,
86 item = guard.recv() => {
87 match item {
88 Some(item) => item,
89 None => break, }
91 }
92 }
93 };
94
95 let result = item.service.execute(item.input).await;
96 let _ = item.reply.send(result);
98 }
99 });
100 }
101
102 Self {
103 tx,
104 cancel,
105 workers: Arc::new(Mutex::new(join_set)),
106 }
107 }
108
109 pub async fn submit(
117 &self,
118 service: Arc<dyn ScrapingService>,
119 input: ServiceInput,
120 ) -> Result<ServiceOutput> {
121 let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
122
123 self.tx
124 .send(WorkItem {
125 service,
126 input,
127 reply: reply_tx,
128 })
129 .await
130 .map_err(|_| {
131 StygianError::Graph(GraphError::ExecutionFailed(
132 "Worker pool is shut down".into(),
133 ))
134 })?;
135
136 reply_rx.await.map_err(|_| {
137 StygianError::Graph(GraphError::ExecutionFailed(
138 "Worker task dropped reply channel".into(),
139 ))
140 })?
141 }
142
143 pub async fn shutdown(self) {
148 self.cancel.cancel();
149 drop(self.tx); let mut workers = self.workers.lock().await;
152 while workers.join_next().await.is_some() {}
153 }
154
155 #[must_use]
159 pub fn is_saturated(&self) -> bool {
160 self.tx.capacity() == 0
161 }
162
163 pub fn available_capacity(&self) -> usize {
165 self.tx.capacity()
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use crate::adapters::noop::NoopService;
173
174 #[tokio::test]
175 async fn test_worker_pool_basic_execution() {
176 let pool = WorkerPool::new(2, 10);
177 let svc: Arc<dyn ScrapingService> = Arc::new(NoopService);
178
179 let input = ServiceInput {
180 url: "https://example.com".to_string(),
181 params: serde_json::json!({}),
182 };
183
184 let result = pool.submit(svc, input).await;
185 assert!(result.is_ok());
186
187 pool.shutdown().await;
188 }
189
190 #[tokio::test]
191 async fn test_worker_pool_concurrent_tasks()
192 -> std::result::Result<(), Box<dyn std::error::Error>> {
193 let pool = Arc::new(WorkerPool::new(4, 20));
194 let svc: Arc<dyn ScrapingService> = Arc::new(NoopService);
195
196 let mut handles = Vec::new();
197 for i in 0..10 {
198 let pool_clone = Arc::clone(&pool);
199 let svc_clone = Arc::clone(&svc);
200 handles.push(tokio::spawn(async move {
201 let url = format!("https://example.com/{i}");
202 let input = ServiceInput {
203 url,
204 params: serde_json::json!({}),
205 };
206 pool_clone.submit(svc_clone, input).await
207 }));
208 }
209
210 for handle in handles {
211 let result = handle.await?;
212 assert!(result.is_ok(), "Task failed: {result:?}");
213 }
214
215 if let Some(p) = Arc::into_inner(pool) {
217 p.shutdown().await;
218 }
219 Ok(())
220 }
221
222 #[tokio::test]
223 async fn test_worker_pool_backpressure() {
224 let pool = WorkerPool::new(1, 1);
226 assert_eq!(pool.available_capacity(), 1);
227
228 let svc: Arc<dyn ScrapingService> = Arc::new(NoopService);
229 let input = ServiceInput {
230 url: "https://example.com".to_string(),
231 params: serde_json::json!({}),
232 };
233
234 let result = pool.submit(svc, input).await;
235 assert!(result.is_ok());
236
237 pool.shutdown().await;
238 }
239
240 #[tokio::test]
241 async fn test_worker_pool_graceful_shutdown() {
242 let pool = WorkerPool::new(2, 10);
243 pool.shutdown().await;
245 }
246}