stygian_graph/application/
extraction.rs

1//! LLM Extraction Service — orchestrator that uses AIProvider to extract structured data
2//!
3//! Implements `ScrapingService` by delegating to one or more `AIProvider`s.
4//! Supports provider fallback: tries providers in order until one succeeds.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ScrapingService  ←  LlmExtractionService  →  AIProvider (Claude, GPT, Gemini, …)
10//!       ↑                     ↓
11//!  ServiceInput           FallbackChain
12//!  { url, params }     [primary, secondary, …]
13//! ```
14//!
15//! The `params` field of `ServiceInput` must contain:
16//! - `schema`: JSON schema object defining the expected output shape
17//! - `content` (optional): If present, used as-is. Otherwise `data` from a prior
18//!   pipeline stage should be passed via `url`.
19//!
20//! # Example
21//!
22//! ```no_run
23//! use stygian_graph::application::extraction::{LlmExtractionService, ExtractionConfig};
24//! use stygian_graph::ports::{ScrapingService, ServiceInput};
25//! use serde_json::json;
26//!
27//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
28//! // Provider built separately — inject via Arc<dyn AIProvider>
29//! // let service = LlmExtractionService::new(providers, ExtractionConfig::default());
30//! // let output = service.execute(input).await.unwrap();
31//! # });
32//! ```
33
34use std::sync::Arc;
35
36use async_trait::async_trait;
37use serde_json::{Value, json};
38use tracing::{debug, info, warn};
39
40use crate::domain::error::{ProviderError, Result, StygianError};
41use crate::ports::{AIProvider, ScrapingService, ServiceInput, ServiceOutput};
42
43/// Configuration for the LLM extraction service
44#[derive(Debug, Clone)]
45pub struct ExtractionConfig {
46    /// Maximum content length sent to providers (characters).
47    /// Content is truncated at this limit to avoid token overflow.
48    pub max_content_chars: usize,
49    /// Whether to validate the provider output against the schema.
50    /// Currently performs a structural check (is the output a JSON object?).
51    pub validate_output: bool,
52}
53
54impl Default for ExtractionConfig {
55    fn default() -> Self {
56        Self {
57            max_content_chars: 64_000,
58            validate_output: true,
59        }
60    }
61}
62
63/// LLM-based structured data extraction service
64///
65/// Wraps one or more `AIProvider` instances and implements `ScrapingService`.
66/// On each `execute()` call the service:
67///
68/// 1. Reads `schema` and optionally `content` from `ServiceInput.params`.
69/// 2. Iterates through the provider list until one returns `Ok`.
70/// 3. Returns extracted data in `ServiceOutput.data` (as JSON string).
71/// 4. Metadata includes which provider succeeded and elapsed time.
72///
73/// # Provider Fallback
74///
75/// Providers are tried **in the order they were added**. The first success
76/// short-circuits the chain. Errors from skipped providers are logged as
77/// warnings, not propagated.
78pub struct LlmExtractionService {
79    /// Ordered fallback chain of AI providers
80    providers: Vec<Arc<dyn AIProvider>>,
81    config: ExtractionConfig,
82}
83
84impl LlmExtractionService {
85    /// Create a new extraction service with an ordered fallback chain
86    ///
87    /// # Example
88    ///
89    /// ```no_run
90    /// use stygian_graph::application::extraction::{LlmExtractionService, ExtractionConfig};
91    /// use stygian_graph::adapters::ai::ollama::OllamaProvider;
92    /// use std::sync::Arc;
93    ///
94    /// let providers: Vec<Arc<dyn stygian_graph::ports::AIProvider>> = vec![
95    ///     Arc::new(OllamaProvider::new()),
96    /// ];
97    /// let service = LlmExtractionService::new(providers, ExtractionConfig::default());
98    /// ```
99    pub fn new(providers: Vec<Arc<dyn AIProvider>>, config: ExtractionConfig) -> Self {
100        Self { providers, config }
101    }
102
103    /// Resolve the content to extract from.
104    ///
105    /// Priority:
106    /// 1. `params["content"]` if present
107    /// 2. `input.url` as fallback (useful when this node receives raw HTML from prior stage)
108    fn resolve_content(input: &ServiceInput) -> &str {
109        input
110            .params
111            .get("content")
112            .and_then(Value::as_str)
113            .unwrap_or(&input.url)
114    }
115
116    /// Truncate content to the configured character limit
117    fn truncate_content<'a>(&self, content: &'a str) -> &'a str {
118        if content.len() <= self.config.max_content_chars {
119            content
120        } else {
121            warn!(
122                limit = self.config.max_content_chars,
123                actual = content.len(),
124                "Content truncated for LLM extraction"
125            );
126            &content[..self.config.max_content_chars]
127        }
128    }
129
130    /// Extract the `schema` from params, returning an error if missing
131    fn resolve_schema(input: &ServiceInput) -> Result<Value> {
132        input.params.get("schema").cloned().ok_or_else(|| {
133            StygianError::Provider(ProviderError::ApiError(
134                "LlmExtractionService requires 'schema' in ServiceInput.params".to_string(),
135            ))
136        })
137    }
138
139    /// Validate that extracted output is a JSON object (basic schema check)
140    fn validate_output(output: &Value) -> Result<()> {
141        if output.is_object() || output.is_array() {
142            Ok(())
143        } else {
144            Err(StygianError::Provider(ProviderError::ApiError(format!(
145                "Provider returned non-object output: {output}"
146            ))))
147        }
148    }
149}
150
151#[async_trait]
152impl ScrapingService for LlmExtractionService {
153    /// Execute structured extraction via the provider fallback chain
154    ///
155    /// # Example
156    ///
157    /// ```no_run
158    /// use stygian_graph::application::extraction::{LlmExtractionService, ExtractionConfig};
159    /// use stygian_graph::adapters::ai::ollama::OllamaProvider;
160    /// use stygian_graph::ports::{ScrapingService, ServiceInput};
161    /// use serde_json::json;
162    /// use std::sync::Arc;
163    ///
164    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
165    /// let providers: Vec<Arc<dyn stygian_graph::ports::AIProvider>> = vec![
166    ///     Arc::new(OllamaProvider::new()),
167    /// ];
168    /// let service = LlmExtractionService::new(providers, ExtractionConfig::default());
169    /// let input = ServiceInput {
170    ///     url: "<h1>Hello World</h1>".to_string(),
171    ///     params: json!({
172    ///         "schema": {"type": "object", "properties": {"heading": {"type": "string"}}},
173    ///     }),
174    /// };
175    /// // let output = service.execute(input).await.unwrap();
176    /// # });
177    /// ```
178    async fn execute(&self, input: ServiceInput) -> Result<ServiceOutput> {
179        if self.providers.is_empty() {
180            return Err(StygianError::Provider(ProviderError::ApiError(
181                "No AI providers configured in LlmExtractionService".to_string(),
182            )));
183        }
184
185        let schema = Self::resolve_schema(&input)?;
186        let raw_content = Self::resolve_content(&input);
187        let content = self.truncate_content(raw_content).to_string();
188
189        let start = std::time::Instant::now();
190        let mut last_error: Option<StygianError> = None;
191
192        for provider in &self.providers {
193            debug!(provider = provider.name(), "Attempting LLM extraction");
194
195            match provider.extract(content.clone(), schema.clone()).await {
196                Ok(extracted) => {
197                    if self.config.validate_output
198                        && let Err(e) = Self::validate_output(&extracted)
199                    {
200                        warn!(
201                            provider = provider.name(),
202                            error = %e,
203                            "Provider returned invalid output, trying next"
204                        );
205                        last_error = Some(e);
206                        continue;
207                    }
208
209                    let elapsed = start.elapsed();
210                    info!(
211                        provider = provider.name(),
212                        elapsed_ms = elapsed.as_millis(),
213                        "LLM extraction succeeded"
214                    );
215
216                    return Ok(ServiceOutput {
217                        data: extracted.to_string(),
218                        metadata: json!({
219                            "provider": provider.name(),
220                            "elapsed_ms": elapsed.as_millis(),
221                            "content_chars": content.len(),
222                        }),
223                    });
224                }
225                Err(e) => {
226                    warn!(
227                        provider = provider.name(),
228                        error = %e,
229                        "Provider failed, trying next in chain"
230                    );
231                    last_error = Some(e);
232                }
233            }
234        }
235
236        // All providers failed
237        Err(last_error.unwrap_or_else(|| {
238            StygianError::Provider(ProviderError::ApiError(
239                "All AI providers in fallback chain failed".to_string(),
240            ))
241        }))
242    }
243
244    fn name(&self) -> &'static str {
245        "llm-extraction"
246    }
247}
248
249#[cfg(test)]
250#[allow(
251    clippy::unwrap_used,
252    clippy::indexing_slicing,
253    clippy::needless_pass_by_value
254)]
255mod tests {
256    use super::*;
257    use crate::ports::ProviderCapabilities;
258    use futures::stream::{self, BoxStream};
259    use serde_json::json;
260
261    // --- Mock AIProvider for tests ---
262
263    struct AlwaysSucceed {
264        response: Value,
265    }
266
267    #[async_trait]
268    impl AIProvider for AlwaysSucceed {
269        async fn extract(&self, _content: String, _schema: Value) -> Result<Value> {
270            Ok(self.response.clone())
271        }
272
273        async fn stream_extract(
274            &self,
275            _content: String,
276            _schema: Value,
277        ) -> Result<BoxStream<'static, Result<Value>>> {
278            Ok(Box::pin(stream::once(async { Ok(json!({})) })))
279        }
280
281        fn capabilities(&self) -> ProviderCapabilities {
282            ProviderCapabilities::default()
283        }
284
285        fn name(&self) -> &'static str {
286            "mock-succeed"
287        }
288    }
289
290    struct AlwaysFail;
291
292    #[async_trait]
293    impl AIProvider for AlwaysFail {
294        async fn extract(&self, _content: String, _schema: Value) -> Result<Value> {
295            Err(StygianError::Provider(ProviderError::ApiError(
296                "mock failure".to_string(),
297            )))
298        }
299
300        async fn stream_extract(
301            &self,
302            _content: String,
303            _schema: Value,
304        ) -> Result<BoxStream<'static, Result<Value>>> {
305            Err(StygianError::Provider(ProviderError::ApiError(
306                "mock failure".to_string(),
307            )))
308        }
309
310        fn capabilities(&self) -> ProviderCapabilities {
311            ProviderCapabilities::default()
312        }
313
314        fn name(&self) -> &'static str {
315            "mock-fail"
316        }
317    }
318
319    fn make_input(schema: Value) -> ServiceInput {
320        ServiceInput {
321            url: "<h1>Hello</h1>".to_string(),
322            params: json!({ "schema": schema }),
323        }
324    }
325
326    #[tokio::test]
327    async fn test_service_name() {
328        let svc = LlmExtractionService::new(vec![], ExtractionConfig::default());
329        assert_eq!(svc.name(), "llm-extraction");
330    }
331
332    #[tokio::test]
333    async fn test_no_providers_returns_error() {
334        let svc = LlmExtractionService::new(vec![], ExtractionConfig::default());
335        let err = svc.execute(make_input(json!({}))).await.unwrap_err();
336        assert!(err.to_string().contains("No AI providers"));
337    }
338
339    #[tokio::test]
340    async fn test_missing_schema_returns_error() {
341        let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysSucceed {
342            response: json!({}),
343        })];
344        let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
345        let input = ServiceInput {
346            url: "some content".to_string(),
347            params: json!({}), // no schema key
348        };
349        let err = svc.execute(input).await.unwrap_err();
350        assert!(err.to_string().contains("schema"));
351    }
352
353    #[tokio::test]
354    async fn test_single_succeeding_provider() {
355        let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysSucceed {
356            response: json!({"title": "Hello"}),
357        })];
358        let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
359        let output = svc.execute(make_input(json!({}))).await.unwrap();
360        assert_eq!(
361            output.metadata["provider"].as_str().unwrap(),
362            "mock-succeed"
363        );
364        let data: Value = serde_json::from_str(&output.data).unwrap();
365        assert_eq!(data["title"].as_str().unwrap(), "Hello");
366    }
367
368    #[tokio::test]
369    async fn test_fallback_to_second_provider() {
370        let providers: Vec<Arc<dyn AIProvider>> = vec![
371            Arc::new(AlwaysFail),
372            Arc::new(AlwaysSucceed {
373                response: json!({"score": 42}),
374            }),
375        ];
376        let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
377        let output = svc.execute(make_input(json!({}))).await.unwrap();
378        assert_eq!(
379            output.metadata["provider"].as_str().unwrap(),
380            "mock-succeed"
381        );
382    }
383
384    #[tokio::test]
385    async fn test_all_providers_fail() {
386        let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysFail), Arc::new(AlwaysFail)];
387        let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
388        let err = svc.execute(make_input(json!({}))).await.unwrap_err();
389        assert!(err.to_string().contains("mock failure"));
390    }
391
392    #[tokio::test]
393    async fn test_content_from_params_overrides_url() {
394        let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysSucceed {
395            response: json!({"ok": true}),
396        })];
397        let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
398        let input = ServiceInput {
399            url: "should-not-be-used".to_string(),
400            params: json!({
401                "schema": {"type": "object"},
402                "content": "actual content here"
403            }),
404        };
405        let output = svc.execute(input).await.unwrap();
406        // Metadata should reflect char count of "actual content here" (19 chars)
407        assert_eq!(output.metadata["content_chars"].as_u64().unwrap(), 19);
408    }
409
410    #[test]
411    fn test_truncate_content_short() {
412        let svc = LlmExtractionService::new(vec![], ExtractionConfig::default());
413        let s = "hello";
414        assert_eq!(svc.truncate_content(s), s);
415    }
416
417    #[test]
418    fn test_truncate_content_long() {
419        let svc = LlmExtractionService::new(
420            vec![],
421            ExtractionConfig {
422                max_content_chars: 5,
423                ..Default::default()
424            },
425        );
426        assert_eq!(svc.truncate_content("hello world"), "hello");
427    }
428
429    #[test]
430    fn test_validate_output_object_ok() {
431        assert!(LlmExtractionService::validate_output(&json!({"k": "v"})).is_ok());
432    }
433
434    #[test]
435    fn test_validate_output_array_ok() {
436        assert!(LlmExtractionService::validate_output(&json!([1, 2, 3])).is_ok());
437    }
438
439    #[test]
440    fn test_validate_output_scalar_err() {
441        assert!(LlmExtractionService::validate_output(&json!("just a string")).is_err());
442    }
443}