stygian_graph/adapters/ai/
copilot.rs

1//! GitHub Copilot AI provider adapter
2//!
3//! Implements the `AIProvider` port using GitHub's AI model inference gateway.
4//! Authenticates with a GitHub personal access token (PAT) or GitHub App token.
5//! Routes requests to GitHub's hosted AI models endpoint.
6//!
7//! # Example
8//!
9//! ```no_run
10//! use stygian_graph::adapters::ai::copilot::{CopilotProvider, CopilotConfig};
11//! use stygian_graph::ports::AIProvider;
12//! use serde_json::json;
13//!
14//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
15//! let provider = CopilotProvider::new("ghp_...".to_string());
16//! let schema = json!({"type": "object", "properties": {"title": {"type": "string"}}});
17//! // let result = provider.extract("<html>Hello</html>".to_string(), schema).await.unwrap();
18//! # });
19//! ```
20
21use std::time::Duration;
22
23use async_trait::async_trait;
24use futures::stream::{self, BoxStream};
25use reqwest::Client;
26use serde_json::{Value, json};
27
28use crate::domain::error::{ProviderError, Result, StygianError};
29use crate::ports::{AIProvider, ProviderCapabilities};
30
31/// GitHub Models inference endpoint
32const API_URL: &str = "https://models.inference.ai.azure.com/chat/completions";
33
34/// Default model available through GitHub Models
35const DEFAULT_MODEL: &str = "gpt-4o";
36
37/// Configuration for the GitHub Copilot provider
38#[derive(Debug, Clone)]
39pub struct CopilotConfig {
40    /// GitHub PAT or App token with `models:read` scope
41    pub token: String,
42    /// Model identifier
43    pub model: String,
44    /// Maximum response tokens
45    pub max_tokens: u32,
46    /// Request timeout
47    pub timeout: Duration,
48}
49
50impl CopilotConfig {
51    /// Create config with token and defaults
52    pub fn new(token: String) -> Self {
53        Self {
54            token,
55            model: DEFAULT_MODEL.to_string(),
56            max_tokens: 4096,
57            timeout: Duration::from_secs(120),
58        }
59    }
60
61    /// Override model
62    #[must_use]
63    pub fn with_model(mut self, model: impl Into<String>) -> Self {
64        self.model = model.into();
65        self
66    }
67}
68
69/// GitHub Copilot AI provider adapter
70pub struct CopilotProvider {
71    config: CopilotConfig,
72    client: Client,
73}
74
75impl CopilotProvider {
76    /// Create with GitHub token and defaults
77    ///
78    /// # Example
79    ///
80    /// ```no_run
81    /// use stygian_graph::adapters::ai::copilot::CopilotProvider;
82    /// let p = CopilotProvider::new("ghp_...".to_string());
83    /// ```
84    pub fn new(token: String) -> Self {
85        Self::with_config(CopilotConfig::new(token))
86    }
87
88    /// Create with custom configuration
89    ///
90    /// # Example
91    ///
92    /// ```no_run
93    /// use stygian_graph::adapters::ai::copilot::{CopilotProvider, CopilotConfig};
94    /// let config = CopilotConfig::new("ghp_...".to_string()).with_model("Meta-Llama-3.1-405B-Instruct");
95    /// let p = CopilotProvider::with_config(config);
96    /// ```
97    pub fn with_config(config: CopilotConfig) -> Self {
98        // SAFETY: TLS backend (rustls) is always available; build() only fails if no TLS backend.
99        #[allow(clippy::expect_used)]
100        let client = Client::builder()
101            .timeout(config.timeout)
102            .build()
103            .expect("Failed to build HTTP client");
104        Self { config, client }
105    }
106
107    fn build_body(&self, content: &str, schema: &Value) -> Value {
108        let system = "You are a precise data extraction assistant. \
109            Extract structured data from the provided content matching the given JSON schema. \
110            Return ONLY valid JSON matching the schema, no extra text.";
111
112        let user_msg = format!(
113            "Schema: {}\n\nContent:\n{}",
114            serde_json::to_string(schema).unwrap_or_default(),
115            content
116        );
117
118        json!({
119            "model": self.config.model,
120            "max_tokens": self.config.max_tokens,
121            "response_format": {"type": "json_object"},
122            "messages": [
123                {"role": "system", "content": system},
124                {"role": "user", "content": user_msg}
125            ]
126        })
127    }
128
129    fn parse_response(response: &Value) -> Result<Value> {
130        let text = response
131            .pointer("/choices/0/message/content")
132            .and_then(Value::as_str)
133            .ok_or_else(|| {
134                StygianError::Provider(ProviderError::ApiError(
135                    "No content in GitHub Models response".to_string(),
136                ))
137            })?;
138
139        serde_json::from_str(text).map_err(|e| {
140            StygianError::Provider(ProviderError::ApiError(format!(
141                "Failed to parse GitHub Models JSON response: {e}"
142            )))
143        })
144    }
145
146    fn map_http_error(status: u16, body: &str) -> StygianError {
147        match status {
148            401 | 403 => StygianError::Provider(ProviderError::InvalidCredentials),
149            429 => StygianError::Provider(ProviderError::ApiError(format!(
150                "GitHub Models rate limited: {body}"
151            ))),
152            _ => StygianError::Provider(ProviderError::ApiError(format!("HTTP {status}: {body}"))),
153        }
154    }
155}
156
157#[async_trait]
158impl AIProvider for CopilotProvider {
159    async fn extract(&self, content: String, schema: Value) -> Result<Value> {
160        let body = self.build_body(&content, &schema);
161
162        let response = self
163            .client
164            .post(API_URL)
165            .header("Authorization", format!("Bearer {}", &self.config.token))
166            .header("Content-Type", "application/json")
167            .json(&body)
168            .send()
169            .await
170            .map_err(|e| {
171                StygianError::Provider(ProviderError::ApiError(format!(
172                    "GitHub Models request failed: {e}"
173                )))
174            })?;
175
176        let status = response.status().as_u16();
177        let text = response
178            .text()
179            .await
180            .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
181
182        if status != 200 {
183            return Err(Self::map_http_error(status, &text));
184        }
185
186        let json_val: Value = serde_json::from_str(&text)
187            .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
188
189        Self::parse_response(&json_val)
190    }
191
192    async fn stream_extract(
193        &self,
194        content: String,
195        schema: Value,
196    ) -> Result<BoxStream<'static, Result<Value>>> {
197        let result = self.extract(content, schema).await;
198        Ok(Box::pin(stream::once(async move { result })))
199    }
200
201    fn capabilities(&self) -> ProviderCapabilities {
202        ProviderCapabilities {
203            streaming: true,
204            vision: false,
205            tool_use: true,
206            json_mode: true,
207        }
208    }
209
210    fn name(&self) -> &'static str {
211        "copilot"
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use serde_json::json;
219
220    #[test]
221    fn test_name() {
222        assert_eq!(CopilotProvider::new("t".to_string()).name(), "copilot");
223    }
224
225    #[test]
226    fn test_capabilities_json_mode() {
227        let caps = CopilotProvider::new("t".to_string()).capabilities();
228        assert!(caps.json_mode);
229    }
230
231    #[test]
232    fn test_build_body_json_format() {
233        let p = CopilotProvider::new("t".to_string());
234        let body = p.build_body("c", &json!({"type": "object"}));
235        assert_eq!(
236            body.get("response_format").and_then(|rf| rf.get("type")),
237            Some(&json!("json_object"))
238        );
239    }
240
241    #[test]
242    fn test_map_http_error_403() {
243        assert!(matches!(
244            CopilotProvider::map_http_error(403, ""),
245            StygianError::Provider(ProviderError::InvalidCredentials)
246        ));
247    }
248
249    #[test]
250    fn test_map_http_error_401() {
251        assert!(matches!(
252            CopilotProvider::map_http_error(401, ""),
253            StygianError::Provider(ProviderError::InvalidCredentials)
254        ));
255    }
256
257    #[test]
258    fn test_map_http_error_429() {
259        let err = CopilotProvider::map_http_error(429, "rate limit hit");
260        assert!(
261            matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("rate limited"))
262        );
263    }
264
265    #[test]
266    fn test_map_http_error_server_error() {
267        let err = CopilotProvider::map_http_error(502, "bad gateway");
268        assert!(
269            matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("502"))
270        );
271    }
272
273    #[test]
274    fn test_parse_response_valid() -> Result<()> {
275        let resp = serde_json::json!({
276            "choices": [{"message": {"content": "{\"title\": \"Hello\"}"}}]
277        });
278        let val = CopilotProvider::parse_response(&resp)?;
279        assert_eq!(
280            val.get("title").and_then(serde_json::Value::as_str),
281            Some("Hello")
282        );
283        Ok(())
284    }
285
286    #[test]
287    fn test_parse_response_no_content() {
288        let resp = serde_json::json!({"choices": []});
289        assert!(CopilotProvider::parse_response(&resp).is_err());
290    }
291
292    #[test]
293    fn test_config_with_model() {
294        let cfg = CopilotConfig::new("ghp_tok".to_string()).with_model("claude-3.5-sonnet");
295        assert_eq!(cfg.model, "claude-3.5-sonnet");
296    }
297}