stygian_graph/application/
extraction.rs1use std::sync::Arc;
35
36use async_trait::async_trait;
37use serde_json::{Value, json};
38use tracing::{debug, info, warn};
39
40use crate::domain::error::{ProviderError, Result, StygianError};
41use crate::ports::{AIProvider, ScrapingService, ServiceInput, ServiceOutput};
42
43#[derive(Debug, Clone)]
45pub struct ExtractionConfig {
46 pub max_content_chars: usize,
49 pub validate_output: bool,
52}
53
54impl Default for ExtractionConfig {
55 fn default() -> Self {
56 Self {
57 max_content_chars: 64_000,
58 validate_output: true,
59 }
60 }
61}
62
63pub struct LlmExtractionService {
79 providers: Vec<Arc<dyn AIProvider>>,
81 config: ExtractionConfig,
82}
83
84impl LlmExtractionService {
85 pub fn new(providers: Vec<Arc<dyn AIProvider>>, config: ExtractionConfig) -> Self {
100 Self { providers, config }
101 }
102
103 fn resolve_content(input: &ServiceInput) -> &str {
109 input
110 .params
111 .get("content")
112 .and_then(Value::as_str)
113 .unwrap_or(&input.url)
114 }
115
116 fn truncate_content<'a>(&self, content: &'a str) -> &'a str {
118 if content.len() <= self.config.max_content_chars {
119 content
120 } else {
121 warn!(
122 limit = self.config.max_content_chars,
123 actual = content.len(),
124 "Content truncated for LLM extraction"
125 );
126 &content[..self.config.max_content_chars]
127 }
128 }
129
130 fn resolve_schema(input: &ServiceInput) -> Result<Value> {
132 input.params.get("schema").cloned().ok_or_else(|| {
133 StygianError::Provider(ProviderError::ApiError(
134 "LlmExtractionService requires 'schema' in ServiceInput.params".to_string(),
135 ))
136 })
137 }
138
139 fn validate_output(output: &Value) -> Result<()> {
141 if output.is_object() || output.is_array() {
142 Ok(())
143 } else {
144 Err(StygianError::Provider(ProviderError::ApiError(format!(
145 "Provider returned non-object output: {output}"
146 ))))
147 }
148 }
149}
150
151#[async_trait]
152impl ScrapingService for LlmExtractionService {
153 async fn execute(&self, input: ServiceInput) -> Result<ServiceOutput> {
179 if self.providers.is_empty() {
180 return Err(StygianError::Provider(ProviderError::ApiError(
181 "No AI providers configured in LlmExtractionService".to_string(),
182 )));
183 }
184
185 let schema = Self::resolve_schema(&input)?;
186 let raw_content = Self::resolve_content(&input);
187 let content = self.truncate_content(raw_content).to_string();
188
189 let start = std::time::Instant::now();
190 let mut last_error: Option<StygianError> = None;
191
192 for provider in &self.providers {
193 debug!(provider = provider.name(), "Attempting LLM extraction");
194
195 match provider.extract(content.clone(), schema.clone()).await {
196 Ok(extracted) => {
197 if self.config.validate_output
198 && let Err(e) = Self::validate_output(&extracted)
199 {
200 warn!(
201 provider = provider.name(),
202 error = %e,
203 "Provider returned invalid output, trying next"
204 );
205 last_error = Some(e);
206 continue;
207 }
208
209 let elapsed = start.elapsed();
210 info!(
211 provider = provider.name(),
212 elapsed_ms = elapsed.as_millis(),
213 "LLM extraction succeeded"
214 );
215
216 return Ok(ServiceOutput {
217 data: extracted.to_string(),
218 metadata: json!({
219 "provider": provider.name(),
220 "elapsed_ms": elapsed.as_millis(),
221 "content_chars": content.len(),
222 }),
223 });
224 }
225 Err(e) => {
226 warn!(
227 provider = provider.name(),
228 error = %e,
229 "Provider failed, trying next in chain"
230 );
231 last_error = Some(e);
232 }
233 }
234 }
235
236 Err(last_error.unwrap_or_else(|| {
238 StygianError::Provider(ProviderError::ApiError(
239 "All AI providers in fallback chain failed".to_string(),
240 ))
241 }))
242 }
243
244 fn name(&self) -> &'static str {
245 "llm-extraction"
246 }
247}
248
249#[cfg(test)]
250#[allow(
251 clippy::unwrap_used,
252 clippy::indexing_slicing,
253 clippy::needless_pass_by_value
254)]
255mod tests {
256 use super::*;
257 use crate::ports::ProviderCapabilities;
258 use futures::stream::{self, BoxStream};
259 use serde_json::json;
260
261 struct AlwaysSucceed {
264 response: Value,
265 }
266
267 #[async_trait]
268 impl AIProvider for AlwaysSucceed {
269 async fn extract(&self, _content: String, _schema: Value) -> Result<Value> {
270 Ok(self.response.clone())
271 }
272
273 async fn stream_extract(
274 &self,
275 _content: String,
276 _schema: Value,
277 ) -> Result<BoxStream<'static, Result<Value>>> {
278 Ok(Box::pin(stream::once(async { Ok(json!({})) })))
279 }
280
281 fn capabilities(&self) -> ProviderCapabilities {
282 ProviderCapabilities::default()
283 }
284
285 fn name(&self) -> &'static str {
286 "mock-succeed"
287 }
288 }
289
290 struct AlwaysFail;
291
292 #[async_trait]
293 impl AIProvider for AlwaysFail {
294 async fn extract(&self, _content: String, _schema: Value) -> Result<Value> {
295 Err(StygianError::Provider(ProviderError::ApiError(
296 "mock failure".to_string(),
297 )))
298 }
299
300 async fn stream_extract(
301 &self,
302 _content: String,
303 _schema: Value,
304 ) -> Result<BoxStream<'static, Result<Value>>> {
305 Err(StygianError::Provider(ProviderError::ApiError(
306 "mock failure".to_string(),
307 )))
308 }
309
310 fn capabilities(&self) -> ProviderCapabilities {
311 ProviderCapabilities::default()
312 }
313
314 fn name(&self) -> &'static str {
315 "mock-fail"
316 }
317 }
318
319 fn make_input(schema: Value) -> ServiceInput {
320 ServiceInput {
321 url: "<h1>Hello</h1>".to_string(),
322 params: json!({ "schema": schema }),
323 }
324 }
325
326 #[tokio::test]
327 async fn test_service_name() {
328 let svc = LlmExtractionService::new(vec![], ExtractionConfig::default());
329 assert_eq!(svc.name(), "llm-extraction");
330 }
331
332 #[tokio::test]
333 async fn test_no_providers_returns_error() {
334 let svc = LlmExtractionService::new(vec![], ExtractionConfig::default());
335 let err = svc.execute(make_input(json!({}))).await.unwrap_err();
336 assert!(err.to_string().contains("No AI providers"));
337 }
338
339 #[tokio::test]
340 async fn test_missing_schema_returns_error() {
341 let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysSucceed {
342 response: json!({}),
343 })];
344 let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
345 let input = ServiceInput {
346 url: "some content".to_string(),
347 params: json!({}), };
349 let err = svc.execute(input).await.unwrap_err();
350 assert!(err.to_string().contains("schema"));
351 }
352
353 #[tokio::test]
354 async fn test_single_succeeding_provider() {
355 let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysSucceed {
356 response: json!({"title": "Hello"}),
357 })];
358 let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
359 let output = svc.execute(make_input(json!({}))).await.unwrap();
360 assert_eq!(
361 output.metadata["provider"].as_str().unwrap(),
362 "mock-succeed"
363 );
364 let data: Value = serde_json::from_str(&output.data).unwrap();
365 assert_eq!(data["title"].as_str().unwrap(), "Hello");
366 }
367
368 #[tokio::test]
369 async fn test_fallback_to_second_provider() {
370 let providers: Vec<Arc<dyn AIProvider>> = vec![
371 Arc::new(AlwaysFail),
372 Arc::new(AlwaysSucceed {
373 response: json!({"score": 42}),
374 }),
375 ];
376 let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
377 let output = svc.execute(make_input(json!({}))).await.unwrap();
378 assert_eq!(
379 output.metadata["provider"].as_str().unwrap(),
380 "mock-succeed"
381 );
382 }
383
384 #[tokio::test]
385 async fn test_all_providers_fail() {
386 let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysFail), Arc::new(AlwaysFail)];
387 let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
388 let err = svc.execute(make_input(json!({}))).await.unwrap_err();
389 assert!(err.to_string().contains("mock failure"));
390 }
391
392 #[tokio::test]
393 async fn test_content_from_params_overrides_url() {
394 let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysSucceed {
395 response: json!({"ok": true}),
396 })];
397 let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
398 let input = ServiceInput {
399 url: "should-not-be-used".to_string(),
400 params: json!({
401 "schema": {"type": "object"},
402 "content": "actual content here"
403 }),
404 };
405 let output = svc.execute(input).await.unwrap();
406 assert_eq!(output.metadata["content_chars"].as_u64().unwrap(), 19);
408 }
409
410 #[test]
411 fn test_truncate_content_short() {
412 let svc = LlmExtractionService::new(vec![], ExtractionConfig::default());
413 let s = "hello";
414 assert_eq!(svc.truncate_content(s), s);
415 }
416
417 #[test]
418 fn test_truncate_content_long() {
419 let svc = LlmExtractionService::new(
420 vec![],
421 ExtractionConfig {
422 max_content_chars: 5,
423 ..Default::default()
424 },
425 );
426 assert_eq!(svc.truncate_content("hello world"), "hello");
427 }
428
429 #[test]
430 fn test_validate_output_object_ok() {
431 assert!(LlmExtractionService::validate_output(&json!({"k": "v"})).is_ok());
432 }
433
434 #[test]
435 fn test_validate_output_array_ok() {
436 assert!(LlmExtractionService::validate_output(&json!([1, 2, 3])).is_ok());
437 }
438
439 #[test]
440 fn test_validate_output_scalar_err() {
441 assert!(LlmExtractionService::validate_output(&json!("just a string")).is_err());
442 }
443}