stygian_graph/adapters/ai/
ollama.rs

1//! Ollama (local LLM) AI provider adapter
2//!
3//! Implements the `AIProvider` port using Ollama's HTTP API for local inference.
4//! Supports any model installed in the local Ollama instance.
5//! JSON output is enforced via `format: "json"` parameter.
6//!
7//! # Example
8//!
9//! ```no_run
10//! use stygian_graph::adapters::ai::ollama::{OllamaProvider, OllamaConfig};
11//! use stygian_graph::ports::AIProvider;
12//! use serde_json::json;
13//!
14//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
15//! let provider = OllamaProvider::new();
16//! let schema = json!({"type": "object", "properties": {"title": {"type": "string"}}});
17//! // let result = provider.extract("<html>Hello</html>".to_string(), schema).await.unwrap();
18//! # });
19//! ```
20
21use std::time::Duration;
22
23use async_trait::async_trait;
24use futures::stream::{self, BoxStream};
25use reqwest::Client;
26use serde_json::{Value, json};
27
28use crate::domain::error::{ProviderError, Result, StygianError};
29use crate::ports::{AIProvider, ProviderCapabilities};
30
31/// Default Ollama base URL
32const DEFAULT_BASE_URL: &str = "http://localhost:11434";
33
34/// Default model for local inference
35const DEFAULT_MODEL: &str = "qwen2.5:32b";
36
37/// Configuration for the Ollama provider
38#[derive(Debug, Clone)]
39pub struct OllamaConfig {
40    /// Ollama API base URL
41    pub base_url: String,
42    /// Model to use for inference
43    pub model: String,
44    /// Request timeout (may need to be long for large models)
45    pub timeout: Duration,
46}
47
48impl OllamaConfig {
49    /// Create config with defaults
50    pub fn new() -> Self {
51        Self {
52            base_url: DEFAULT_BASE_URL.to_string(),
53            model: DEFAULT_MODEL.to_string(),
54            timeout: Duration::from_secs(300),
55        }
56    }
57
58    /// Override base URL
59    #[must_use]
60    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
61        self.base_url = url.into();
62        self
63    }
64
65    /// Override model
66    #[must_use]
67    pub fn with_model(mut self, model: impl Into<String>) -> Self {
68        self.model = model.into();
69        self
70    }
71}
72
73impl Default for OllamaConfig {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79/// Ollama local LLM provider adapter
80pub struct OllamaProvider {
81    config: OllamaConfig,
82    client: Client,
83}
84
85impl OllamaProvider {
86    /// Create with default configuration (localhost:11434, qwen2.5:32b)
87    ///
88    /// # Example
89    ///
90    /// ```no_run
91    /// use stygian_graph::adapters::ai::ollama::OllamaProvider;
92    /// let p = OllamaProvider::new();
93    /// ```
94    pub fn new() -> Self {
95        Self::with_config(OllamaConfig::new())
96    }
97
98    /// Create with custom configuration
99    ///
100    /// # Example
101    ///
102    /// ```no_run
103    /// use stygian_graph::adapters::ai::ollama::{OllamaProvider, OllamaConfig};
104    /// let config = OllamaConfig::new().with_model("llama3.2:latest");
105    /// let p = OllamaProvider::with_config(config);
106    /// ```
107    pub fn with_config(config: OllamaConfig) -> Self {
108        // SAFETY: TLS backend (rustls) is always available; build() only fails if no TLS backend.
109        #[allow(clippy::expect_used)]
110        let client = Client::builder()
111            .timeout(config.timeout)
112            .build()
113            .expect("Failed to build HTTP client");
114        Self { config, client }
115    }
116
117    fn api_url(&self) -> String {
118        format!("{}/api/generate", self.config.base_url)
119    }
120
121    fn build_body(&self, content: &str, schema: &Value) -> Value {
122        let prompt = format!(
123            "Extract structured data from the following content according to this JSON schema.\n\
124             Return ONLY valid JSON matching the schema, with no markdown, no code blocks, no extra text.\n\
125             Schema: {}\n\nContent:\n{}",
126            serde_json::to_string(schema).unwrap_or_default(),
127            content
128        );
129
130        json!({
131            "model": self.config.model,
132            "prompt": prompt,
133            "stream": false,
134            "format": "json"
135        })
136    }
137
138    fn parse_response(response: &Value) -> Result<Value> {
139        let text = response
140            .get("response")
141            .and_then(Value::as_str)
142            .ok_or_else(|| {
143                StygianError::Provider(ProviderError::ApiError(
144                    "No response field in Ollama output".to_string(),
145                ))
146            })?;
147
148        serde_json::from_str(text).map_err(|e| {
149            StygianError::Provider(ProviderError::ApiError(format!(
150                "Failed to parse Ollama JSON response: {e}"
151            )))
152        })
153    }
154
155    fn map_http_error(status: u16, body: &str) -> StygianError {
156        match status {
157            404 => StygianError::Provider(ProviderError::ModelUnavailable(format!(
158                "Model not found in Ollama: {body}"
159            ))),
160            _ => StygianError::Provider(ProviderError::ApiError(format!("HTTP {status}: {body}"))),
161        }
162    }
163}
164
165impl Default for OllamaProvider {
166    fn default() -> Self {
167        Self::new()
168    }
169}
170
171#[async_trait]
172impl AIProvider for OllamaProvider {
173    async fn extract(&self, content: String, schema: Value) -> Result<Value> {
174        let body = self.build_body(&content, &schema);
175        let url = self.api_url();
176
177        let response = self
178            .client
179            .post(&url)
180            .header("Content-Type", "application/json")
181            .json(&body)
182            .send()
183            .await
184            .map_err(|e| {
185                StygianError::Provider(ProviderError::ApiError(format!(
186                    "Ollama request failed (is Ollama running?): {e}"
187                )))
188            })?;
189
190        let status = response.status().as_u16();
191        let text = response
192            .text()
193            .await
194            .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
195
196        if status != 200 {
197            return Err(Self::map_http_error(status, &text));
198        }
199
200        let json_val: Value = serde_json::from_str(&text)
201            .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
202
203        Self::parse_response(&json_val)
204    }
205
206    async fn stream_extract(
207        &self,
208        content: String,
209        schema: Value,
210    ) -> Result<BoxStream<'static, Result<Value>>> {
211        let result = self.extract(content, schema).await;
212        Ok(Box::pin(stream::once(async move { result })))
213    }
214
215    fn capabilities(&self) -> ProviderCapabilities {
216        ProviderCapabilities {
217            streaming: true,
218            vision: false,
219            tool_use: false,
220            json_mode: true,
221        }
222    }
223
224    fn name(&self) -> &'static str {
225        "ollama"
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use serde_json::json;
233
234    #[test]
235    fn test_name() {
236        assert_eq!(OllamaProvider::new().name(), "ollama");
237    }
238
239    #[test]
240    fn test_default() {
241        let p = OllamaProvider::default();
242        assert_eq!(p.config.model, DEFAULT_MODEL);
243        assert_eq!(p.config.base_url, DEFAULT_BASE_URL);
244    }
245
246    #[test]
247    fn test_capabilities_json_mode() {
248        let caps = OllamaProvider::new().capabilities();
249        assert!(caps.json_mode);
250        assert!(!caps.vision);
251    }
252
253    #[test]
254    fn test_api_url() {
255        let p = OllamaProvider::new();
256        assert_eq!(p.api_url(), "http://localhost:11434/api/generate");
257    }
258
259    #[test]
260    fn test_build_body_stream_false() {
261        let p = OllamaProvider::new();
262        let body = p.build_body("c", &json!({"type": "object"}));
263        assert_eq!(body.get("stream"), Some(&json!(false)));
264        assert_eq!(body.get("format").and_then(Value::as_str), Some("json"));
265    }
266
267    #[test]
268    fn test_parse_response_valid() -> Result<()> {
269        let resp = json!({"response": "{\"score\": 42}"});
270        let val = OllamaProvider::parse_response(&resp)?;
271        assert_eq!(val.get("score").and_then(Value::as_u64), Some(42));
272        Ok(())
273    }
274
275    #[test]
276    fn test_map_http_error_404() {
277        assert!(matches!(
278            OllamaProvider::map_http_error(404, "not found"),
279            StygianError::Provider(ProviderError::ModelUnavailable(_))
280        ));
281    }
282
283    #[test]
284    fn test_config_builder() {
285        let config = OllamaConfig::new()
286            .with_model("llama3:latest")
287            .with_base_url("http://192.168.1.10:11434");
288        assert_eq!(config.model, "llama3:latest");
289        assert_eq!(config.base_url, "http://192.168.1.10:11434");
290    }
291}