stygian_graph/adapters/ai/
ollama.rs1use std::time::Duration;
22
23use async_trait::async_trait;
24use futures::stream::{self, BoxStream};
25use reqwest::Client;
26use serde_json::{Value, json};
27
28use crate::domain::error::{ProviderError, Result, StygianError};
29use crate::ports::{AIProvider, ProviderCapabilities};
30
31const DEFAULT_BASE_URL: &str = "http://localhost:11434";
33
34const DEFAULT_MODEL: &str = "qwen2.5:32b";
36
37#[derive(Debug, Clone)]
39pub struct OllamaConfig {
40 pub base_url: String,
42 pub model: String,
44 pub timeout: Duration,
46}
47
48impl OllamaConfig {
49 #[must_use]
51 pub fn new() -> Self {
52 Self {
53 base_url: DEFAULT_BASE_URL.to_string(),
54 model: DEFAULT_MODEL.to_string(),
55 timeout: Duration::from_mins(5),
56 }
57 }
58
59 #[must_use]
61 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
62 self.base_url = url.into();
63 self
64 }
65
66 #[must_use]
68 pub fn with_model(mut self, model: impl Into<String>) -> Self {
69 self.model = model.into();
70 self
71 }
72}
73
74impl Default for OllamaConfig {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80pub struct OllamaProvider {
82 config: OllamaConfig,
83 client: Client,
84}
85
86impl OllamaProvider {
87 #[must_use]
96 pub fn new() -> Self {
97 Self::with_config(OllamaConfig::new())
98 }
99
100 #[must_use]
116 pub fn with_config(config: OllamaConfig) -> Self {
117 #[allow(clippy::expect_used)]
119 let client = Client::builder()
120 .timeout(config.timeout)
121 .build()
122 .expect("Failed to build HTTP client");
123 Self { config, client }
124 }
125
126 fn api_url(&self) -> String {
127 format!("{}/api/generate", self.config.base_url)
128 }
129
130 fn build_body(&self, content: &str, schema: &Value) -> Value {
131 let prompt = format!(
132 "Extract structured data from the following content according to this JSON schema.\n\
133 Return ONLY valid JSON matching the schema, with no markdown, no code blocks, no extra text.\n\
134 Schema: {}\n\nContent:\n{}",
135 serde_json::to_string(schema).unwrap_or_default(),
136 content
137 );
138
139 json!({
140 "model": self.config.model,
141 "prompt": prompt,
142 "stream": false,
143 "format": "json"
144 })
145 }
146
147 fn parse_response(response: &Value) -> Result<Value> {
148 let text = response
149 .get("response")
150 .and_then(Value::as_str)
151 .ok_or_else(|| {
152 StygianError::Provider(ProviderError::ApiError(
153 "No response field in Ollama output".to_string(),
154 ))
155 })?;
156
157 serde_json::from_str(text).map_err(|e| {
158 StygianError::Provider(ProviderError::ApiError(format!(
159 "Failed to parse Ollama JSON response: {e}"
160 )))
161 })
162 }
163
164 fn map_http_error(status: u16, body: &str) -> StygianError {
165 match status {
166 404 => StygianError::Provider(ProviderError::ModelUnavailable(format!(
167 "Model not found in Ollama: {body}"
168 ))),
169 _ => StygianError::Provider(ProviderError::ApiError(format!("HTTP {status}: {body}"))),
170 }
171 }
172}
173
174impl Default for OllamaProvider {
175 fn default() -> Self {
176 Self::new()
177 }
178}
179
180#[async_trait]
181impl AIProvider for OllamaProvider {
182 async fn extract(&self, content: String, schema: Value) -> Result<Value> {
183 let body = self.build_body(&content, &schema);
184 let url = self.api_url();
185
186 let response = self
187 .client
188 .post(&url)
189 .header("Content-Type", "application/json")
190 .json(&body)
191 .send()
192 .await
193 .map_err(|e| {
194 StygianError::Provider(ProviderError::ApiError(format!(
195 "Ollama request failed (is Ollama running?): {e}"
196 )))
197 })?;
198
199 let status = response.status().as_u16();
200 let text = response
201 .text()
202 .await
203 .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
204
205 if status != 200 {
206 return Err(Self::map_http_error(status, &text));
207 }
208
209 let json_val: Value = serde_json::from_str(&text)
210 .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
211
212 Self::parse_response(&json_val)
213 }
214
215 async fn stream_extract(
216 &self,
217 content: String,
218 schema: Value,
219 ) -> Result<BoxStream<'static, Result<Value>>> {
220 let result = self.extract(content, schema).await;
221 Ok(Box::pin(stream::once(async move { result })))
222 }
223
224 fn capabilities(&self) -> ProviderCapabilities {
225 ProviderCapabilities {
226 streaming: true,
227 vision: false,
228 tool_use: false,
229 json_mode: true,
230 }
231 }
232
233 fn name(&self) -> &'static str {
234 "ollama"
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use serde_json::json;
242
243 #[test]
244 fn test_name() {
245 assert_eq!(OllamaProvider::new().name(), "ollama");
246 }
247
248 #[test]
249 fn test_default() {
250 let p = OllamaProvider::default();
251 assert_eq!(p.config.model, DEFAULT_MODEL);
252 assert_eq!(p.config.base_url, DEFAULT_BASE_URL);
253 }
254
255 #[test]
256 fn test_capabilities_json_mode() {
257 let caps = OllamaProvider::new().capabilities();
258 assert!(caps.json_mode);
259 assert!(!caps.vision);
260 }
261
262 #[test]
263 fn test_api_url() {
264 let p = OllamaProvider::new();
265 assert_eq!(p.api_url(), "http://localhost:11434/api/generate");
266 }
267
268 #[test]
269 fn test_build_body_stream_false() {
270 let p = OllamaProvider::new();
271 let body = p.build_body("c", &json!({"type": "object"}));
272 assert_eq!(body.get("stream"), Some(&json!(false)));
273 assert_eq!(body.get("format").and_then(Value::as_str), Some("json"));
274 }
275
276 #[test]
277 fn test_parse_response_valid() -> Result<()> {
278 let resp = json!({"response": "{\"score\": 42}"});
279 let val = OllamaProvider::parse_response(&resp)?;
280 assert_eq!(val.get("score").and_then(Value::as_u64), Some(42));
281 Ok(())
282 }
283
284 #[test]
285 fn test_map_http_error_404() {
286 assert!(matches!(
287 OllamaProvider::map_http_error(404, "not found"),
288 StygianError::Provider(ProviderError::ModelUnavailable(_))
289 ));
290 }
291
292 #[test]
293 fn test_config_builder() {
294 let config = OllamaConfig::new()
295 .with_model("llama3:latest")
296 .with_base_url("http://192.168.1.10:11434");
297 assert_eq!(config.model, "llama3:latest");
298 assert_eq!(config.base_url, "http://192.168.1.10:11434");
299 }
300}