stygian_graph/adapters/ai/
copilot.rs1use std::time::Duration;
22
23use async_trait::async_trait;
24use futures::stream::{self, BoxStream};
25use reqwest::Client;
26use serde_json::{Value, json};
27
28use crate::domain::error::{ProviderError, Result, StygianError};
29use crate::ports::{AIProvider, ProviderCapabilities};
30
31const API_URL: &str = "https://models.inference.ai.azure.com/chat/completions";
33
34const DEFAULT_MODEL: &str = "gpt-4o";
36
37#[derive(Debug, Clone)]
39pub struct CopilotConfig {
40 pub token: String,
42 pub model: String,
44 pub max_tokens: u32,
46 pub timeout: Duration,
48}
49
50impl CopilotConfig {
51 pub fn new(token: String) -> Self {
53 Self {
54 token,
55 model: DEFAULT_MODEL.to_string(),
56 max_tokens: 4096,
57 timeout: Duration::from_secs(120),
58 }
59 }
60
61 #[must_use]
63 pub fn with_model(mut self, model: impl Into<String>) -> Self {
64 self.model = model.into();
65 self
66 }
67}
68
69pub struct CopilotProvider {
71 config: CopilotConfig,
72 client: Client,
73}
74
75impl CopilotProvider {
76 pub fn new(token: String) -> Self {
85 Self::with_config(CopilotConfig::new(token))
86 }
87
88 pub fn with_config(config: CopilotConfig) -> Self {
98 #[allow(clippy::expect_used)]
100 let client = Client::builder()
101 .timeout(config.timeout)
102 .build()
103 .expect("Failed to build HTTP client");
104 Self { config, client }
105 }
106
107 fn build_body(&self, content: &str, schema: &Value) -> Value {
108 let system = "You are a precise data extraction assistant. \
109 Extract structured data from the provided content matching the given JSON schema. \
110 Return ONLY valid JSON matching the schema, no extra text.";
111
112 let user_msg = format!(
113 "Schema: {}\n\nContent:\n{}",
114 serde_json::to_string(schema).unwrap_or_default(),
115 content
116 );
117
118 json!({
119 "model": self.config.model,
120 "max_tokens": self.config.max_tokens,
121 "response_format": {"type": "json_object"},
122 "messages": [
123 {"role": "system", "content": system},
124 {"role": "user", "content": user_msg}
125 ]
126 })
127 }
128
129 fn parse_response(response: &Value) -> Result<Value> {
130 let text = response
131 .pointer("/choices/0/message/content")
132 .and_then(Value::as_str)
133 .ok_or_else(|| {
134 StygianError::Provider(ProviderError::ApiError(
135 "No content in GitHub Models response".to_string(),
136 ))
137 })?;
138
139 serde_json::from_str(text).map_err(|e| {
140 StygianError::Provider(ProviderError::ApiError(format!(
141 "Failed to parse GitHub Models JSON response: {e}"
142 )))
143 })
144 }
145
146 fn map_http_error(status: u16, body: &str) -> StygianError {
147 match status {
148 401 | 403 => StygianError::Provider(ProviderError::InvalidCredentials),
149 429 => StygianError::Provider(ProviderError::ApiError(format!(
150 "GitHub Models rate limited: {body}"
151 ))),
152 _ => StygianError::Provider(ProviderError::ApiError(format!("HTTP {status}: {body}"))),
153 }
154 }
155}
156
157#[async_trait]
158impl AIProvider for CopilotProvider {
159 async fn extract(&self, content: String, schema: Value) -> Result<Value> {
160 let body = self.build_body(&content, &schema);
161
162 let response = self
163 .client
164 .post(API_URL)
165 .header("Authorization", format!("Bearer {}", &self.config.token))
166 .header("Content-Type", "application/json")
167 .json(&body)
168 .send()
169 .await
170 .map_err(|e| {
171 StygianError::Provider(ProviderError::ApiError(format!(
172 "GitHub Models request failed: {e}"
173 )))
174 })?;
175
176 let status = response.status().as_u16();
177 let text = response
178 .text()
179 .await
180 .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
181
182 if status != 200 {
183 return Err(Self::map_http_error(status, &text));
184 }
185
186 let json_val: Value = serde_json::from_str(&text)
187 .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
188
189 Self::parse_response(&json_val)
190 }
191
192 async fn stream_extract(
193 &self,
194 content: String,
195 schema: Value,
196 ) -> Result<BoxStream<'static, Result<Value>>> {
197 let result = self.extract(content, schema).await;
198 Ok(Box::pin(stream::once(async move { result })))
199 }
200
201 fn capabilities(&self) -> ProviderCapabilities {
202 ProviderCapabilities {
203 streaming: true,
204 vision: false,
205 tool_use: true,
206 json_mode: true,
207 }
208 }
209
210 fn name(&self) -> &'static str {
211 "copilot"
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use serde_json::json;
219
220 #[test]
221 fn test_name() {
222 assert_eq!(CopilotProvider::new("t".to_string()).name(), "copilot");
223 }
224
225 #[test]
226 fn test_capabilities_json_mode() {
227 let caps = CopilotProvider::new("t".to_string()).capabilities();
228 assert!(caps.json_mode);
229 }
230
231 #[test]
232 fn test_build_body_json_format() {
233 let p = CopilotProvider::new("t".to_string());
234 let body = p.build_body("c", &json!({"type": "object"}));
235 assert_eq!(
236 body.get("response_format").and_then(|rf| rf.get("type")),
237 Some(&json!("json_object"))
238 );
239 }
240
241 #[test]
242 fn test_map_http_error_403() {
243 assert!(matches!(
244 CopilotProvider::map_http_error(403, ""),
245 StygianError::Provider(ProviderError::InvalidCredentials)
246 ));
247 }
248
249 #[test]
250 fn test_map_http_error_401() {
251 assert!(matches!(
252 CopilotProvider::map_http_error(401, ""),
253 StygianError::Provider(ProviderError::InvalidCredentials)
254 ));
255 }
256
257 #[test]
258 fn test_map_http_error_429() {
259 let err = CopilotProvider::map_http_error(429, "rate limit hit");
260 assert!(
261 matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("rate limited"))
262 );
263 }
264
265 #[test]
266 fn test_map_http_error_server_error() {
267 let err = CopilotProvider::map_http_error(502, "bad gateway");
268 assert!(
269 matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("502"))
270 );
271 }
272
273 #[test]
274 fn test_parse_response_valid() -> Result<()> {
275 let resp = serde_json::json!({
276 "choices": [{"message": {"content": "{\"title\": \"Hello\"}"}}]
277 });
278 let val = CopilotProvider::parse_response(&resp)?;
279 assert_eq!(
280 val.get("title").and_then(serde_json::Value::as_str),
281 Some("Hello")
282 );
283 Ok(())
284 }
285
286 #[test]
287 fn test_parse_response_no_content() {
288 let resp = serde_json::json!({"choices": []});
289 assert!(CopilotProvider::parse_response(&resp).is_err());
290 }
291
292 #[test]
293 fn test_config_with_model() {
294 let cfg = CopilotConfig::new("ghp_tok".to_string()).with_model("claude-3.5-sonnet");
295 assert_eq!(cfg.model, "claude-3.5-sonnet");
296 }
297}