stygian_graph/adapters/ai/
openai.rs

1//! OpenAI (ChatGPT) AI provider adapter
2//!
3//! Implements the `AIProvider` port using OpenAI's Chat Completions API.
4//! Supports GPT-4o, GPT-4, and o1-series models with native JSON mode
5//! (`response_format: json_object`) and function calling for structured extraction.
6//!
7//! # Example
8//!
9//! ```no_run
10//! use stygian_graph::adapters::ai::openai::{OpenAIProvider, OpenAIConfig};
11//! use stygian_graph::ports::AIProvider;
12//! use serde_json::json;
13//!
14//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
15//! let provider = OpenAIProvider::new("sk-...".to_string());
16//! let schema = json!({"type": "object", "properties": {"title": {"type": "string"}}});
17//! // let result = provider.extract("<html>Hello</html>".to_string(), schema).await.unwrap();
18//! # });
19//! ```
20
21use std::time::Duration;
22
23use async_trait::async_trait;
24use futures::stream::{self, BoxStream};
25use reqwest::Client;
26use serde_json::{Value, json};
27
28use crate::domain::error::{ProviderError, Result, StygianError};
29use crate::ports::{AIProvider, ProviderCapabilities};
30
31/// Default model
32const DEFAULT_MODEL: &str = "gpt-4o";
33
34/// Chat completions endpoint
35const API_URL: &str = "https://api.openai.com/v1/chat/completions";
36
37/// Configuration for the `OpenAI` provider
38#[derive(Debug, Clone)]
39pub struct OpenAIConfig {
40    /// `OpenAI` API key
41    pub api_key: String,
42    /// Model identifier
43    pub model: String,
44    /// Maximum response tokens
45    pub max_tokens: u32,
46    /// Request timeout
47    pub timeout: Duration,
48}
49
50impl OpenAIConfig {
51    /// Create config with API key and defaults
52    pub fn new(api_key: String) -> Self {
53        Self {
54            api_key,
55            model: DEFAULT_MODEL.to_string(),
56            max_tokens: 4096,
57            timeout: Duration::from_secs(120),
58        }
59    }
60
61    /// Override model
62    #[must_use]
63    pub fn with_model(mut self, model: impl Into<String>) -> Self {
64        self.model = model.into();
65        self
66    }
67}
68
69/// `OpenAI` provider adapter
70///
71/// Uses `response_format: json_object` + function calling to enforce schema.
72pub struct OpenAIProvider {
73    config: OpenAIConfig,
74    client: Client,
75}
76
77impl OpenAIProvider {
78    /// Create with API key and defaults
79    ///
80    /// # Example
81    ///
82    /// ```no_run
83    /// use stygian_graph::adapters::ai::openai::OpenAIProvider;
84    /// let p = OpenAIProvider::new("sk-...".to_string());
85    /// ```
86    pub fn new(api_key: String) -> Self {
87        Self::with_config(OpenAIConfig::new(api_key))
88    }
89
90    /// Create with custom configuration
91    ///
92    /// # Example
93    ///
94    /// ```no_run
95    /// use stygian_graph::adapters::ai::openai::{OpenAIProvider, OpenAIConfig};
96    /// let config = OpenAIConfig::new("sk-...".to_string()).with_model("gpt-4");
97    /// let p = OpenAIProvider::with_config(config);
98    /// ```
99    pub fn with_config(config: OpenAIConfig) -> Self {
100        // SAFETY: TLS backend (rustls) is always available; build() only fails if no TLS backend.
101        #[allow(clippy::expect_used)]
102        let client = Client::builder()
103            .timeout(config.timeout)
104            .build()
105            .expect("Failed to build HTTP client");
106        Self { config, client }
107    }
108
109    fn build_body(&self, content: &str, schema: &Value) -> Value {
110        let system = "You are a precise data extraction assistant. \
111            Extract structured data from the provided content matching the given JSON schema. \
112            Return ONLY valid JSON matching the schema, no extra text.";
113
114        let user_msg = format!(
115            "Schema: {}\n\nContent:\n{}",
116            serde_json::to_string(schema).unwrap_or_default(),
117            content
118        );
119
120        json!({
121            "model": self.config.model,
122            "max_tokens": self.config.max_tokens,
123            "response_format": {"type": "json_object"},
124            "messages": [
125                {"role": "system", "content": system},
126                {"role": "user", "content": user_msg}
127            ]
128        })
129    }
130
131    fn parse_response(response: &Value) -> Result<Value> {
132        let text = response
133            .pointer("/choices/0/message/content")
134            .and_then(Value::as_str)
135            .ok_or_else(|| {
136                StygianError::Provider(ProviderError::ApiError(
137                    "No content in OpenAI response".to_string(),
138                ))
139            })?;
140
141        serde_json::from_str(text).map_err(|e| {
142            StygianError::Provider(ProviderError::ApiError(format!(
143                "Failed to parse OpenAI JSON response: {e}"
144            )))
145        })
146    }
147
148    fn map_http_error(status: u16, body: &str) -> StygianError {
149        match status {
150            401 => StygianError::Provider(ProviderError::InvalidCredentials),
151            429 => StygianError::Provider(ProviderError::ApiError(format!(
152                "OpenAI rate limited: {body}"
153            ))),
154            _ => StygianError::Provider(ProviderError::ApiError(format!("HTTP {status}: {body}"))),
155        }
156    }
157}
158
159#[async_trait]
160impl AIProvider for OpenAIProvider {
161    async fn extract(&self, content: String, schema: Value) -> Result<Value> {
162        let body = self.build_body(&content, &schema);
163
164        let response = self
165            .client
166            .post(API_URL)
167            .header("Authorization", format!("Bearer {}", &self.config.api_key))
168            .header("Content-Type", "application/json")
169            .json(&body)
170            .send()
171            .await
172            .map_err(|e| {
173                StygianError::Provider(ProviderError::ApiError(format!(
174                    "OpenAI request failed: {e}"
175                )))
176            })?;
177
178        let status = response.status().as_u16();
179        let text = response
180            .text()
181            .await
182            .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
183
184        if status != 200 {
185            return Err(Self::map_http_error(status, &text));
186        }
187
188        let json_val: Value = serde_json::from_str(&text)
189            .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
190
191        Self::parse_response(&json_val)
192    }
193
194    async fn stream_extract(
195        &self,
196        content: String,
197        schema: Value,
198    ) -> Result<BoxStream<'static, Result<Value>>> {
199        let result = self.extract(content, schema).await;
200        Ok(Box::pin(stream::once(async move { result })))
201    }
202
203    fn capabilities(&self) -> ProviderCapabilities {
204        ProviderCapabilities {
205            streaming: true,
206            vision: true,
207            tool_use: true,
208            json_mode: true,
209        }
210    }
211
212    fn name(&self) -> &'static str {
213        "openai"
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use serde_json::json;
221
222    #[test]
223    fn test_name() {
224        assert_eq!(OpenAIProvider::new("k".to_string()).name(), "openai");
225    }
226
227    #[test]
228    fn test_capabilities() {
229        let caps = OpenAIProvider::new("k".to_string()).capabilities();
230        assert!(caps.json_mode);
231        assert!(caps.streaming);
232    }
233
234    #[test]
235    fn test_build_body_contains_json_format() {
236        let p = OpenAIProvider::new("k".to_string());
237        let body = p.build_body("content", &json!({"type": "object"}));
238        assert_eq!(
239            body.get("response_format")
240                .and_then(|rf| rf.get("type"))
241                .and_then(Value::as_str),
242            Some("json_object")
243        );
244    }
245
246    #[test]
247    fn test_parse_response_valid() -> Result<()> {
248        let resp = json!({
249            "choices": [{"message": {"content": "{\"title\": \"Hello\"}"}}]
250        });
251        let val = OpenAIProvider::parse_response(&resp)?;
252        assert_eq!(val.get("title").and_then(Value::as_str), Some("Hello"));
253        Ok(())
254    }
255
256    #[test]
257    fn test_parse_response_invalid_json() {
258        let resp = json!({"choices": [{"message": {"content": "not json"}}]});
259        assert!(OpenAIProvider::parse_response(&resp).is_err());
260    }
261
262    #[test]
263    fn test_map_http_error_401() {
264        assert!(matches!(
265            OpenAIProvider::map_http_error(401, ""),
266            StygianError::Provider(ProviderError::InvalidCredentials)
267        ));
268    }
269
270    #[test]
271    fn test_map_http_error_429() {
272        let err = OpenAIProvider::map_http_error(429, "too many");
273        assert!(
274            matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("rate limited"))
275        );
276    }
277
278    #[test]
279    fn test_map_http_error_server_error() {
280        let err = OpenAIProvider::map_http_error(500, "internal");
281        assert!(
282            matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("500"))
283        );
284    }
285
286    #[test]
287    fn test_parse_response_missing_choices() {
288        let resp = serde_json::json!({"id": "chatcmpl-abc"});
289        assert!(OpenAIProvider::parse_response(&resp).is_err());
290    }
291
292    #[test]
293    fn test_config_with_model() {
294        let cfg = OpenAIConfig::new("key".to_string()).with_model("gpt-4-turbo");
295        assert_eq!(cfg.model, "gpt-4-turbo");
296    }
297}