stygian_graph/adapters/ai/
gemini.rs1use std::time::Duration;
21
22use async_trait::async_trait;
23use futures::stream::{self, BoxStream};
24use reqwest::Client;
25use serde_json::{Value, json};
26
27use crate::domain::error::{ProviderError, Result, StygianError};
28use crate::ports::{AIProvider, ProviderCapabilities};
29
30const DEFAULT_MODEL: &str = "gemini-2.0-flash";
32
33const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models";
35
36#[derive(Debug, Clone)]
38pub struct GeminiConfig {
39 pub api_key: String,
41 pub model: String,
43 pub max_tokens: u32,
45 pub timeout: Duration,
47}
48
49impl GeminiConfig {
50 pub fn new(api_key: String) -> Self {
52 Self {
53 api_key,
54 model: DEFAULT_MODEL.to_string(),
55 max_tokens: 8192,
56 timeout: Duration::from_secs(120),
57 }
58 }
59
60 #[must_use]
62 pub fn with_model(mut self, model: impl Into<String>) -> Self {
63 self.model = model.into();
64 self
65 }
66}
67
68pub struct GeminiProvider {
70 config: GeminiConfig,
71 client: Client,
72}
73
74impl GeminiProvider {
75 pub fn new(api_key: String) -> Self {
84 Self::with_config(GeminiConfig::new(api_key))
85 }
86
87 pub fn with_config(config: GeminiConfig) -> Self {
97 #[allow(clippy::expect_used)]
99 let client = Client::builder()
100 .timeout(config.timeout)
101 .build()
102 .expect("Failed to build HTTP client");
103 Self { config, client }
104 }
105
106 fn api_url(&self) -> String {
107 format!(
108 "{}/{}:generateContent?key={}",
109 API_BASE, self.config.model, self.config.api_key
110 )
111 }
112
113 fn build_body(&self, content: &str, schema: &Value) -> Value {
114 let prompt = format!(
115 "Extract structured data from the following content according to this JSON schema.\n\
116 Return ONLY valid JSON matching the schema.\n\
117 Schema: {}\n\nContent:\n{}",
118 serde_json::to_string(schema).unwrap_or_default(),
119 content
120 );
121
122 json!({
123 "contents": [{"parts": [{"text": prompt}]}],
124 "generationConfig": {
125 "maxOutputTokens": self.config.max_tokens,
126 "responseMimeType": "application/json",
127 "responseSchema": schema
128 }
129 })
130 }
131
132 fn parse_response(response: &Value) -> Result<Value> {
133 let text = response
134 .pointer("/candidates/0/content/parts/0/text")
135 .and_then(Value::as_str)
136 .ok_or_else(|| {
137 StygianError::Provider(ProviderError::ApiError(
138 "No text in Gemini response".to_string(),
139 ))
140 })?;
141
142 serde_json::from_str(text).map_err(|e| {
143 StygianError::Provider(ProviderError::ApiError(format!(
144 "Failed to parse Gemini JSON response: {e}"
145 )))
146 })
147 }
148
149 fn map_http_error(status: u16, body: &str) -> StygianError {
150 match status {
151 400 if body.contains("API_KEY") => {
152 StygianError::Provider(ProviderError::InvalidCredentials)
153 }
154 429 => StygianError::Provider(ProviderError::ApiError(format!(
155 "Gemini rate limited: {body}"
156 ))),
157 _ => StygianError::Provider(ProviderError::ApiError(format!("HTTP {status}: {body}"))),
158 }
159 }
160}
161
162#[async_trait]
163impl AIProvider for GeminiProvider {
164 async fn extract(&self, content: String, schema: Value) -> Result<Value> {
165 let body = self.build_body(&content, &schema);
166 let url = self.api_url();
167
168 let response = self
169 .client
170 .post(&url)
171 .header("Content-Type", "application/json")
172 .json(&body)
173 .send()
174 .await
175 .map_err(|e| {
176 StygianError::Provider(ProviderError::ApiError(format!(
177 "Gemini request failed: {e}"
178 )))
179 })?;
180
181 let status = response.status().as_u16();
182 let text = response
183 .text()
184 .await
185 .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
186
187 if status != 200 {
188 return Err(Self::map_http_error(status, &text));
189 }
190
191 let json_val: Value = serde_json::from_str(&text)
192 .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
193
194 Self::parse_response(&json_val)
195 }
196
197 async fn stream_extract(
198 &self,
199 content: String,
200 schema: Value,
201 ) -> Result<BoxStream<'static, Result<Value>>> {
202 let result = self.extract(content, schema).await;
203 Ok(Box::pin(stream::once(async move { result })))
204 }
205
206 fn capabilities(&self) -> ProviderCapabilities {
207 ProviderCapabilities {
208 streaming: true,
209 vision: true,
210 tool_use: false,
211 json_mode: true,
212 }
213 }
214
215 fn name(&self) -> &'static str {
216 "gemini"
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use serde_json::json;
224
225 #[test]
226 fn test_name() {
227 assert_eq!(GeminiProvider::new("k".to_string()).name(), "gemini");
228 }
229
230 #[test]
231 fn test_capabilities() {
232 let caps = GeminiProvider::new("k".to_string()).capabilities();
233 assert!(caps.json_mode);
234 assert!(caps.vision);
235 }
236
237 #[test]
238 fn test_api_url_contains_model_and_key() {
239 let p = GeminiProvider::new("my-key".to_string());
240 let url = p.api_url();
241 assert!(url.contains(DEFAULT_MODEL));
242 assert!(url.contains("my-key"));
243 }
244
245 #[test]
246 fn test_build_body_has_response_mime() {
247 let p = GeminiProvider::new("k".to_string());
248 let body = p.build_body("content", &json!({"type": "object"}));
249 assert_eq!(
250 body.get("generationConfig")
251 .and_then(|gc| gc.get("responseMimeType"))
252 .and_then(Value::as_str),
253 Some("application/json")
254 );
255 }
256
257 #[test]
258 fn test_parse_response_valid() -> Result<()> {
259 let resp = json!({
260 "candidates": [{
261 "content": {"parts": [{"text": "{\"name\": \"Alice\"}"}]}
262 }]
263 });
264 let val = GeminiProvider::parse_response(&resp)?;
265 assert_eq!(val.get("name").and_then(Value::as_str), Some("Alice"));
266 Ok(())
267 }
268
269 #[test]
270 fn test_parse_response_no_candidates() {
271 let resp = json!({"promptFeedback": {}});
272 assert!(GeminiProvider::parse_response(&resp).is_err());
273 }
274
275 #[test]
276 fn test_parse_response_invalid_json_text() {
277 let resp = json!({
278 "candidates": [{
279 "content": {"parts": [{"text": "not json at all"}]}
280 }]
281 });
282 assert!(GeminiProvider::parse_response(&resp).is_err());
283 }
284
285 #[test]
286 fn test_map_http_error_api_key() {
287 let err = GeminiProvider::map_http_error(400, "Invalid API_KEY provided");
288 assert!(matches!(
289 err,
290 StygianError::Provider(ProviderError::InvalidCredentials)
291 ));
292 }
293
294 #[test]
295 fn test_map_http_error_429() {
296 let err = GeminiProvider::map_http_error(429, "quota exceeded");
297 assert!(
298 matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("rate limited"))
299 );
300 }
301
302 #[test]
303 fn test_map_http_error_server_error() {
304 let err = GeminiProvider::map_http_error(503, "unavailable");
305 assert!(
306 matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("503"))
307 );
308 }
309
310 #[test]
311 fn test_config_with_model() {
312 let cfg = GeminiConfig::new("AIza".to_string()).with_model("gemini-1.5-pro");
313 assert_eq!(cfg.model, "gemini-1.5-pro");
314 }
315}