Skip to main content

stygian_graph/adapters/ai/
copilot.rs

1//! GitHub Copilot AI provider adapter
2//!
3//! Implements the `AIProvider` port using GitHub's AI model inference gateway.
4//! Authenticates with a GitHub personal access token (PAT) or GitHub App token.
5//! Routes requests to GitHub's hosted AI models endpoint.
6//!
7//! # Example
8//!
9//! ```no_run
10//! use stygian_graph::adapters::ai::copilot::{CopilotProvider, CopilotConfig};
11//! use stygian_graph::ports::AIProvider;
12//! use serde_json::json;
13//!
14//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
15//! let provider = CopilotProvider::new("ghp_...".to_string());
16//! let schema = json!({"type": "object", "properties": {"title": {"type": "string"}}});
17//! // let result = provider.extract("<html>Hello</html>".to_string(), schema).await.unwrap();
18//! # });
19//! ```
20
21use std::time::Duration;
22
23use async_trait::async_trait;
24use futures::stream::{self, BoxStream};
25use reqwest::Client;
26use serde_json::{Value, json};
27
28use crate::domain::error::{ProviderError, Result, StygianError};
29use crate::ports::{AIProvider, ProviderCapabilities};
30
31/// GitHub Models inference endpoint
32const API_URL: &str = "https://models.inference.ai.azure.com/chat/completions";
33
34/// Default model available through GitHub Models
35const DEFAULT_MODEL: &str = "gpt-4o";
36
37/// Configuration for the GitHub Copilot provider
38#[derive(Debug, Clone)]
39pub struct CopilotConfig {
40    /// GitHub PAT or App token with `models:read` scope
41    pub token: String,
42    /// Model identifier
43    pub model: String,
44    /// Maximum response tokens
45    pub max_tokens: u32,
46    /// Request timeout
47    pub timeout: Duration,
48}
49
50impl CopilotConfig {
51    /// Create config with token and defaults
52    #[must_use]
53    pub fn new(token: String) -> Self {
54        Self {
55            token,
56            model: DEFAULT_MODEL.to_string(),
57            max_tokens: 4096,
58            timeout: Duration::from_mins(2),
59        }
60    }
61
62    /// Override model
63    #[must_use]
64    pub fn with_model(mut self, model: impl Into<String>) -> Self {
65        self.model = model.into();
66        self
67    }
68}
69
70/// GitHub Copilot AI provider adapter
71pub struct CopilotProvider {
72    config: CopilotConfig,
73    client: Client,
74}
75
76impl CopilotProvider {
77    /// Create with GitHub token and defaults
78    ///
79    /// # Example
80    ///
81    /// ```no_run
82    /// use stygian_graph::adapters::ai::copilot::CopilotProvider;
83    /// let p = CopilotProvider::new("ghp_...".to_string());
84    /// ```
85    #[must_use]
86    pub fn new(token: String) -> Self {
87        Self::with_config(CopilotConfig::new(token))
88    }
89
90    /// Create with custom configuration
91    ///
92    /// # Panics
93    ///
94    /// Panics if the underlying HTTP client fails to build. With `rustls` as the
95    /// TLS backend this is unreachable in practice (build only fails when no TLS
96    /// backend is configured).
97    ///
98    /// # Example
99    ///
100    /// ```no_run
101    /// use stygian_graph::adapters::ai::copilot::{CopilotProvider, CopilotConfig};
102    /// let config = CopilotConfig::new("ghp_...".to_string()).with_model("Meta-Llama-3.1-405B-Instruct");
103    /// let p = CopilotProvider::with_config(config);
104    /// ```
105    #[must_use]
106    pub fn with_config(config: CopilotConfig) -> Self {
107        // SAFETY: TLS backend (rustls) is always available; build() only fails if no TLS backend.
108        #[allow(clippy::expect_used)]
109        let client = Client::builder()
110            .timeout(config.timeout)
111            .build()
112            .expect("Failed to build HTTP client");
113        Self { config, client }
114    }
115
116    fn build_body(&self, content: &str, schema: &Value) -> Value {
117        let system = "You are a precise data extraction assistant. \
118            Extract structured data from the provided content matching the given JSON schema. \
119            Return ONLY valid JSON matching the schema, no extra text.";
120
121        let user_msg = format!(
122            "Schema: {}\n\nContent:\n{}",
123            serde_json::to_string(schema).unwrap_or_default(),
124            content
125        );
126
127        json!({
128            "model": self.config.model,
129            "max_tokens": self.config.max_tokens,
130            "response_format": {"type": "json_object"},
131            "messages": [
132                {"role": "system", "content": system},
133                {"role": "user", "content": user_msg}
134            ]
135        })
136    }
137
138    fn parse_response(response: &Value) -> Result<Value> {
139        let text = response
140            .pointer("/choices/0/message/content")
141            .and_then(Value::as_str)
142            .ok_or_else(|| {
143                StygianError::Provider(ProviderError::ApiError(
144                    "No content in GitHub Models response".to_string(),
145                ))
146            })?;
147
148        serde_json::from_str(text).map_err(|e| {
149            StygianError::Provider(ProviderError::ApiError(format!(
150                "Failed to parse GitHub Models JSON response: {e}"
151            )))
152        })
153    }
154
155    fn map_http_error(status: u16, body: &str) -> StygianError {
156        match status {
157            401 | 403 => StygianError::Provider(ProviderError::InvalidCredentials),
158            429 => StygianError::Provider(ProviderError::ApiError(format!(
159                "GitHub Models rate limited: {body}"
160            ))),
161            _ => StygianError::Provider(ProviderError::ApiError(format!("HTTP {status}: {body}"))),
162        }
163    }
164}
165
166#[async_trait]
167impl AIProvider for CopilotProvider {
168    async fn extract(&self, content: String, schema: Value) -> Result<Value> {
169        let body = self.build_body(&content, &schema);
170
171        let response = self
172            .client
173            .post(API_URL)
174            .header("Authorization", format!("Bearer {}", &self.config.token))
175            .header("Content-Type", "application/json")
176            .json(&body)
177            .send()
178            .await
179            .map_err(|e| {
180                StygianError::Provider(ProviderError::ApiError(format!(
181                    "GitHub Models request failed: {e}"
182                )))
183            })?;
184
185        let status = response.status().as_u16();
186        let text = response
187            .text()
188            .await
189            .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
190
191        if status != 200 {
192            return Err(Self::map_http_error(status, &text));
193        }
194
195        let json_val: Value = serde_json::from_str(&text)
196            .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
197
198        Self::parse_response(&json_val)
199    }
200
201    async fn stream_extract(
202        &self,
203        content: String,
204        schema: Value,
205    ) -> Result<BoxStream<'static, Result<Value>>> {
206        let result = self.extract(content, schema).await;
207        Ok(Box::pin(stream::once(async move { result })))
208    }
209
210    fn capabilities(&self) -> ProviderCapabilities {
211        ProviderCapabilities {
212            streaming: true,
213            vision: false,
214            tool_use: true,
215            json_mode: true,
216        }
217    }
218
219    fn name(&self) -> &'static str {
220        "copilot"
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use serde_json::json;
228
229    #[test]
230    fn test_name() {
231        assert_eq!(CopilotProvider::new("t".to_string()).name(), "copilot");
232    }
233
234    #[test]
235    fn test_capabilities_json_mode() {
236        let caps = CopilotProvider::new("t".to_string()).capabilities();
237        assert!(caps.json_mode);
238    }
239
240    #[test]
241    fn test_build_body_json_format() {
242        let p = CopilotProvider::new("t".to_string());
243        let body = p.build_body("c", &json!({"type": "object"}));
244        assert_eq!(
245            body.get("response_format").and_then(|rf| rf.get("type")),
246            Some(&json!("json_object"))
247        );
248    }
249
250    #[test]
251    fn test_map_http_error_403() {
252        assert!(matches!(
253            CopilotProvider::map_http_error(403, ""),
254            StygianError::Provider(ProviderError::InvalidCredentials)
255        ));
256    }
257
258    #[test]
259    fn test_map_http_error_401() {
260        assert!(matches!(
261            CopilotProvider::map_http_error(401, ""),
262            StygianError::Provider(ProviderError::InvalidCredentials)
263        ));
264    }
265
266    #[test]
267    fn test_map_http_error_429() {
268        let err = CopilotProvider::map_http_error(429, "rate limit hit");
269        assert!(
270            matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("rate limited"))
271        );
272    }
273
274    #[test]
275    fn test_map_http_error_server_error() {
276        let err = CopilotProvider::map_http_error(502, "bad gateway");
277        assert!(
278            matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("502"))
279        );
280    }
281
282    #[test]
283    fn test_parse_response_valid() -> Result<()> {
284        let resp = serde_json::json!({
285            "choices": [{"message": {"content": "{\"title\": \"Hello\"}"}}]
286        });
287        let val = CopilotProvider::parse_response(&resp)?;
288        assert_eq!(
289            val.get("title").and_then(serde_json::Value::as_str),
290            Some("Hello")
291        );
292        Ok(())
293    }
294
295    #[test]
296    fn test_parse_response_no_content() {
297        let resp = serde_json::json!({"choices": []});
298        assert!(CopilotProvider::parse_response(&resp).is_err());
299    }
300
301    #[test]
302    fn test_config_with_model() {
303        let cfg = CopilotConfig::new("ghp_tok".to_string()).with_model("claude-3.5-sonnet");
304        assert_eq!(cfg.model, "claude-3.5-sonnet");
305    }
306}