Skip to main content

stygian_graph/application/
extraction.rs

1//! LLM Extraction Service — orchestrator that uses AIProvider to extract structured data
2//!
3//! Implements `ScrapingService` by delegating to one or more `AIProvider`s.
4//! Supports provider fallback: tries providers in order until one succeeds.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ScrapingService  ←  LlmExtractionService  →  AIProvider (Claude, GPT, Gemini, …)
10//!       ↑                     ↓
11//!  ServiceInput           FallbackChain
12//!  { url, params }     [primary, secondary, …]
13//! ```
14//!
15//! The `params` field of `ServiceInput` must contain:
16//! - `schema`: JSON schema object defining the expected output shape
17//! - `content` (optional): If present, used as-is. Otherwise `data` from a prior
18//!   pipeline stage should be passed via `url`.
19//!
20//! # Example
21//!
22//! ```no_run
23//! use stygian_graph::application::extraction::{LlmExtractionService, ExtractionConfig};
24//! use stygian_graph::ports::{ScrapingService, ServiceInput};
25//! use serde_json::json;
26//!
27//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
28//! // Provider built separately — inject via Arc<dyn AIProvider>
29//! // let service = LlmExtractionService::new(providers, ExtractionConfig::default());
30//! // let output = service.execute(input).await.unwrap();
31//! # });
32//! ```
33
34use std::sync::Arc;
35
36use async_trait::async_trait;
37use serde_json::{Value, json};
38use tracing::{debug, info, warn};
39
40use crate::domain::error::{ProviderError, Result, StygianError};
41use crate::ports::{AIProvider, ScrapingService, ServiceInput, ServiceOutput};
42
43/// Configuration for the LLM extraction service
44#[derive(Debug, Clone)]
45pub struct ExtractionConfig {
46    /// Maximum content length sent to providers (characters).
47    /// Content is truncated at this limit to avoid token overflow.
48    pub max_content_chars: usize,
49    /// Whether to validate the provider output against the schema.
50    /// Currently performs a structural check (is the output a JSON object?).
51    pub validate_output: bool,
52}
53
54impl Default for ExtractionConfig {
55    fn default() -> Self {
56        Self {
57            max_content_chars: 64_000,
58            validate_output: true,
59        }
60    }
61}
62
63/// LLM-based structured data extraction service
64///
65/// Wraps one or more `AIProvider` instances and implements `ScrapingService`.
66/// On each `execute()` call the service:
67///
68/// 1. Reads `schema` and optionally `content` from `ServiceInput.params`.
69/// 2. Iterates through the provider list until one returns `Ok`.
70/// 3. Returns extracted data in `ServiceOutput.data` (as JSON string).
71/// 4. Metadata includes which provider succeeded and elapsed time.
72///
73/// # Provider Fallback
74///
75/// Providers are tried **in the order they were added**. The first success
76/// short-circuits the chain. Errors from skipped providers are logged as
77/// warnings, not propagated.
78pub struct LlmExtractionService {
79    /// Ordered fallback chain of AI providers
80    providers: Vec<Arc<dyn AIProvider>>,
81    config: ExtractionConfig,
82}
83
84impl LlmExtractionService {
85    /// Create a new extraction service with an ordered fallback chain
86    ///
87    /// # Example
88    ///
89    /// ```no_run
90    /// use stygian_graph::application::extraction::{LlmExtractionService, ExtractionConfig};
91    /// use stygian_graph::adapters::ai::ollama::OllamaProvider;
92    /// use std::sync::Arc;
93    ///
94    /// let providers: Vec<Arc<dyn stygian_graph::ports::AIProvider>> = vec![
95    ///     Arc::new(OllamaProvider::new()),
96    /// ];
97    /// let service = LlmExtractionService::new(providers, ExtractionConfig::default());
98    /// ```
99    #[must_use]
100    pub fn new(providers: Vec<Arc<dyn AIProvider>>, config: ExtractionConfig) -> Self {
101        Self { providers, config }
102    }
103
104    /// Resolve the content to extract from.
105    ///
106    /// Priority:
107    /// 1. `params["content"]` if present
108    /// 2. `input.url` as fallback (useful when this node receives raw HTML from prior stage)
109    fn resolve_content(input: &ServiceInput) -> &str {
110        input
111            .params
112            .get("content")
113            .and_then(Value::as_str)
114            .unwrap_or(&input.url)
115    }
116
117    /// Truncate content to the configured character limit
118    fn truncate_content<'a>(&self, content: &'a str) -> &'a str {
119        if content.len() <= self.config.max_content_chars {
120            content
121        } else {
122            warn!(
123                limit = self.config.max_content_chars,
124                actual = content.len(),
125                "Content truncated for LLM extraction"
126            );
127            &content[..self.config.max_content_chars]
128        }
129    }
130
131    /// Extract the `schema` from params, returning an error if missing
132    fn resolve_schema(input: &ServiceInput) -> Result<Value> {
133        input.params.get("schema").cloned().ok_or_else(|| {
134            StygianError::Provider(ProviderError::ApiError(
135                "LlmExtractionService requires 'schema' in ServiceInput.params".to_string(),
136            ))
137        })
138    }
139
140    /// Validate that extracted output is a JSON object (basic schema check)
141    fn validate_output(output: &Value) -> Result<()> {
142        if output.is_object() || output.is_array() {
143            Ok(())
144        } else {
145            Err(StygianError::Provider(ProviderError::ApiError(format!(
146                "Provider returned non-object output: {output}"
147            ))))
148        }
149    }
150}
151
152#[async_trait]
153impl ScrapingService for LlmExtractionService {
154    /// Execute structured extraction via the provider fallback chain
155    ///
156    /// # Example
157    ///
158    /// ```no_run
159    /// use stygian_graph::application::extraction::{LlmExtractionService, ExtractionConfig};
160    /// use stygian_graph::adapters::ai::ollama::OllamaProvider;
161    /// use stygian_graph::ports::{ScrapingService, ServiceInput};
162    /// use serde_json::json;
163    /// use std::sync::Arc;
164    ///
165    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
166    /// let providers: Vec<Arc<dyn stygian_graph::ports::AIProvider>> = vec![
167    ///     Arc::new(OllamaProvider::new()),
168    /// ];
169    /// let service = LlmExtractionService::new(providers, ExtractionConfig::default());
170    /// let input = ServiceInput {
171    ///     url: "<h1>Hello World</h1>".to_string(),
172    ///     params: json!({
173    ///         "schema": {"type": "object", "properties": {"heading": {"type": "string"}}},
174    ///     }),
175    /// };
176    /// // let output = service.execute(input).await.unwrap();
177    /// # });
178    /// ```
179    async fn execute(&self, input: ServiceInput) -> Result<ServiceOutput> {
180        if self.providers.is_empty() {
181            return Err(StygianError::Provider(ProviderError::ApiError(
182                "No AI providers configured in LlmExtractionService".to_string(),
183            )));
184        }
185
186        let schema = Self::resolve_schema(&input)?;
187        let raw_content = Self::resolve_content(&input);
188        let content = self.truncate_content(raw_content).to_string();
189
190        let start = std::time::Instant::now();
191        let mut last_error: Option<StygianError> = None;
192
193        for provider in &self.providers {
194            debug!(provider = provider.name(), "Attempting LLM extraction");
195
196            match provider.extract(content.clone(), schema.clone()).await {
197                Ok(extracted) => {
198                    if self.config.validate_output
199                        && let Err(e) = Self::validate_output(&extracted)
200                    {
201                        warn!(
202                            provider = provider.name(),
203                            error = %e,
204                            "Provider returned invalid output, trying next"
205                        );
206                        last_error = Some(e);
207                        continue;
208                    }
209
210                    let elapsed = start.elapsed();
211                    info!(
212                        provider = provider.name(),
213                        elapsed_ms = elapsed.as_millis(),
214                        "LLM extraction succeeded"
215                    );
216
217                    return Ok(ServiceOutput {
218                        data: extracted.to_string(),
219                        metadata: json!({
220                            "provider": provider.name(),
221                            "elapsed_ms": elapsed.as_millis(),
222                            "content_chars": content.len(),
223                        }),
224                    });
225                }
226                Err(e) => {
227                    warn!(
228                        provider = provider.name(),
229                        error = %e,
230                        "Provider failed, trying next in chain"
231                    );
232                    last_error = Some(e);
233                }
234            }
235        }
236
237        // All providers failed
238        Err(last_error.unwrap_or_else(|| {
239            StygianError::Provider(ProviderError::ApiError(
240                "All AI providers in fallback chain failed".to_string(),
241            ))
242        }))
243    }
244
245    fn name(&self) -> &'static str {
246        "llm-extraction"
247    }
248}
249
250#[cfg(test)]
251#[allow(
252    clippy::unwrap_used,
253    clippy::indexing_slicing,
254    clippy::needless_pass_by_value
255)]
256mod tests {
257    use super::*;
258    use crate::ports::ProviderCapabilities;
259    use futures::stream::{self, BoxStream};
260    use serde_json::json;
261
262    // --- Mock AIProvider for tests ---
263
264    struct AlwaysSucceed {
265        response: Value,
266    }
267
268    #[async_trait]
269    impl AIProvider for AlwaysSucceed {
270        async fn extract(&self, _content: String, _schema: Value) -> Result<Value> {
271            Ok(self.response.clone())
272        }
273
274        async fn stream_extract(
275            &self,
276            _content: String,
277            _schema: Value,
278        ) -> Result<BoxStream<'static, Result<Value>>> {
279            Ok(Box::pin(stream::once(async { Ok(json!({})) })))
280        }
281
282        fn capabilities(&self) -> ProviderCapabilities {
283            ProviderCapabilities::default()
284        }
285
286        fn name(&self) -> &'static str {
287            "mock-succeed"
288        }
289    }
290
291    struct AlwaysFail;
292
293    #[async_trait]
294    impl AIProvider for AlwaysFail {
295        async fn extract(&self, _content: String, _schema: Value) -> Result<Value> {
296            Err(StygianError::Provider(ProviderError::ApiError(
297                "mock failure".to_string(),
298            )))
299        }
300
301        async fn stream_extract(
302            &self,
303            _content: String,
304            _schema: Value,
305        ) -> Result<BoxStream<'static, Result<Value>>> {
306            Err(StygianError::Provider(ProviderError::ApiError(
307                "mock failure".to_string(),
308            )))
309        }
310
311        fn capabilities(&self) -> ProviderCapabilities {
312            ProviderCapabilities::default()
313        }
314
315        fn name(&self) -> &'static str {
316            "mock-fail"
317        }
318    }
319
320    fn make_input(schema: Value) -> ServiceInput {
321        ServiceInput {
322            url: "<h1>Hello</h1>".to_string(),
323            params: json!({ "schema": schema }),
324        }
325    }
326
327    #[tokio::test]
328    async fn test_service_name() {
329        let svc = LlmExtractionService::new(vec![], ExtractionConfig::default());
330        assert_eq!(svc.name(), "llm-extraction");
331    }
332
333    #[tokio::test]
334    async fn test_no_providers_returns_error() {
335        let svc = LlmExtractionService::new(vec![], ExtractionConfig::default());
336        let err = svc.execute(make_input(json!({}))).await.unwrap_err();
337        assert!(err.to_string().contains("No AI providers"));
338    }
339
340    #[tokio::test]
341    async fn test_missing_schema_returns_error() {
342        let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysSucceed {
343            response: json!({}),
344        })];
345        let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
346        let input = ServiceInput {
347            url: "some content".to_string(),
348            params: json!({}), // no schema key
349        };
350        let err = svc.execute(input).await.unwrap_err();
351        assert!(err.to_string().contains("schema"));
352    }
353
354    #[tokio::test]
355    async fn test_single_succeeding_provider() {
356        let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysSucceed {
357            response: json!({"title": "Hello"}),
358        })];
359        let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
360        let output = svc.execute(make_input(json!({}))).await.unwrap();
361        assert_eq!(
362            output.metadata["provider"].as_str().unwrap(),
363            "mock-succeed"
364        );
365        let data: Value = serde_json::from_str(&output.data).unwrap();
366        assert_eq!(data["title"].as_str().unwrap(), "Hello");
367    }
368
369    #[tokio::test]
370    async fn test_fallback_to_second_provider() {
371        let providers: Vec<Arc<dyn AIProvider>> = vec![
372            Arc::new(AlwaysFail),
373            Arc::new(AlwaysSucceed {
374                response: json!({"score": 42}),
375            }),
376        ];
377        let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
378        let output = svc.execute(make_input(json!({}))).await.unwrap();
379        assert_eq!(
380            output.metadata["provider"].as_str().unwrap(),
381            "mock-succeed"
382        );
383    }
384
385    #[tokio::test]
386    async fn test_all_providers_fail() {
387        let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysFail), Arc::new(AlwaysFail)];
388        let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
389        let err = svc.execute(make_input(json!({}))).await.unwrap_err();
390        assert!(err.to_string().contains("mock failure"));
391    }
392
393    #[tokio::test]
394    async fn test_content_from_params_overrides_url() {
395        let providers: Vec<Arc<dyn AIProvider>> = vec![Arc::new(AlwaysSucceed {
396            response: json!({"ok": true}),
397        })];
398        let svc = LlmExtractionService::new(providers, ExtractionConfig::default());
399        let input = ServiceInput {
400            url: "should-not-be-used".to_string(),
401            params: json!({
402                "schema": {"type": "object"},
403                "content": "actual content here"
404            }),
405        };
406        let output = svc.execute(input).await.unwrap();
407        // Metadata should reflect char count of "actual content here" (19 chars)
408        assert_eq!(output.metadata["content_chars"].as_u64().unwrap(), 19);
409    }
410
411    #[test]
412    fn test_truncate_content_short() {
413        let svc = LlmExtractionService::new(vec![], ExtractionConfig::default());
414        let s = "hello";
415        assert_eq!(svc.truncate_content(s), s);
416    }
417
418    #[test]
419    fn test_truncate_content_long() {
420        let svc = LlmExtractionService::new(
421            vec![],
422            ExtractionConfig {
423                max_content_chars: 5,
424                ..Default::default()
425            },
426        );
427        assert_eq!(svc.truncate_content("hello world"), "hello");
428    }
429
430    #[test]
431    fn test_validate_output_object_ok() {
432        assert!(LlmExtractionService::validate_output(&json!({"k": "v"})).is_ok());
433    }
434
435    #[test]
436    fn test_validate_output_array_ok() {
437        assert!(LlmExtractionService::validate_output(&json!([1, 2, 3])).is_ok());
438    }
439
440    #[test]
441    fn test_validate_output_scalar_err() {
442        assert!(LlmExtractionService::validate_output(&json!("just a string")).is_err());
443    }
444}