stygian_graph/adapters/ai/
copilot.rs1use std::time::Duration;
22
23use async_trait::async_trait;
24use futures::stream::{self, BoxStream};
25use reqwest::Client;
26use serde_json::{Value, json};
27
28use crate::domain::error::{ProviderError, Result, StygianError};
29use crate::ports::{AIProvider, ProviderCapabilities};
30
31const API_URL: &str = "https://models.inference.ai.azure.com/chat/completions";
33
34const DEFAULT_MODEL: &str = "gpt-4o";
36
37#[derive(Debug, Clone)]
39pub struct CopilotConfig {
40 pub token: String,
42 pub model: String,
44 pub max_tokens: u32,
46 pub timeout: Duration,
48}
49
50impl CopilotConfig {
51 #[must_use]
53 pub fn new(token: String) -> Self {
54 Self {
55 token,
56 model: DEFAULT_MODEL.to_string(),
57 max_tokens: 4096,
58 timeout: Duration::from_mins(2),
59 }
60 }
61
62 #[must_use]
64 pub fn with_model(mut self, model: impl Into<String>) -> Self {
65 self.model = model.into();
66 self
67 }
68}
69
70pub struct CopilotProvider {
72 config: CopilotConfig,
73 client: Client,
74}
75
76impl CopilotProvider {
77 #[must_use]
86 pub fn new(token: String) -> Self {
87 Self::with_config(CopilotConfig::new(token))
88 }
89
90 #[must_use]
106 pub fn with_config(config: CopilotConfig) -> Self {
107 #[allow(clippy::expect_used)]
109 let client = Client::builder()
110 .timeout(config.timeout)
111 .build()
112 .expect("Failed to build HTTP client");
113 Self { config, client }
114 }
115
116 fn build_body(&self, content: &str, schema: &Value) -> Value {
117 let system = "You are a precise data extraction assistant. \
118 Extract structured data from the provided content matching the given JSON schema. \
119 Return ONLY valid JSON matching the schema, no extra text.";
120
121 let user_msg = format!(
122 "Schema: {}\n\nContent:\n{}",
123 serde_json::to_string(schema).unwrap_or_default(),
124 content
125 );
126
127 json!({
128 "model": self.config.model,
129 "max_tokens": self.config.max_tokens,
130 "response_format": {"type": "json_object"},
131 "messages": [
132 {"role": "system", "content": system},
133 {"role": "user", "content": user_msg}
134 ]
135 })
136 }
137
138 fn parse_response(response: &Value) -> Result<Value> {
139 let text = response
140 .pointer("/choices/0/message/content")
141 .and_then(Value::as_str)
142 .ok_or_else(|| {
143 StygianError::Provider(ProviderError::ApiError(
144 "No content in GitHub Models response".to_string(),
145 ))
146 })?;
147
148 serde_json::from_str(text).map_err(|e| {
149 StygianError::Provider(ProviderError::ApiError(format!(
150 "Failed to parse GitHub Models JSON response: {e}"
151 )))
152 })
153 }
154
155 fn map_http_error(status: u16, body: &str) -> StygianError {
156 match status {
157 401 | 403 => StygianError::Provider(ProviderError::InvalidCredentials),
158 429 => StygianError::Provider(ProviderError::ApiError(format!(
159 "GitHub Models rate limited: {body}"
160 ))),
161 _ => StygianError::Provider(ProviderError::ApiError(format!("HTTP {status}: {body}"))),
162 }
163 }
164}
165
166#[async_trait]
167impl AIProvider for CopilotProvider {
168 async fn extract(&self, content: String, schema: Value) -> Result<Value> {
169 let body = self.build_body(&content, &schema);
170
171 let response = self
172 .client
173 .post(API_URL)
174 .header("Authorization", format!("Bearer {}", &self.config.token))
175 .header("Content-Type", "application/json")
176 .json(&body)
177 .send()
178 .await
179 .map_err(|e| {
180 StygianError::Provider(ProviderError::ApiError(format!(
181 "GitHub Models request failed: {e}"
182 )))
183 })?;
184
185 let status = response.status().as_u16();
186 let text = response
187 .text()
188 .await
189 .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
190
191 if status != 200 {
192 return Err(Self::map_http_error(status, &text));
193 }
194
195 let json_val: Value = serde_json::from_str(&text)
196 .map_err(|e| StygianError::Provider(ProviderError::ApiError(e.to_string())))?;
197
198 Self::parse_response(&json_val)
199 }
200
201 async fn stream_extract(
202 &self,
203 content: String,
204 schema: Value,
205 ) -> Result<BoxStream<'static, Result<Value>>> {
206 let result = self.extract(content, schema).await;
207 Ok(Box::pin(stream::once(async move { result })))
208 }
209
210 fn capabilities(&self) -> ProviderCapabilities {
211 ProviderCapabilities {
212 streaming: true,
213 vision: false,
214 tool_use: true,
215 json_mode: true,
216 }
217 }
218
219 fn name(&self) -> &'static str {
220 "copilot"
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use serde_json::json;
228
229 #[test]
230 fn test_name() {
231 assert_eq!(CopilotProvider::new("t".to_string()).name(), "copilot");
232 }
233
234 #[test]
235 fn test_capabilities_json_mode() {
236 let caps = CopilotProvider::new("t".to_string()).capabilities();
237 assert!(caps.json_mode);
238 }
239
240 #[test]
241 fn test_build_body_json_format() {
242 let p = CopilotProvider::new("t".to_string());
243 let body = p.build_body("c", &json!({"type": "object"}));
244 assert_eq!(
245 body.get("response_format").and_then(|rf| rf.get("type")),
246 Some(&json!("json_object"))
247 );
248 }
249
250 #[test]
251 fn test_map_http_error_403() {
252 assert!(matches!(
253 CopilotProvider::map_http_error(403, ""),
254 StygianError::Provider(ProviderError::InvalidCredentials)
255 ));
256 }
257
258 #[test]
259 fn test_map_http_error_401() {
260 assert!(matches!(
261 CopilotProvider::map_http_error(401, ""),
262 StygianError::Provider(ProviderError::InvalidCredentials)
263 ));
264 }
265
266 #[test]
267 fn test_map_http_error_429() {
268 let err = CopilotProvider::map_http_error(429, "rate limit hit");
269 assert!(
270 matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("rate limited"))
271 );
272 }
273
274 #[test]
275 fn test_map_http_error_server_error() {
276 let err = CopilotProvider::map_http_error(502, "bad gateway");
277 assert!(
278 matches!(err, StygianError::Provider(ProviderError::ApiError(ref msg)) if msg.contains("502"))
279 );
280 }
281
282 #[test]
283 fn test_parse_response_valid() -> Result<()> {
284 let resp = serde_json::json!({
285 "choices": [{"message": {"content": "{\"title\": \"Hello\"}"}}]
286 });
287 let val = CopilotProvider::parse_response(&resp)?;
288 assert_eq!(
289 val.get("title").and_then(serde_json::Value::as_str),
290 Some("Hello")
291 );
292 Ok(())
293 }
294
295 #[test]
296 fn test_parse_response_no_content() {
297 let resp = serde_json::json!({"choices": []});
298 assert!(CopilotProvider::parse_response(&resp).is_err());
299 }
300
301 #[test]
302 fn test_config_with_model() {
303 let cfg = CopilotConfig::new("ghp_tok".to_string()).with_model("claude-3.5-sonnet");
304 assert_eq!(cfg.model, "claude-3.5-sonnet");
305 }
306}