stygian_graph/adapters/ai/
gemini.rs

1//! Google Gemini AI provider adapter
2//!
3//! Implements the `AIProvider` port using Google's Generative Language API.
4//! Supports Gemini 1.5 Pro and Gemini 2.0 Flash with response schema enforcement.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use stygian_graph::adapters::ai::gemini::{GeminiProvider, GeminiConfig};
10//! use stygian_graph::ports::AIProvider;
11//! use serde_json::json;
12//!
13//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
14//! let provider = GeminiProvider::new("AIza...".to_string());
15//! let schema = json!({"type": "object", "properties": {"title": {"type": "string"}}});
16//! // let result = provider.extract("<html>Hello</html>".to_string(), schema).await.unwrap();
17//! # });
18//! ```
19
20use std::time::Duration;
21
22use async_trait::async_trait;
23use futures::stream::{self, BoxStream};
24use reqwest::Client;
25use serde_json::{Value, json};
26
27use crate::domain::error::{ProviderError, Result, StygianError};
28use crate::ports::{AIProvider, ProviderCapabilities};
29
30/// Default model
31const DEFAULT_MODEL: &str = "gemini-2.0-flash";
32
33/// Google Generative Language API base URL
34const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models";
35
36/// Configuration for the Gemini provider
37#[derive(Debug, Clone)]
38pub struct GeminiConfig {
39    /// Google AI API key
40    pub api_key: String,
41    /// Model identifier
42    pub model: String,
43    /// Maximum output tokens
44    pub max_tokens: u32,
45    /// Request timeout
46    pub timeout: Duration,
47}
48
49impl GeminiConfig {
50    /// Create config with API key and defaults
51    pub fn new(api_key: String) -> Self {
52        Self {
53            api_key,
54            model: DEFAULT_MODEL.to_string(),
55            max_tokens: 8192,
56            timeout: Duration::from_secs(120),
57        }
58    }
59
60    /// Override model
61    #[must_use]
62    pub fn with_model(mut self, model: impl Into<String>) -> Self {
63        self.model = model.into();
64        self
65    }
66}
67
68/// Google Gemini provider adapter
69pub struct GeminiProvider {
70    config: GeminiConfig,
71    client: Client,
72}
73
74impl GeminiProvider {
75    /// Create with API key and defaults
76    ///
77    /// # Example
78    ///
79    /// ```no_run
80    /// use stygian_graph::adapters::ai::gemini::GeminiProvider;
81    /// let p = GeminiProvider::new("AIza...".to_string());
82    /// ```
83    pub fn new(api_key: String) -> Self {
84        Self::with_config(GeminiConfig::new(api_key))
85    }
86
87    /// Create with custom configuration
88    ///
89    /// # Example
90    ///
91    /// ```no_run
92    /// use stygian_graph::adapters::ai::gemini::{GeminiProvider, GeminiConfig};
93    /// let config = GeminiConfig::new("AIza...".to_string()).with_model("gemini-1.5-pro");
94    /// let p = GeminiProvider::with_config(config);
95    /// ```
96    pub fn with_config(config: GeminiConfig) -> Self {
97        // SAFETY: TLS backend (rustls) is always available; build() only fails if no TLS backend.
98        #[allow(clippy::expect_used)]
99        let client = Client::builder()
100            .timeout(config.timeout)
101            .build()
102            .expect("Failed to build HTTP client");
103        Self { config, client }
104    }
105
106    fn api_url(&self) -> String {
107        format!(
108            "{}/{}:generateContent?key={}",
109            API_BASE, self.config.model, self.config.api_key
110        )
111    }
112
113    fn build_body(&self, content: &str, schema: &Value) -> Value {
114        let prompt = format!(
115            "Extract structured data from the following content according to this JSON schema.\n\
116             Return ONLY valid JSON matching the schema.\n\
117             Schema: {}\n\nContent:\n{}",
118            serde_json::to_string(schema).unwrap_or_default(),
119            content
120        );
121
122        json!({
123            "contents": [{"parts": [{"text": prompt}]}],
124            "generationConfig": {
125                "maxOutputTokens": self.config.max_tokens,
126                "responseMimeType": "application/json",
127                "responseSchema": schema
128            }
129        })
130    }
131
132    fn parse_response(response: &Value) -> Result<Value> {
133        let text = response
134            .pointer("/candidates/0/content/parts/0/text")
135            .and_then(Value::as_str)
136            .ok_or_else(|| {
137                StygianError::Provider(ProviderError::ApiError(
138                    "No text in Gemini response".to_string(),
139                ))
140            })?;
141
142        serde_json::from_str(text).map_err(|e| {
143            StygianError::Provider(ProviderError::ApiError(format!(
144                "Failed to parse Gemini JSON response: {e}"
145            )))
146        })
147    }
148
149    fn map_http_error(status: u16, body: &str) -> StygianError {
150        match status {
151            400 if body.contains("API_KEY") => {
152                StygianError::Provider(ProviderError::InvalidCredentials)
153            }
154            429 => StygianError::Provider(ProviderError::ApiError(format!(
155                "Gemini rate limited: {body}"
156            ))),
157            _ => StygianError::Provider(ProviderError::ApiError(format!("HTTP {status}: {body}"))),
158        }
159    }
160}
161
162#[async_trait]
163impl AIProvider for GeminiProvider {
164    async fn extract(&self, content: String, schema: Value) -> Result<Value> {
165        let body = self.build_body(&content, &schema);
166        let url = self.api_url();
167
168        let response = self
169            .client
170            .post(&url)
171            .header("Content-Type", "application/json")
172            .json(&body)
173            .send()
174            .await
175            .map_err(|e| {
176                StygianError::Provider(ProviderError::ApiError(format!(
177                    "Gemini request failed: {e}"
178                )))
179            })?;
180
181        let status = response.status().as_u16();
182        let text = response
183            .text()
184            .await
185            .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
186
187        if status != 200 {
188            return Err(Self::map_http_error(status, &text));
189        }
190
191        let json_val: Value = serde_json::from_str(&text)
192            .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
193
194        Self::parse_response(&json_val)
195    }
196
197    async fn stream_extract(
198        &self,
199        content: String,
200        schema: Value,
201    ) -> Result<BoxStream<'static, Result<Value>>> {
202        let result = self.extract(content, schema).await;
203        Ok(Box::pin(stream::once(async move { result })))
204    }
205
206    fn capabilities(&self) -> ProviderCapabilities {
207        ProviderCapabilities {
208            streaming: true,
209            vision: true,
210            tool_use: false,
211            json_mode: true,
212        }
213    }
214
215    fn name(&self) -> &'static str {
216        "gemini"
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use serde_json::json;
224
225    #[test]
226    fn test_name() {
227        assert_eq!(GeminiProvider::new("k".to_string()).name(), "gemini");
228    }
229
230    #[test]
231    fn test_capabilities() {
232        let caps = GeminiProvider::new("k".to_string()).capabilities();
233        assert!(caps.json_mode);
234        assert!(caps.vision);
235    }
236
237    #[test]
238    fn test_api_url_contains_model_and_key() {
239        let p = GeminiProvider::new("my-key".to_string());
240        let url = p.api_url();
241        assert!(url.contains(DEFAULT_MODEL));
242        assert!(url.contains("my-key"));
243    }
244
245    #[test]
246    fn test_build_body_has_response_mime() {
247        let p = GeminiProvider::new("k".to_string());
248        let body = p.build_body("content", &json!({"type": "object"}));
249        assert_eq!(
250            body.get("generationConfig")
251                .and_then(|gc| gc.get("responseMimeType"))
252                .and_then(Value::as_str),
253            Some("application/json")
254        );
255    }
256
257    #[test]
258    fn test_parse_response_valid() -> Result<()> {
259        let resp = json!({
260            "candidates": [{
261                "content": {"parts": [{"text": "{\"name\": \"Alice\"}"}]}
262            }]
263        });
264        let val = GeminiProvider::parse_response(&resp)?;
265        assert_eq!(val.get("name").and_then(Value::as_str), Some("Alice"));
266        Ok(())
267    }
268
269    #[test]
270    fn test_parse_response_no_candidates() {
271        let resp = json!({"promptFeedback": {}});
272        assert!(GeminiProvider::parse_response(&resp).is_err());
273    }
274
275    #[test]
276    fn test_parse_response_invalid_json_text() {
277        let resp = json!({
278            "candidates": [{
279                "content": {"parts": [{"text": "not json at all"}]}
280            }]
281        });
282        assert!(GeminiProvider::parse_response(&resp).is_err());
283    }
284
285    #[test]
286    fn test_map_http_error_api_key() {
287        let err = GeminiProvider::map_http_error(400, "Invalid API_KEY provided");
288        assert!(matches!(
289            err,
290            StygianError::Provider(ProviderError::InvalidCredentials)
291        ));
292    }
293
294    #[test]
295    fn test_map_http_error_429() {
296        let err = GeminiProvider::map_http_error(429, "quota exceeded");
297        assert!(
298            matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("rate limited"))
299        );
300    }
301
302    #[test]
303    fn test_map_http_error_server_error() {
304        let err = GeminiProvider::map_http_error(503, "unavailable");
305        assert!(
306            matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("503"))
307        );
308    }
309
310    #[test]
311    fn test_config_with_model() {
312        let cfg = GeminiConfig::new("AIza".to_string()).with_model("gemini-1.5-pro");
313        assert_eq!(cfg.model, "gemini-1.5-pro");
314    }
315}