Skip to main content

stygian_graph/adapters/ai/
gemini.rs

1//! Google Gemini AI provider adapter
2//!
3//! Implements the `AIProvider` port using Google's Generative Language API.
4//! Supports Gemini 1.5 Pro and Gemini 2.0 Flash with response schema enforcement.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use stygian_graph::adapters::ai::gemini::{GeminiProvider, GeminiConfig};
10//! use stygian_graph::ports::AIProvider;
11//! use serde_json::json;
12//!
13//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
14//! let provider = GeminiProvider::new("AIza...".to_string());
15//! let schema = json!({"type": "object", "properties": {"title": {"type": "string"}}});
16//! // let result = provider.extract("<html>Hello</html>".to_string(), schema).await.unwrap();
17//! # });
18//! ```
19
20use std::time::Duration;
21
22use async_trait::async_trait;
23use futures::stream::{self, BoxStream};
24use reqwest::Client;
25use serde_json::{Value, json};
26
27use crate::domain::error::{ProviderError, Result, StygianError};
28use crate::ports::{AIProvider, ProviderCapabilities};
29
30/// Default model
31const DEFAULT_MODEL: &str = "gemini-2.0-flash";
32
33/// Google Generative Language API base URL
34const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models";
35
36/// Configuration for the Gemini provider
37#[derive(Debug, Clone)]
38pub struct GeminiConfig {
39    /// Google AI API key
40    pub api_key: String,
41    /// Model identifier
42    pub model: String,
43    /// Maximum output tokens
44    pub max_tokens: u32,
45    /// Request timeout
46    pub timeout: Duration,
47}
48
49impl GeminiConfig {
50    /// Create config with API key and defaults
51    #[must_use]
52    pub fn new(api_key: String) -> Self {
53        Self {
54            api_key,
55            model: DEFAULT_MODEL.to_string(),
56            max_tokens: 8192,
57            timeout: Duration::from_mins(2),
58        }
59    }
60
61    /// Override model
62    #[must_use]
63    pub fn with_model(mut self, model: impl Into<String>) -> Self {
64        self.model = model.into();
65        self
66    }
67}
68
69/// Google Gemini provider adapter
70pub struct GeminiProvider {
71    config: GeminiConfig,
72    client: Client,
73}
74
75impl GeminiProvider {
76    /// Create with API key and defaults
77    ///
78    /// # Example
79    ///
80    /// ```no_run
81    /// use stygian_graph::adapters::ai::gemini::GeminiProvider;
82    /// let p = GeminiProvider::new("AIza...".to_string());
83    /// ```
84    #[must_use]
85    pub fn new(api_key: String) -> Self {
86        Self::with_config(GeminiConfig::new(api_key))
87    }
88
89    /// Create with custom configuration
90    ///
91    /// # Panics
92    ///
93    /// Panics if the underlying HTTP client fails to build. With `rustls` as the
94    /// TLS backend this is unreachable in practice (build only fails when no TLS
95    /// backend is configured).
96    ///
97    /// # Example
98    ///
99    /// ```no_run
100    /// use stygian_graph::adapters::ai::gemini::{GeminiProvider, GeminiConfig};
101    /// let config = GeminiConfig::new("AIza...".to_string()).with_model("gemini-1.5-pro");
102    /// let p = GeminiProvider::with_config(config);
103    /// ```
104    #[must_use]
105    pub fn with_config(config: GeminiConfig) -> Self {
106        // SAFETY: TLS backend (rustls) is always available; build() only fails if no TLS backend.
107        #[allow(clippy::expect_used)]
108        let client = Client::builder()
109            .timeout(config.timeout)
110            .build()
111            .expect("Failed to build HTTP client");
112        Self { config, client }
113    }
114
115    fn api_url(&self) -> String {
116        format!(
117            "{}/{}:generateContent?key={}",
118            API_BASE, self.config.model, self.config.api_key
119        )
120    }
121
122    fn build_body(&self, content: &str, schema: &Value) -> Value {
123        let prompt = format!(
124            "Extract structured data from the following content according to this JSON schema.\n\
125             Return ONLY valid JSON matching the schema.\n\
126             Schema: {}\n\nContent:\n{}",
127            serde_json::to_string(schema).unwrap_or_default(),
128            content
129        );
130
131        json!({
132            "contents": [{"parts": [{"text": prompt}]}],
133            "generationConfig": {
134                "maxOutputTokens": self.config.max_tokens,
135                "responseMimeType": "application/json",
136                "responseSchema": schema
137            }
138        })
139    }
140
141    fn parse_response(response: &Value) -> Result<Value> {
142        let text = response
143            .pointer("/candidates/0/content/parts/0/text")
144            .and_then(Value::as_str)
145            .ok_or_else(|| {
146                StygianError::Provider(ProviderError::ApiError(
147                    "No text in Gemini response".to_string(),
148                ))
149            })?;
150
151        serde_json::from_str(text).map_err(|e| {
152            StygianError::Provider(ProviderError::ApiError(format!(
153                "Failed to parse Gemini JSON response: {e}"
154            )))
155        })
156    }
157
158    fn map_http_error(status: u16, body: &str) -> StygianError {
159        match status {
160            400 if body.contains("API_KEY") => {
161                StygianError::Provider(ProviderError::InvalidCredentials)
162            }
163            429 => StygianError::Provider(ProviderError::ApiError(format!(
164                "Gemini rate limited: {body}"
165            ))),
166            _ => StygianError::Provider(ProviderError::ApiError(format!("HTTP {status}: {body}"))),
167        }
168    }
169}
170
171#[async_trait]
172impl AIProvider for GeminiProvider {
173    async fn extract(&self, content: String, schema: Value) -> Result<Value> {
174        let body = self.build_body(&content, &schema);
175        let url = self.api_url();
176
177        let response = self
178            .client
179            .post(&url)
180            .header("Content-Type", "application/json")
181            .json(&body)
182            .send()
183            .await
184            .map_err(|e| {
185                StygianError::Provider(ProviderError::ApiError(format!(
186                    "Gemini request failed: {e}"
187                )))
188            })?;
189
190        let status = response.status().as_u16();
191        let text = response
192            .text()
193            .await
194            .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
195
196        if status != 200 {
197            return Err(Self::map_http_error(status, &text));
198        }
199
200        let json_val: Value = serde_json::from_str(&text)
201            .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
202
203        Self::parse_response(&json_val)
204    }
205
206    async fn stream_extract(
207        &self,
208        content: String,
209        schema: Value,
210    ) -> Result<BoxStream<'static, Result<Value>>> {
211        let result = self.extract(content, schema).await;
212        Ok(Box::pin(stream::once(async move { result })))
213    }
214
215    fn capabilities(&self) -> ProviderCapabilities {
216        ProviderCapabilities {
217            streaming: true,
218            vision: true,
219            tool_use: false,
220            json_mode: true,
221        }
222    }
223
224    fn name(&self) -> &'static str {
225        "gemini"
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use serde_json::json;
233
234    #[test]
235    fn test_name() {
236        assert_eq!(GeminiProvider::new("k".to_string()).name(), "gemini");
237    }
238
239    #[test]
240    fn test_capabilities() {
241        let caps = GeminiProvider::new("k".to_string()).capabilities();
242        assert!(caps.json_mode);
243        assert!(caps.vision);
244    }
245
246    #[test]
247    fn test_api_url_contains_model_and_key() {
248        let p = GeminiProvider::new("my-key".to_string());
249        let url = p.api_url();
250        assert!(url.contains(DEFAULT_MODEL));
251        assert!(url.contains("my-key"));
252    }
253
254    #[test]
255    fn test_build_body_has_response_mime() {
256        let p = GeminiProvider::new("k".to_string());
257        let body = p.build_body("content", &json!({"type": "object"}));
258        assert_eq!(
259            body.get("generationConfig")
260                .and_then(|gc| gc.get("responseMimeType"))
261                .and_then(Value::as_str),
262            Some("application/json")
263        );
264    }
265
266    #[test]
267    fn test_parse_response_valid() -> Result<()> {
268        let resp = json!({
269            "candidates": [{
270                "content": {"parts": [{"text": "{\"name\": \"Alice\"}"}]}
271            }]
272        });
273        let val = GeminiProvider::parse_response(&resp)?;
274        assert_eq!(val.get("name").and_then(Value::as_str), Some("Alice"));
275        Ok(())
276    }
277
278    #[test]
279    fn test_parse_response_no_candidates() {
280        let resp = json!({"promptFeedback": {}});
281        assert!(GeminiProvider::parse_response(&resp).is_err());
282    }
283
284    #[test]
285    fn test_parse_response_invalid_json_text() {
286        let resp = json!({
287            "candidates": [{
288                "content": {"parts": [{"text": "not json at all"}]}
289            }]
290        });
291        assert!(GeminiProvider::parse_response(&resp).is_err());
292    }
293
294    #[test]
295    fn test_map_http_error_api_key() {
296        let err = GeminiProvider::map_http_error(400, "Invalid API_KEY provided");
297        assert!(matches!(
298            err,
299            StygianError::Provider(ProviderError::InvalidCredentials)
300        ));
301    }
302
303    #[test]
304    fn test_map_http_error_429() {
305        let err = GeminiProvider::map_http_error(429, "quota exceeded");
306        assert!(
307            matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("rate limited"))
308        );
309    }
310
311    #[test]
312    fn test_map_http_error_server_error() {
313        let err = GeminiProvider::map_http_error(503, "unavailable");
314        assert!(
315            matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("503"))
316        );
317    }
318
319    #[test]
320    fn test_config_with_model() {
321        let cfg = GeminiConfig::new("AIza".to_string()).with_model("gemini-1.5-pro");
322        assert_eq!(cfg.model, "gemini-1.5-pro");
323    }
324}