stygian_graph/adapters/ai/
openai.rs1use std::time::Duration;
22
23use async_trait::async_trait;
24use futures::stream::{self, BoxStream};
25use reqwest::Client;
26use serde_json::{Value, json};
27
28use crate::domain::error::{ProviderError, Result, StygianError};
29use crate::ports::{AIProvider, ProviderCapabilities};
30
31const DEFAULT_MODEL: &str = "gpt-4o";
33
34const API_URL: &str = "https://api.openai.com/v1/chat/completions";
36
37#[derive(Debug, Clone)]
39pub struct OpenAIConfig {
40 pub api_key: String,
42 pub model: String,
44 pub max_tokens: u32,
46 pub timeout: Duration,
48}
49
50impl OpenAIConfig {
51 pub fn new(api_key: String) -> Self {
53 Self {
54 api_key,
55 model: DEFAULT_MODEL.to_string(),
56 max_tokens: 4096,
57 timeout: Duration::from_secs(120),
58 }
59 }
60
61 #[must_use]
63 pub fn with_model(mut self, model: impl Into<String>) -> Self {
64 self.model = model.into();
65 self
66 }
67}
68
69pub struct OpenAIProvider {
73 config: OpenAIConfig,
74 client: Client,
75}
76
77impl OpenAIProvider {
78 pub fn new(api_key: String) -> Self {
87 Self::with_config(OpenAIConfig::new(api_key))
88 }
89
90 pub fn with_config(config: OpenAIConfig) -> Self {
100 #[allow(clippy::expect_used)]
102 let client = Client::builder()
103 .timeout(config.timeout)
104 .build()
105 .expect("Failed to build HTTP client");
106 Self { config, client }
107 }
108
109 fn build_body(&self, content: &str, schema: &Value) -> Value {
110 let system = "You are a precise data extraction assistant. \
111 Extract structured data from the provided content matching the given JSON schema. \
112 Return ONLY valid JSON matching the schema, no extra text.";
113
114 let user_msg = format!(
115 "Schema: {}\n\nContent:\n{}",
116 serde_json::to_string(schema).unwrap_or_default(),
117 content
118 );
119
120 json!({
121 "model": self.config.model,
122 "max_tokens": self.config.max_tokens,
123 "response_format": {"type": "json_object"},
124 "messages": [
125 {"role": "system", "content": system},
126 {"role": "user", "content": user_msg}
127 ]
128 })
129 }
130
131 fn parse_response(response: &Value) -> Result<Value> {
132 let text = response
133 .pointer("/choices/0/message/content")
134 .and_then(Value::as_str)
135 .ok_or_else(|| {
136 StygianError::Provider(ProviderError::ApiError(
137 "No content in OpenAI response".to_string(),
138 ))
139 })?;
140
141 serde_json::from_str(text).map_err(|e| {
142 StygianError::Provider(ProviderError::ApiError(format!(
143 "Failed to parse OpenAI JSON response: {e}"
144 )))
145 })
146 }
147
148 fn map_http_error(status: u16, body: &str) -> StygianError {
149 match status {
150 401 => StygianError::Provider(ProviderError::InvalidCredentials),
151 429 => StygianError::Provider(ProviderError::ApiError(format!(
152 "OpenAI rate limited: {body}"
153 ))),
154 _ => StygianError::Provider(ProviderError::ApiError(format!("HTTP {status}: {body}"))),
155 }
156 }
157}
158
159#[async_trait]
160impl AIProvider for OpenAIProvider {
161 async fn extract(&self, content: String, schema: Value) -> Result<Value> {
162 let body = self.build_body(&content, &schema);
163
164 let response = self
165 .client
166 .post(API_URL)
167 .header("Authorization", format!("Bearer {}", &self.config.api_key))
168 .header("Content-Type", "application/json")
169 .json(&body)
170 .send()
171 .await
172 .map_err(|e| {
173 StygianError::Provider(ProviderError::ApiError(format!(
174 "OpenAI request failed: {e}"
175 )))
176 })?;
177
178 let status = response.status().as_u16();
179 let text = response
180 .text()
181 .await
182 .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
183
184 if status != 200 {
185 return Err(Self::map_http_error(status, &text));
186 }
187
188 let json_val: Value = serde_json::from_str(&text)
189 .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
190
191 Self::parse_response(&json_val)
192 }
193
194 async fn stream_extract(
195 &self,
196 content: String,
197 schema: Value,
198 ) -> Result<BoxStream<'static, Result<Value>>> {
199 let result = self.extract(content, schema).await;
200 Ok(Box::pin(stream::once(async move { result })))
201 }
202
203 fn capabilities(&self) -> ProviderCapabilities {
204 ProviderCapabilities {
205 streaming: true,
206 vision: true,
207 tool_use: true,
208 json_mode: true,
209 }
210 }
211
212 fn name(&self) -> &'static str {
213 "openai"
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use serde_json::json;
221
222 #[test]
223 fn test_name() {
224 assert_eq!(OpenAIProvider::new("k".to_string()).name(), "openai");
225 }
226
227 #[test]
228 fn test_capabilities() {
229 let caps = OpenAIProvider::new("k".to_string()).capabilities();
230 assert!(caps.json_mode);
231 assert!(caps.streaming);
232 }
233
234 #[test]
235 fn test_build_body_contains_json_format() {
236 let p = OpenAIProvider::new("k".to_string());
237 let body = p.build_body("content", &json!({"type": "object"}));
238 assert_eq!(
239 body.get("response_format")
240 .and_then(|rf| rf.get("type"))
241 .and_then(Value::as_str),
242 Some("json_object")
243 );
244 }
245
246 #[test]
247 fn test_parse_response_valid() -> Result<()> {
248 let resp = json!({
249 "choices": [{"message": {"content": "{\"title\": \"Hello\"}"}}]
250 });
251 let val = OpenAIProvider::parse_response(&resp)?;
252 assert_eq!(val.get("title").and_then(Value::as_str), Some("Hello"));
253 Ok(())
254 }
255
256 #[test]
257 fn test_parse_response_invalid_json() {
258 let resp = json!({"choices": [{"message": {"content": "not json"}}]});
259 assert!(OpenAIProvider::parse_response(&resp).is_err());
260 }
261
262 #[test]
263 fn test_map_http_error_401() {
264 assert!(matches!(
265 OpenAIProvider::map_http_error(401, ""),
266 StygianError::Provider(ProviderError::InvalidCredentials)
267 ));
268 }
269
270 #[test]
271 fn test_map_http_error_429() {
272 let err = OpenAIProvider::map_http_error(429, "too many");
273 assert!(
274 matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("rate limited"))
275 );
276 }
277
278 #[test]
279 fn test_map_http_error_server_error() {
280 let err = OpenAIProvider::map_http_error(500, "internal");
281 assert!(
282 matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("500"))
283 );
284 }
285
286 #[test]
287 fn test_parse_response_missing_choices() {
288 let resp = serde_json::json!({"id": "chatcmpl-abc"});
289 assert!(OpenAIProvider::parse_response(&resp).is_err());
290 }
291
292 #[test]
293 fn test_config_with_model() {
294 let cfg = OpenAIConfig::new("key".to_string()).with_model("gpt-4-turbo");
295 assert_eq!(cfg.model, "gpt-4-turbo");
296 }
297}