Skip to main content

stygian_graph/adapters/ai/
ollama.rs

1//! Ollama (local LLM) AI provider adapter
2//!
3//! Implements the `AIProvider` port using Ollama's HTTP API for local inference.
4//! Supports any model installed in the local Ollama instance.
5//! JSON output is enforced via `format: "json"` parameter.
6//!
7//! # Example
8//!
9//! ```no_run
10//! use stygian_graph::adapters::ai::ollama::{OllamaProvider, OllamaConfig};
11//! use stygian_graph::ports::AIProvider;
12//! use serde_json::json;
13//!
14//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
15//! let provider = OllamaProvider::new();
16//! let schema = json!({"type": "object", "properties": {"title": {"type": "string"}}});
17//! // let result = provider.extract("<html>Hello</html>".to_string(), schema).await.unwrap();
18//! # });
19//! ```
20
21use std::time::Duration;
22
23use async_trait::async_trait;
24use futures::stream::{self, BoxStream};
25use reqwest::Client;
26use serde_json::{Value, json};
27
28use crate::domain::error::{ProviderError, Result, StygianError};
29use crate::ports::{AIProvider, ProviderCapabilities};
30
31/// Default Ollama base URL
32const DEFAULT_BASE_URL: &str = "http://localhost:11434";
33
34/// Default model for local inference
35const DEFAULT_MODEL: &str = "qwen2.5:32b";
36
37/// Configuration for the Ollama provider
38#[derive(Debug, Clone)]
39pub struct OllamaConfig {
40    /// Ollama API base URL
41    pub base_url: String,
42    /// Model to use for inference
43    pub model: String,
44    /// Request timeout (may need to be long for large models)
45    pub timeout: Duration,
46}
47
48impl OllamaConfig {
49    /// Create config with defaults
50    #[must_use]
51    pub fn new() -> Self {
52        Self {
53            base_url: DEFAULT_BASE_URL.to_string(),
54            model: DEFAULT_MODEL.to_string(),
55            timeout: Duration::from_mins(5),
56        }
57    }
58
59    /// Override base URL
60    #[must_use]
61    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
62        self.base_url = url.into();
63        self
64    }
65
66    /// Override model
67    #[must_use]
68    pub fn with_model(mut self, model: impl Into<String>) -> Self {
69        self.model = model.into();
70        self
71    }
72}
73
74impl Default for OllamaConfig {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80/// Ollama local LLM provider adapter
81pub struct OllamaProvider {
82    config: OllamaConfig,
83    client: Client,
84}
85
86impl OllamaProvider {
87    /// Create with default configuration (localhost:11434, qwen2.5:32b)
88    ///
89    /// # Example
90    ///
91    /// ```no_run
92    /// use stygian_graph::adapters::ai::ollama::OllamaProvider;
93    /// let p = OllamaProvider::new();
94    /// ```
95    #[must_use]
96    pub fn new() -> Self {
97        Self::with_config(OllamaConfig::new())
98    }
99
100    /// Create with custom configuration
101    ///
102    /// # Panics
103    ///
104    /// Panics if the underlying HTTP client fails to build. With `rustls` as the
105    /// TLS backend this is unreachable in practice (build only fails when no TLS
106    /// backend is configured).
107    ///
108    /// # Example
109    ///
110    /// ```no_run
111    /// use stygian_graph::adapters::ai::ollama::{OllamaProvider, OllamaConfig};
112    /// let config = OllamaConfig::new().with_model("llama3.2:latest");
113    /// let p = OllamaProvider::with_config(config);
114    /// ```
115    #[must_use]
116    pub fn with_config(config: OllamaConfig) -> Self {
117        // SAFETY: TLS backend (rustls) is always available; build() only fails if no TLS backend.
118        #[allow(clippy::expect_used)]
119        let client = Client::builder()
120            .timeout(config.timeout)
121            .build()
122            .expect("Failed to build HTTP client");
123        Self { config, client }
124    }
125
126    fn api_url(&self) -> String {
127        format!("{}/api/generate", self.config.base_url)
128    }
129
130    fn build_body(&self, content: &str, schema: &Value) -> Value {
131        let prompt = format!(
132            "Extract structured data from the following content according to this JSON schema.\n\
133             Return ONLY valid JSON matching the schema, with no markdown, no code blocks, no extra text.\n\
134             Schema: {}\n\nContent:\n{}",
135            serde_json::to_string(schema).unwrap_or_default(),
136            content
137        );
138
139        json!({
140            "model": self.config.model,
141            "prompt": prompt,
142            "stream": false,
143            "format": "json"
144        })
145    }
146
147    fn parse_response(response: &Value) -> Result<Value> {
148        let text = response
149            .get("response")
150            .and_then(Value::as_str)
151            .ok_or_else(|| {
152                StygianError::Provider(ProviderError::ApiError(
153                    "No response field in Ollama output".to_string(),
154                ))
155            })?;
156
157        serde_json::from_str(text).map_err(|e| {
158            StygianError::Provider(ProviderError::ApiError(format!(
159                "Failed to parse Ollama JSON response: {e}"
160            )))
161        })
162    }
163
164    fn map_http_error(status: u16, body: &str) -> StygianError {
165        match status {
166            404 => StygianError::Provider(ProviderError::ModelUnavailable(format!(
167                "Model not found in Ollama: {body}"
168            ))),
169            _ => StygianError::Provider(ProviderError::ApiError(format!("HTTP {status}: {body}"))),
170        }
171    }
172}
173
174impl Default for OllamaProvider {
175    fn default() -> Self {
176        Self::new()
177    }
178}
179
180#[async_trait]
181impl AIProvider for OllamaProvider {
182    async fn extract(&self, content: String, schema: Value) -> Result<Value> {
183        let body = self.build_body(&content, &schema);
184        let url = self.api_url();
185
186        let response = self
187            .client
188            .post(&url)
189            .header("Content-Type", "application/json")
190            .json(&body)
191            .send()
192            .await
193            .map_err(|e| {
194                StygianError::Provider(ProviderError::ApiError(format!(
195                    "Ollama request failed (is Ollama running?): {e}"
196                )))
197            })?;
198
199        let status = response.status().as_u16();
200        let text = response
201            .text()
202            .await
203            .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
204
205        if status != 200 {
206            return Err(Self::map_http_error(status, &text));
207        }
208
209        let json_val: Value = serde_json::from_str(&text)
210            .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
211
212        Self::parse_response(&json_val)
213    }
214
215    async fn stream_extract(
216        &self,
217        content: String,
218        schema: Value,
219    ) -> Result<BoxStream<'static, Result<Value>>> {
220        let result = self.extract(content, schema).await;
221        Ok(Box::pin(stream::once(async move { result })))
222    }
223
224    fn capabilities(&self) -> ProviderCapabilities {
225        ProviderCapabilities {
226            streaming: true,
227            vision: false,
228            tool_use: false,
229            json_mode: true,
230        }
231    }
232
233    fn name(&self) -> &'static str {
234        "ollama"
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use serde_json::json;
242
243    #[test]
244    fn test_name() {
245        assert_eq!(OllamaProvider::new().name(), "ollama");
246    }
247
248    #[test]
249    fn test_default() {
250        let p = OllamaProvider::default();
251        assert_eq!(p.config.model, DEFAULT_MODEL);
252        assert_eq!(p.config.base_url, DEFAULT_BASE_URL);
253    }
254
255    #[test]
256    fn test_capabilities_json_mode() {
257        let caps = OllamaProvider::new().capabilities();
258        assert!(caps.json_mode);
259        assert!(!caps.vision);
260    }
261
262    #[test]
263    fn test_api_url() {
264        let p = OllamaProvider::new();
265        assert_eq!(p.api_url(), "http://localhost:11434/api/generate");
266    }
267
268    #[test]
269    fn test_build_body_stream_false() {
270        let p = OllamaProvider::new();
271        let body = p.build_body("c", &json!({"type": "object"}));
272        assert_eq!(body.get("stream"), Some(&json!(false)));
273        assert_eq!(body.get("format").and_then(Value::as_str), Some("json"));
274    }
275
276    #[test]
277    fn test_parse_response_valid() -> Result<()> {
278        let resp = json!({"response": "{\"score\": 42}"});
279        let val = OllamaProvider::parse_response(&resp)?;
280        assert_eq!(val.get("score").and_then(Value::as_u64), Some(42));
281        Ok(())
282    }
283
284    #[test]
285    fn test_map_http_error_404() {
286        assert!(matches!(
287            OllamaProvider::map_http_error(404, "not found"),
288            StygianError::Provider(ProviderError::ModelUnavailable(_))
289        ));
290    }
291
292    #[test]
293    fn test_config_builder() {
294        let config = OllamaConfig::new()
295            .with_model("llama3:latest")
296            .with_base_url("http://192.168.1.10:11434");
297        assert_eq!(config.model, "llama3:latest");
298        assert_eq!(config.base_url, "http://192.168.1.10:11434");
299    }
300}