1use std::sync::Arc;
35
36use async_trait::async_trait;
37use serde_json::{Value, json};
38
39use crate::domain::error::{ProviderError, Result, ServiceError, StygianError};
40use crate::ports::{AIProvider, ScrapingService, ServiceInput, ServiceOutput};
41
42#[derive(Debug, Clone, PartialEq, Eq)]
44pub enum ContentType {
45 Csv,
47 Json,
49 Xml,
51 Image(String),
53 Pdf,
55 Unknown(String),
57}
58
59impl ContentType {
60 #[allow(clippy::case_sensitive_file_extension_comparisons)]
62 #[must_use]
63 pub fn detect(mime_or_ext: &str) -> Self {
64 let lower = mime_or_ext.to_lowercase();
65 if lower.contains("csv") || lower.ends_with(".csv") {
66 Self::Csv
67 } else if lower.contains("json") || lower.ends_with(".json") {
68 Self::Json
69 } else if lower.contains("xml") || lower.ends_with(".xml") || lower.ends_with(".html") {
70 Self::Xml
71 } else if lower.contains("image/")
72 || lower.ends_with(".jpg")
73 || lower.ends_with(".jpeg")
74 || lower.ends_with(".png")
75 || lower.ends_with(".gif")
76 || lower.ends_with(".webp")
77 {
78 Self::Image(lower)
79 } else if lower.contains("pdf") || lower.ends_with(".pdf") {
80 Self::Pdf
81 } else {
82 Self::Unknown(lower)
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
89pub struct MultiModalConfig {
90 pub max_csv_rows: usize,
92 pub default_image_schema: Value,
94 pub pdf_enabled: bool,
96}
97
98impl Default for MultiModalConfig {
99 fn default() -> Self {
100 Self {
101 max_csv_rows: 10_000,
102 default_image_schema: json!({
103 "type": "object",
104 "properties": {
105 "description": {"type": "string"},
106 "text_content": {"type": "string"},
107 "objects": {"type": "array", "items": {"type": "string"}}
108 }
109 }),
110 pdf_enabled: false,
111 }
112 }
113}
114
115pub struct MultiModalAdapter {
122 config: MultiModalConfig,
123 vision_provider: Option<Arc<dyn AIProvider>>,
125}
126
127impl MultiModalAdapter {
128 #[must_use]
143 pub fn new(config: MultiModalConfig, vision_provider: Option<Arc<dyn AIProvider>>) -> Self {
144 Self {
145 config,
146 vision_provider,
147 }
148 }
149
150 #[allow(clippy::unnecessary_wraps)]
152 fn parse_csv(&self, data: &str) -> Result<Value> {
153 let mut lines = data.lines();
154 let headers: Vec<&str> = match lines.next() {
155 Some(h) => h.split(',').map(str::trim).collect(),
156 None => {
157 return Ok(json!({"rows": [], "row_count": 0}));
158 }
159 };
160
161 let mut rows = Vec::new();
162 for (i, line) in lines.enumerate() {
163 if i >= self.config.max_csv_rows {
164 break;
165 }
166 let values: Vec<&str> = line.split(',').map(str::trim).collect();
167 let mut obj = serde_json::Map::new();
168 for (header, val) in headers.iter().zip(values.iter()) {
169 if let Ok(n) = val.parse::<f64>() {
171 obj.insert((*header).to_string(), json!(n));
172 } else {
173 obj.insert((*header).to_string(), json!(*val));
174 }
175 }
176 rows.push(Value::Object(obj));
177 }
178
179 let row_count = rows.len();
180 Ok(json!({
181 "rows": rows,
182 "row_count": row_count,
183 "columns": headers
184 }))
185 }
186
187 fn parse_json(data: &str) -> Result<Value> {
189 serde_json::from_str(data).map_err(|e| {
190 StygianError::Service(ServiceError::InvalidResponse(format!(
191 "Failed to parse JSON content: {e}"
192 )))
193 })
194 }
195
196 fn parse_xml(data: &str) -> Value {
201 let mut text = String::with_capacity(data.len());
203 let mut in_tag = false;
204 for ch in data.chars() {
205 match ch {
206 '<' => in_tag = true,
207 '>' => in_tag = false,
208 c if !in_tag => text.push(c),
209 _ => {}
210 }
211 }
212
213 let cleaned: String = text.split_whitespace().collect::<Vec<_>>().join(" ");
215 json!({
216 "text_content": cleaned,
217 "raw_length": data.len()
218 })
219 }
220
221 async fn extract_image(&self, data: &str, schema: &Value) -> Result<Value> {
223 match &self.vision_provider {
224 Some(provider) => {
225 if !provider.capabilities().vision {
226 return Err(StygianError::Provider(ProviderError::ApiError(format!(
227 "Configured vision provider '{}' does not support vision",
228 provider.name()
229 ))));
230 }
231 provider.extract(data.to_string(), schema.clone()).await
232 }
233 None => {
234 Ok(json!({
236 "status": "no_vision_provider",
237 "message": "Inject a vision-capable AIProvider to enable image understanding",
238 "data_length": data.len()
239 }))
240 }
241 }
242 }
243
244 fn extract_pdf(data: &str, enabled: bool) -> Value {
246 if enabled {
247 json!({
249 "status": "pdf_extraction_stub",
250 "message": "PDF text extraction requires the 'pdf' feature flag",
251 "data_length": data.len()
252 })
253 } else {
254 json!({
255 "status": "pdf_disabled",
256 "message": "PDF extraction is disabled. Set MultiModalConfig::pdf_enabled = true",
257 "data_length": data.len()
258 })
259 }
260 }
261}
262
263#[async_trait]
264impl ScrapingService for MultiModalAdapter {
265 async fn execute(&self, input: ServiceInput) -> Result<ServiceOutput> {
287 let mime = input
288 .params
289 .get("content_type")
290 .and_then(Value::as_str)
291 .unwrap_or("unknown");
292
293 let content = input
294 .params
295 .get("content")
296 .and_then(Value::as_str)
297 .unwrap_or(&input.url);
298
299 let content_type = ContentType::detect(mime);
300
301 let (extracted, type_name) = match &content_type {
302 ContentType::Csv => (self.parse_csv(content)?, "csv"),
303 ContentType::Json => (Self::parse_json(content)?, "json"),
304 ContentType::Xml => (Self::parse_xml(content), "xml"),
305 ContentType::Image(_) => {
306 let schema = input
307 .params
308 .get("schema")
309 .cloned()
310 .unwrap_or_else(|| self.config.default_image_schema.clone());
311 (self.extract_image(content, &schema).await?, "image")
312 }
313 ContentType::Pdf => (Self::extract_pdf(content, self.config.pdf_enabled), "pdf"),
314 ContentType::Unknown(_) => (json!({"raw": content}), "unknown"),
315 };
316
317 Ok(ServiceOutput {
318 data: extracted.to_string(),
319 metadata: json!({
320 "content_type": mime,
321 "detected_type": type_name,
322 "input_length": content.len(),
323 }),
324 })
325 }
326
327 fn name(&self) -> &'static str {
328 "multimodal"
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335 use serde_json::json;
336
337 fn adapter() -> MultiModalAdapter {
338 MultiModalAdapter::new(MultiModalConfig::default(), None)
339 }
340
341 fn input(content_type: &str, data: &str) -> ServiceInput {
342 ServiceInput {
343 url: data.to_string(),
344 params: json!({ "content_type": content_type }),
345 }
346 }
347
348 #[test]
349 fn test_name() {
350 assert_eq!(adapter().name(), "multimodal");
351 }
352
353 #[test]
356 fn test_detect_csv() {
357 assert_eq!(ContentType::detect("text/csv"), ContentType::Csv);
358 assert_eq!(ContentType::detect("file.csv"), ContentType::Csv);
359 }
360
361 #[test]
362 fn test_detect_json() {
363 assert_eq!(ContentType::detect("application/json"), ContentType::Json);
364 }
365
366 #[test]
367 fn test_detect_xml() {
368 assert_eq!(ContentType::detect("text/xml"), ContentType::Xml);
369 }
370
371 #[test]
372 fn test_detect_image() {
373 assert!(matches!(
374 ContentType::detect("image/png"),
375 ContentType::Image(_)
376 ));
377 assert!(matches!(
378 ContentType::detect("photo.jpg"),
379 ContentType::Image(_)
380 ));
381 }
382
383 #[test]
384 fn test_detect_pdf() {
385 assert_eq!(ContentType::detect("application/pdf"), ContentType::Pdf);
386 }
387
388 #[allow(clippy::float_cmp)]
391 #[test]
392 fn test_parse_csv_basic() -> crate::domain::error::Result<()> {
393 let a = adapter();
394 let result = a.parse_csv("name,age\nalice,30\nbob,25")?;
395 assert_eq!(result.get("row_count").and_then(Value::as_u64), Some(2));
396 assert_eq!(
397 result
398 .get("rows")
399 .and_then(|r| r.get(0))
400 .and_then(|row| row.get("name"))
401 .and_then(Value::as_str),
402 Some("alice")
403 );
404 assert_eq!(
405 result
406 .get("rows")
407 .and_then(|r| r.get(0))
408 .and_then(|row| row.get("age"))
409 .and_then(Value::as_f64),
410 Some(30.0)
411 );
412 Ok(())
413 }
414
415 #[test]
416 fn test_parse_csv_empty() -> crate::domain::error::Result<()> {
417 let a = adapter();
418 let result = a.parse_csv("")?;
419 assert_eq!(result.get("row_count").and_then(Value::as_u64), Some(0));
420 Ok(())
421 }
422
423 #[test]
424 fn test_parse_csv_headers_only() -> crate::domain::error::Result<()> {
425 let a = adapter();
426 let result = a.parse_csv("col1,col2")?;
427 assert_eq!(result.get("row_count").and_then(Value::as_u64), Some(0));
428 Ok(())
429 }
430
431 #[test]
434 fn test_parse_json_valid() -> crate::domain::error::Result<()> {
435 let result = MultiModalAdapter::parse_json(r#"{"hello": "world"}"#)?;
436 assert_eq!(result.get("hello").and_then(Value::as_str), Some("world"));
437 Ok(())
438 }
439
440 #[test]
441 fn test_parse_json_invalid() {
442 assert!(MultiModalAdapter::parse_json("not json").is_err());
443 }
444
445 #[test]
448 fn test_parse_xml_strips_tags() {
449 let result = MultiModalAdapter::parse_xml("<root><name>Alice</name></root>");
450 let text = result
451 .get("text_content")
452 .and_then(Value::as_str)
453 .unwrap_or("");
454 assert!(text.contains("Alice"));
455 assert!(!text.contains('<'));
456 }
457
458 #[test]
461 fn test_pdf_disabled_returns_status() {
462 let result = MultiModalAdapter::extract_pdf("data", false);
463 assert_eq!(
464 result.get("status").and_then(Value::as_str),
465 Some("pdf_disabled")
466 );
467 }
468
469 #[tokio::test]
472 async fn test_execute_csv() -> crate::domain::error::Result<()> {
473 let a = adapter();
474 let output = a.execute(input("text/csv", "x,y\n1,2")).await?;
475 let data: Value = serde_json::from_str(&output.data)
476 .map_err(|e| ServiceError::InvalidResponse(e.to_string()))?;
477 assert_eq!(data.get("row_count").and_then(Value::as_u64), Some(1));
478 assert_eq!(
479 output.metadata.get("detected_type").and_then(Value::as_str),
480 Some("csv")
481 );
482 Ok(())
483 }
484
485 #[tokio::test]
486 async fn test_execute_json() -> crate::domain::error::Result<()> {
487 let a = adapter();
488 let out = a
489 .execute(input("application/json", r#"{"k": "v"}"#))
490 .await?;
491 let data: Value = serde_json::from_str(&out.data)
492 .map_err(|e| ServiceError::InvalidResponse(e.to_string()))?;
493 assert_eq!(data.get("k").and_then(Value::as_str), Some("v"));
494 Ok(())
495 }
496
497 #[tokio::test]
498 async fn test_execute_image_no_provider() -> crate::domain::error::Result<()> {
499 let a = adapter();
500 let out = a.execute(input("image/png", "binary-data")).await?;
501 let data: Value = serde_json::from_str(&out.data)
502 .map_err(|e| ServiceError::InvalidResponse(e.to_string()))?;
503 assert_eq!(
504 data.get("status").and_then(Value::as_str),
505 Some("no_vision_provider")
506 );
507 Ok(())
508 }
509
510 #[tokio::test]
511 async fn test_execute_unknown_passthrough() -> crate::domain::error::Result<()> {
512 let a = adapter();
513 let out = a.execute(input("application/octet-stream", "raw")).await?;
514 let data: Value = serde_json::from_str(&out.data)
515 .map_err(|e| ServiceError::InvalidResponse(e.to_string()))?;
516 assert_eq!(data.get("raw").and_then(Value::as_str), Some("raw"));
517 Ok(())
518 }
519
520 #[tokio::test]
521 async fn test_content_from_params_overrides_url() -> crate::domain::error::Result<()> {
522 let a = adapter();
523 let input = ServiceInput {
524 url: "should-not-be-used".to_string(),
525 params: json!({
526 "content_type": "application/json",
527 "content": "{\"answer\": 42}"
528 }),
529 };
530 let out = a.execute(input).await?;
531 let data: Value = serde_json::from_str(&out.data)
532 .map_err(|e| ServiceError::InvalidResponse(e.to_string()))?;
533 assert_eq!(data.get("answer").and_then(Value::as_u64), Some(42));
534 Ok(())
535 }
536}