stygian_graph/adapters/ai/
ollama.rs1use std::time::Duration;
22
23use async_trait::async_trait;
24use futures::stream::{self, BoxStream};
25use reqwest::Client;
26use serde_json::{Value, json};
27
28use crate::domain::error::{ProviderError, Result, StygianError};
29use crate::ports::{AIProvider, ProviderCapabilities};
30
31const DEFAULT_BASE_URL: &str = "http://localhost:11434";
33
34const DEFAULT_MODEL: &str = "qwen2.5:32b";
36
37#[derive(Debug, Clone)]
39pub struct OllamaConfig {
40 pub base_url: String,
42 pub model: String,
44 pub timeout: Duration,
46}
47
48impl OllamaConfig {
49 pub fn new() -> Self {
51 Self {
52 base_url: DEFAULT_BASE_URL.to_string(),
53 model: DEFAULT_MODEL.to_string(),
54 timeout: Duration::from_secs(300),
55 }
56 }
57
58 #[must_use]
60 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
61 self.base_url = url.into();
62 self
63 }
64
65 #[must_use]
67 pub fn with_model(mut self, model: impl Into<String>) -> Self {
68 self.model = model.into();
69 self
70 }
71}
72
73impl Default for OllamaConfig {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79pub struct OllamaProvider {
81 config: OllamaConfig,
82 client: Client,
83}
84
85impl OllamaProvider {
86 pub fn new() -> Self {
95 Self::with_config(OllamaConfig::new())
96 }
97
98 pub fn with_config(config: OllamaConfig) -> Self {
108 #[allow(clippy::expect_used)]
110 let client = Client::builder()
111 .timeout(config.timeout)
112 .build()
113 .expect("Failed to build HTTP client");
114 Self { config, client }
115 }
116
117 fn api_url(&self) -> String {
118 format!("{}/api/generate", self.config.base_url)
119 }
120
121 fn build_body(&self, content: &str, schema: &Value) -> Value {
122 let prompt = format!(
123 "Extract structured data from the following content according to this JSON schema.\n\
124 Return ONLY valid JSON matching the schema, with no markdown, no code blocks, no extra text.\n\
125 Schema: {}\n\nContent:\n{}",
126 serde_json::to_string(schema).unwrap_or_default(),
127 content
128 );
129
130 json!({
131 "model": self.config.model,
132 "prompt": prompt,
133 "stream": false,
134 "format": "json"
135 })
136 }
137
138 fn parse_response(response: &Value) -> Result<Value> {
139 let text = response
140 .get("response")
141 .and_then(Value::as_str)
142 .ok_or_else(|| {
143 StygianError::Provider(ProviderError::ApiError(
144 "No response field in Ollama output".to_string(),
145 ))
146 })?;
147
148 serde_json::from_str(text).map_err(|e| {
149 StygianError::Provider(ProviderError::ApiError(format!(
150 "Failed to parse Ollama JSON response: {e}"
151 )))
152 })
153 }
154
155 fn map_http_error(status: u16, body: &str) -> StygianError {
156 match status {
157 404 => StygianError::Provider(ProviderError::ModelUnavailable(format!(
158 "Model not found in Ollama: {body}"
159 ))),
160 _ => StygianError::Provider(ProviderError::ApiError(format!("HTTP {status}: {body}"))),
161 }
162 }
163}
164
165impl Default for OllamaProvider {
166 fn default() -> Self {
167 Self::new()
168 }
169}
170
171#[async_trait]
172impl AIProvider for OllamaProvider {
173 async fn extract(&self, content: String, schema: Value) -> Result<Value> {
174 let body = self.build_body(&content, &schema);
175 let url = self.api_url();
176
177 let response = self
178 .client
179 .post(&url)
180 .header("Content-Type", "application/json")
181 .json(&body)
182 .send()
183 .await
184 .map_err(|e| {
185 StygianError::Provider(ProviderError::ApiError(format!(
186 "Ollama request failed (is Ollama running?): {e}"
187 )))
188 })?;
189
190 let status = response.status().as_u16();
191 let text = response
192 .text()
193 .await
194 .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
195
196 if status != 200 {
197 return Err(Self::map_http_error(status, &text));
198 }
199
200 let json_val: Value = serde_json::from_str(&text)
201 .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
202
203 Self::parse_response(&json_val)
204 }
205
206 async fn stream_extract(
207 &self,
208 content: String,
209 schema: Value,
210 ) -> Result<BoxStream<'static, Result<Value>>> {
211 let result = self.extract(content, schema).await;
212 Ok(Box::pin(stream::once(async move { result })))
213 }
214
215 fn capabilities(&self) -> ProviderCapabilities {
216 ProviderCapabilities {
217 streaming: true,
218 vision: false,
219 tool_use: false,
220 json_mode: true,
221 }
222 }
223
224 fn name(&self) -> &'static str {
225 "ollama"
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use serde_json::json;
233
234 #[test]
235 fn test_name() {
236 assert_eq!(OllamaProvider::new().name(), "ollama");
237 }
238
239 #[test]
240 fn test_default() {
241 let p = OllamaProvider::default();
242 assert_eq!(p.config.model, DEFAULT_MODEL);
243 assert_eq!(p.config.base_url, DEFAULT_BASE_URL);
244 }
245
246 #[test]
247 fn test_capabilities_json_mode() {
248 let caps = OllamaProvider::new().capabilities();
249 assert!(caps.json_mode);
250 assert!(!caps.vision);
251 }
252
253 #[test]
254 fn test_api_url() {
255 let p = OllamaProvider::new();
256 assert_eq!(p.api_url(), "http://localhost:11434/api/generate");
257 }
258
259 #[test]
260 fn test_build_body_stream_false() {
261 let p = OllamaProvider::new();
262 let body = p.build_body("c", &json!({"type": "object"}));
263 assert_eq!(body.get("stream"), Some(&json!(false)));
264 assert_eq!(body.get("format").and_then(Value::as_str), Some("json"));
265 }
266
267 #[test]
268 fn test_parse_response_valid() -> Result<()> {
269 let resp = json!({"response": "{\"score\": 42}"});
270 let val = OllamaProvider::parse_response(&resp)?;
271 assert_eq!(val.get("score").and_then(Value::as_u64), Some(42));
272 Ok(())
273 }
274
275 #[test]
276 fn test_map_http_error_404() {
277 assert!(matches!(
278 OllamaProvider::map_http_error(404, "not found"),
279 StygianError::Provider(ProviderError::ModelUnavailable(_))
280 ));
281 }
282
283 #[test]
284 fn test_config_builder() {
285 let config = OllamaConfig::new()
286 .with_model("llama3:latest")
287 .with_base_url("http://192.168.1.10:11434");
288 assert_eq!(config.model, "llama3:latest");
289 assert_eq!(config.base_url, "http://192.168.1.10:11434");
290 }
291}