1use std::sync::Arc;
35
36use async_trait::async_trait;
37use serde_json::{Value, json};
38
39use crate::domain::error::{ProviderError, Result, ServiceError, StygianError};
40use crate::ports::{AIProvider, ScrapingService, ServiceInput, ServiceOutput};
41
42#[derive(Debug, Clone, PartialEq, Eq)]
44pub enum ContentType {
45 Csv,
47 Json,
49 Xml,
51 Image(String),
53 Pdf,
55 Unknown(String),
57}
58
59impl ContentType {
60 #[allow(clippy::case_sensitive_file_extension_comparisons)]
62 pub fn detect(mime_or_ext: &str) -> Self {
63 let lower = mime_or_ext.to_lowercase();
64 if lower.contains("csv") || lower.ends_with(".csv") {
65 Self::Csv
66 } else if lower.contains("json") || lower.ends_with(".json") {
67 Self::Json
68 } else if lower.contains("xml") || lower.ends_with(".xml") || lower.ends_with(".html") {
69 Self::Xml
70 } else if lower.contains("image/")
71 || lower.ends_with(".jpg")
72 || lower.ends_with(".jpeg")
73 || lower.ends_with(".png")
74 || lower.ends_with(".gif")
75 || lower.ends_with(".webp")
76 {
77 Self::Image(lower)
78 } else if lower.contains("pdf") || lower.ends_with(".pdf") {
79 Self::Pdf
80 } else {
81 Self::Unknown(lower)
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct MultiModalConfig {
89 pub max_csv_rows: usize,
91 pub default_image_schema: Value,
93 pub pdf_enabled: bool,
95}
96
97impl Default for MultiModalConfig {
98 fn default() -> Self {
99 Self {
100 max_csv_rows: 10_000,
101 default_image_schema: json!({
102 "type": "object",
103 "properties": {
104 "description": {"type": "string"},
105 "text_content": {"type": "string"},
106 "objects": {"type": "array", "items": {"type": "string"}}
107 }
108 }),
109 pdf_enabled: false,
110 }
111 }
112}
113
114pub struct MultiModalAdapter {
121 config: MultiModalConfig,
122 vision_provider: Option<Arc<dyn AIProvider>>,
124}
125
126impl MultiModalAdapter {
127 pub fn new(config: MultiModalConfig, vision_provider: Option<Arc<dyn AIProvider>>) -> Self {
142 Self {
143 config,
144 vision_provider,
145 }
146 }
147
148 #[allow(clippy::unnecessary_wraps)]
150 fn parse_csv(&self, data: &str) -> Result<Value> {
151 let mut lines = data.lines();
152 let headers: Vec<&str> = match lines.next() {
153 Some(h) => h.split(',').map(str::trim).collect(),
154 None => {
155 return Ok(json!({"rows": [], "row_count": 0}));
156 }
157 };
158
159 let mut rows = Vec::new();
160 for (i, line) in lines.enumerate() {
161 if i >= self.config.max_csv_rows {
162 break;
163 }
164 let values: Vec<&str> = line.split(',').map(str::trim).collect();
165 let mut obj = serde_json::Map::new();
166 for (header, val) in headers.iter().zip(values.iter()) {
167 if let Ok(n) = val.parse::<f64>() {
169 obj.insert((*header).to_string(), json!(n));
170 } else {
171 obj.insert((*header).to_string(), json!(*val));
172 }
173 }
174 rows.push(Value::Object(obj));
175 }
176
177 let row_count = rows.len();
178 Ok(json!({
179 "rows": rows,
180 "row_count": row_count,
181 "columns": headers
182 }))
183 }
184
185 fn parse_json(data: &str) -> Result<Value> {
187 serde_json::from_str(data).map_err(|e| {
188 StygianError::Service(ServiceError::InvalidResponse(format!(
189 "Failed to parse JSON content: {e}"
190 )))
191 })
192 }
193
194 fn parse_xml(data: &str) -> Value {
199 let mut text = String::with_capacity(data.len());
201 let mut in_tag = false;
202 for ch in data.chars() {
203 match ch {
204 '<' => in_tag = true,
205 '>' => in_tag = false,
206 c if !in_tag => text.push(c),
207 _ => {}
208 }
209 }
210
211 let cleaned: String = text.split_whitespace().collect::<Vec<_>>().join(" ");
213 json!({
214 "text_content": cleaned,
215 "raw_length": data.len()
216 })
217 }
218
219 async fn extract_image(&self, data: &str, schema: &Value) -> Result<Value> {
221 match &self.vision_provider {
222 Some(provider) => {
223 if !provider.capabilities().vision {
224 return Err(StygianError::Provider(ProviderError::ApiError(format!(
225 "Configured vision provider '{}' does not support vision",
226 provider.name()
227 ))));
228 }
229 provider.extract(data.to_string(), schema.clone()).await
230 }
231 None => {
232 Ok(json!({
234 "status": "no_vision_provider",
235 "message": "Inject a vision-capable AIProvider to enable image understanding",
236 "data_length": data.len()
237 }))
238 }
239 }
240 }
241
242 fn extract_pdf(data: &str, enabled: bool) -> Value {
244 if enabled {
245 json!({
247 "status": "pdf_extraction_stub",
248 "message": "PDF text extraction requires the 'pdf' feature flag",
249 "data_length": data.len()
250 })
251 } else {
252 json!({
253 "status": "pdf_disabled",
254 "message": "PDF extraction is disabled. Set MultiModalConfig::pdf_enabled = true",
255 "data_length": data.len()
256 })
257 }
258 }
259}
260
261#[async_trait]
262impl ScrapingService for MultiModalAdapter {
263 async fn execute(&self, input: ServiceInput) -> Result<ServiceOutput> {
285 let mime = input
286 .params
287 .get("content_type")
288 .and_then(Value::as_str)
289 .unwrap_or("unknown");
290
291 let content = input
292 .params
293 .get("content")
294 .and_then(Value::as_str)
295 .unwrap_or(&input.url);
296
297 let content_type = ContentType::detect(mime);
298
299 let (extracted, type_name) = match &content_type {
300 ContentType::Csv => (self.parse_csv(content)?, "csv"),
301 ContentType::Json => (Self::parse_json(content)?, "json"),
302 ContentType::Xml => (Self::parse_xml(content), "xml"),
303 ContentType::Image(_) => {
304 let schema = input
305 .params
306 .get("schema")
307 .cloned()
308 .unwrap_or_else(|| self.config.default_image_schema.clone());
309 (self.extract_image(content, &schema).await?, "image")
310 }
311 ContentType::Pdf => (Self::extract_pdf(content, self.config.pdf_enabled), "pdf"),
312 ContentType::Unknown(_) => (json!({"raw": content}), "unknown"),
313 };
314
315 Ok(ServiceOutput {
316 data: extracted.to_string(),
317 metadata: json!({
318 "content_type": mime,
319 "detected_type": type_name,
320 "input_length": content.len(),
321 }),
322 })
323 }
324
325 fn name(&self) -> &'static str {
326 "multimodal"
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use serde_json::json;
334
335 fn adapter() -> MultiModalAdapter {
336 MultiModalAdapter::new(MultiModalConfig::default(), None)
337 }
338
339 fn input(content_type: &str, data: &str) -> ServiceInput {
340 ServiceInput {
341 url: data.to_string(),
342 params: json!({ "content_type": content_type }),
343 }
344 }
345
346 #[test]
347 fn test_name() {
348 assert_eq!(adapter().name(), "multimodal");
349 }
350
351 #[test]
354 fn test_detect_csv() {
355 assert_eq!(ContentType::detect("text/csv"), ContentType::Csv);
356 assert_eq!(ContentType::detect("file.csv"), ContentType::Csv);
357 }
358
359 #[test]
360 fn test_detect_json() {
361 assert_eq!(ContentType::detect("application/json"), ContentType::Json);
362 }
363
364 #[test]
365 fn test_detect_xml() {
366 assert_eq!(ContentType::detect("text/xml"), ContentType::Xml);
367 }
368
369 #[test]
370 fn test_detect_image() {
371 assert!(matches!(
372 ContentType::detect("image/png"),
373 ContentType::Image(_)
374 ));
375 assert!(matches!(
376 ContentType::detect("photo.jpg"),
377 ContentType::Image(_)
378 ));
379 }
380
381 #[test]
382 fn test_detect_pdf() {
383 assert_eq!(ContentType::detect("application/pdf"), ContentType::Pdf);
384 }
385
386 #[allow(clippy::float_cmp)]
389 #[test]
390 fn test_parse_csv_basic() -> crate::domain::error::Result<()> {
391 let a = adapter();
392 let result = a.parse_csv("name,age\nalice,30\nbob,25")?;
393 assert_eq!(result.get("row_count").and_then(Value::as_u64), Some(2));
394 assert_eq!(
395 result
396 .get("rows")
397 .and_then(|r| r.get(0))
398 .and_then(|row| row.get("name"))
399 .and_then(Value::as_str),
400 Some("alice")
401 );
402 assert_eq!(
403 result
404 .get("rows")
405 .and_then(|r| r.get(0))
406 .and_then(|row| row.get("age"))
407 .and_then(Value::as_f64),
408 Some(30.0)
409 );
410 Ok(())
411 }
412
413 #[test]
414 fn test_parse_csv_empty() -> crate::domain::error::Result<()> {
415 let a = adapter();
416 let result = a.parse_csv("")?;
417 assert_eq!(result.get("row_count").and_then(Value::as_u64), Some(0));
418 Ok(())
419 }
420
421 #[test]
422 fn test_parse_csv_headers_only() -> crate::domain::error::Result<()> {
423 let a = adapter();
424 let result = a.parse_csv("col1,col2")?;
425 assert_eq!(result.get("row_count").and_then(Value::as_u64), Some(0));
426 Ok(())
427 }
428
429 #[test]
432 fn test_parse_json_valid() -> crate::domain::error::Result<()> {
433 let result = MultiModalAdapter::parse_json(r#"{"hello": "world"}"#)?;
434 assert_eq!(result.get("hello").and_then(Value::as_str), Some("world"));
435 Ok(())
436 }
437
438 #[test]
439 fn test_parse_json_invalid() {
440 assert!(MultiModalAdapter::parse_json("not json").is_err());
441 }
442
443 #[test]
446 fn test_parse_xml_strips_tags() {
447 let result = MultiModalAdapter::parse_xml("<root><name>Alice</name></root>");
448 let text = result
449 .get("text_content")
450 .and_then(Value::as_str)
451 .unwrap_or("");
452 assert!(text.contains("Alice"));
453 assert!(!text.contains('<'));
454 }
455
456 #[test]
459 fn test_pdf_disabled_returns_status() {
460 let result = MultiModalAdapter::extract_pdf("data", false);
461 assert_eq!(
462 result.get("status").and_then(Value::as_str),
463 Some("pdf_disabled")
464 );
465 }
466
467 #[tokio::test]
470 async fn test_execute_csv() -> crate::domain::error::Result<()> {
471 let a = adapter();
472 let output = a.execute(input("text/csv", "x,y\n1,2")).await?;
473 let data: Value = serde_json::from_str(&output.data)
474 .map_err(|e| ServiceError::InvalidResponse(e.to_string()))?;
475 assert_eq!(data.get("row_count").and_then(Value::as_u64), Some(1));
476 assert_eq!(
477 output.metadata.get("detected_type").and_then(Value::as_str),
478 Some("csv")
479 );
480 Ok(())
481 }
482
483 #[tokio::test]
484 async fn test_execute_json() -> crate::domain::error::Result<()> {
485 let a = adapter();
486 let out = a
487 .execute(input("application/json", r#"{"k": "v"}"#))
488 .await?;
489 let data: Value = serde_json::from_str(&out.data)
490 .map_err(|e| ServiceError::InvalidResponse(e.to_string()))?;
491 assert_eq!(data.get("k").and_then(Value::as_str), Some("v"));
492 Ok(())
493 }
494
495 #[tokio::test]
496 async fn test_execute_image_no_provider() -> crate::domain::error::Result<()> {
497 let a = adapter();
498 let out = a.execute(input("image/png", "binary-data")).await?;
499 let data: Value = serde_json::from_str(&out.data)
500 .map_err(|e| ServiceError::InvalidResponse(e.to_string()))?;
501 assert_eq!(
502 data.get("status").and_then(Value::as_str),
503 Some("no_vision_provider")
504 );
505 Ok(())
506 }
507
508 #[tokio::test]
509 async fn test_execute_unknown_passthrough() -> crate::domain::error::Result<()> {
510 let a = adapter();
511 let out = a.execute(input("application/octet-stream", "raw")).await?;
512 let data: Value = serde_json::from_str(&out.data)
513 .map_err(|e| ServiceError::InvalidResponse(e.to_string()))?;
514 assert_eq!(data.get("raw").and_then(Value::as_str), Some("raw"));
515 Ok(())
516 }
517
518 #[tokio::test]
519 async fn test_content_from_params_overrides_url() -> crate::domain::error::Result<()> {
520 let a = adapter();
521 let input = ServiceInput {
522 url: "should-not-be-used".to_string(),
523 params: json!({
524 "content_type": "application/json",
525 "content": "{\"answer\": 42}"
526 }),
527 };
528 let out = a.execute(input).await?;
529 let data: Value = serde_json::from_str(&out.data)
530 .map_err(|e| ServiceError::InvalidResponse(e.to_string()))?;
531 assert_eq!(data.get("answer").and_then(Value::as_u64), Some(42));
532 Ok(())
533 }
534}