stygian_graph/adapters/ai/
claude.rs1use std::time::Duration;
27
28use async_trait::async_trait;
29use futures::stream::{self, BoxStream};
30use reqwest::Client;
31use serde_json::{Value, json};
32
33use crate::domain::error::{ProviderError, Result, StygianError};
34use crate::ports::{AIProvider, ProviderCapabilities};
35
36const DEFAULT_MODEL: &str = "claude-sonnet-4-5";
38
39const API_URL: &str = "https://api.anthropic.com/v1/messages";
41
42const ANTHROPIC_VERSION: &str = "2023-06-01";
44
45#[derive(Debug, Clone)]
47pub struct ClaudeConfig {
48 pub api_key: String,
50 pub model: String,
52 pub max_tokens: u32,
54 pub timeout: Duration,
56}
57
58impl ClaudeConfig {
59 pub fn new(api_key: String) -> Self {
61 Self {
62 api_key,
63 model: DEFAULT_MODEL.to_string(),
64 max_tokens: 4096,
65 timeout: Duration::from_secs(120),
66 }
67 }
68
69 #[must_use]
71 pub fn with_model(mut self, model: impl Into<String>) -> Self {
72 self.model = model.into();
73 self
74 }
75
76 #[must_use]
78 pub const fn with_max_tokens(mut self, max_tokens: u32) -> Self {
79 self.max_tokens = max_tokens;
80 self
81 }
82}
83
84pub struct ClaudeProvider {
89 config: ClaudeConfig,
90 client: Client,
91}
92
93impl ClaudeProvider {
94 pub fn new(api_key: String) -> Self {
104 let config = ClaudeConfig::new(api_key);
105 Self::with_config(config)
106 }
107
108 pub fn with_config(config: ClaudeConfig) -> Self {
120 #[allow(clippy::expect_used)]
122 let client = Client::builder()
123 .timeout(config.timeout)
124 .build()
125 .expect("Failed to build HTTP client");
126 Self { config, client }
127 }
128
129 fn build_extract_body(&self, content: &str, schema: &Value) -> Value {
134 let system = "You are a precise data extraction assistant. \
135 Extract the requested information from the provided content and \
136 return it using the extract_data tool. \
137 Always extract exactly what the schema requests — nothing more, nothing less.";
138
139 let tool = json!({
140 "name": "extract_data",
141 "description": "Extract structured data from the provided content according to the schema.",
142 "input_schema": schema
143 });
144
145 json!({
146 "model": self.config.model,
147 "max_tokens": self.config.max_tokens,
148 "system": system,
149 "tools": [tool],
150 "tool_choice": {"type": "tool", "name": "extract_data"},
151 "messages": [
152 {
153 "role": "user",
154 "content": format!("Extract data from the following content:\n\n{content}")
155 }
156 ]
157 })
158 }
159
160 #[allow(dead_code, clippy::indexing_slicing)]
162 fn build_stream_body(&self, content: &str, schema: &Value) -> Value {
163 let mut body = self.build_extract_body(content, schema);
164 body["stream"] = json!(true);
165 body
166 }
167
168 fn parse_extract_response(response: &Value) -> Result<Value> {
170 let content = response
172 .get("content")
173 .and_then(Value::as_array)
174 .ok_or_else(|| {
175 StygianError::Provider(ProviderError::ApiError(
176 "No content in Claude response".to_string(),
177 ))
178 })?;
179
180 for block in content {
181 if block.get("type").and_then(Value::as_str) == Some("tool_use")
182 && let Some(input) = block.get("input")
183 {
184 return Ok(input.clone());
185 }
186 }
187
188 Err(StygianError::Provider(ProviderError::ApiError(
189 "Claude response contained no tool_use block".to_string(),
190 )))
191 }
192
193 fn map_http_error(status: u16, body: &str) -> StygianError {
195 match status {
196 401 => StygianError::Provider(ProviderError::InvalidCredentials),
197 429 => StygianError::Provider(ProviderError::ApiError(format!(
198 "Rate limited by Anthropic API: {body}"
199 ))),
200 400 => {
201 if body.contains("token") {
202 StygianError::Provider(ProviderError::TokenLimitExceeded(body.to_string()))
203 } else if body.contains("policy") {
204 StygianError::Provider(ProviderError::ContentPolicyViolation(body.to_string()))
205 } else {
206 StygianError::Provider(ProviderError::ApiError(body.to_string()))
207 }
208 }
209 _ => StygianError::Provider(ProviderError::ApiError(format!("HTTP {status}: {body}"))),
210 }
211 }
212}
213
214#[async_trait]
215impl AIProvider for ClaudeProvider {
216 async fn extract(&self, content: String, schema: Value) -> Result<Value> {
236 let body = self.build_extract_body(&content, &schema);
237
238 let response = self
239 .client
240 .post(API_URL)
241 .header("x-api-key", &self.config.api_key)
242 .header("anthropic-version", ANTHROPIC_VERSION)
243 .header("content-type", "application/json")
244 .json(&body)
245 .send()
246 .await
247 .map_err(|e| {
248 StygianError::Provider(ProviderError::ApiError(format!(
249 "Request to Anthropic API failed: {e}"
250 )))
251 })?;
252
253 let status = response.status().as_u16();
254 let text = response.text().await.map_err(|e| {
255 StygianError::Provider(ProviderError::ApiError(format!(
256 "Failed to read Anthropic response body: {e}"
257 )))
258 })?;
259
260 if status != 200 {
261 return Err(Self::map_http_error(status, &text));
262 }
263
264 let json_value: Value = serde_json::from_str(&text).map_err(|e| {
265 StygianError::Provider(ProviderError::ApiError(format!(
266 "Failed to parse Anthropic response JSON: {e}"
267 )))
268 })?;
269
270 Self::parse_extract_response(&json_value)
271 }
272
273 async fn stream_extract(
293 &self,
294 content: String,
295 schema: Value,
296 ) -> Result<BoxStream<'static, Result<Value>>> {
297 let result = self.extract(content, schema).await;
302 let stream = stream::once(async move { result });
303 Ok(Box::pin(stream))
304 }
305
306 fn capabilities(&self) -> ProviderCapabilities {
307 ProviderCapabilities {
308 streaming: true,
309 vision: true,
310 tool_use: true,
311 json_mode: true,
312 }
313 }
314
315 fn name(&self) -> &'static str {
316 "claude"
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use serde_json::json;
324
325 #[test]
326 fn test_provider_name() {
327 let p = ClaudeProvider::new("key".to_string());
328 assert_eq!(p.name(), "claude");
329 }
330
331 #[test]
332 fn test_capabilities() {
333 let p = ClaudeProvider::new("key".to_string());
334 let caps = p.capabilities();
335 assert!(caps.streaming);
336 assert!(caps.vision);
337 assert!(caps.tool_use);
338 assert!(caps.json_mode);
339 }
340
341 #[test]
342 fn test_build_extract_body_contains_tool() -> std::result::Result<(), Box<dyn std::error::Error>>
343 {
344 let p = ClaudeProvider::new("key".to_string());
345 let schema = json!({"type": "object"});
346 let body = p.build_extract_body("some content", &schema);
347
348 assert_eq!(
349 body.get("model").and_then(Value::as_str),
350 Some(DEFAULT_MODEL)
351 );
352 let tools = body
353 .get("tools")
354 .and_then(Value::as_array)
355 .ok_or("no tools field")?;
356 assert_eq!(tools.len(), 1);
357 assert_eq!(
358 tools
359 .first()
360 .and_then(|t| t.get("name"))
361 .and_then(Value::as_str),
362 Some("extract_data")
363 );
364 assert_eq!(
365 body.get("tool_choice")
366 .and_then(|tc| tc.get("name"))
367 .and_then(Value::as_str),
368 Some("extract_data")
369 );
370 Ok(())
371 }
372
373 #[test]
374 fn test_parse_extract_response_success() -> Result<()> {
375 let response = json!({
376 "content": [
377 {"type": "tool_use", "name": "extract_data", "input": {"title": "Hello"}}
378 ]
379 });
380 let result = ClaudeProvider::parse_extract_response(&response)?;
381 assert_eq!(result.get("title").and_then(Value::as_str), Some("Hello"));
382 Ok(())
383 }
384
385 #[test]
386 fn test_parse_extract_response_no_tool_use() {
387 let response = json!({
388 "content": [{"type": "text", "text": "some text"}]
389 });
390 let err_result = ClaudeProvider::parse_extract_response(&response);
391 assert!(err_result.is_err(), "expected Err but got Ok");
392 if let Err(e) = err_result {
393 assert!(e.to_string().contains("tool_use"));
394 }
395 }
396
397 #[test]
398 fn test_parse_extract_response_no_content() {
399 let response = json!({"stop_reason": "end_turn"});
400 let err_result = ClaudeProvider::parse_extract_response(&response);
401 assert!(err_result.is_err(), "expected Err but got Ok");
402 if let Err(e) = err_result {
403 assert!(e.to_string().contains("content") || e.to_string().contains("API error"));
404 }
405 }
406
407 #[test]
408 fn test_map_http_error_401() {
409 let e = ClaudeProvider::map_http_error(401, "unauthorized");
410 assert!(matches!(
411 e,
412 StygianError::Provider(ProviderError::InvalidCredentials)
413 ));
414 }
415
416 #[test]
417 fn test_map_http_error_429() {
418 let e = ClaudeProvider::map_http_error(429, "rate limited");
419 assert!(e.to_string().contains("Rate limited"));
420 }
421
422 #[test]
423 fn test_config_builder() {
424 let config = ClaudeConfig::new("key".to_string())
425 .with_model("claude-3-5-sonnet-20241022")
426 .with_max_tokens(2048);
427 assert_eq!(config.model, "claude-3-5-sonnet-20241022");
428 assert_eq!(config.max_tokens, 2048);
429 }
430
431 #[tokio::test]
432 async fn test_stream_extract_returns_stream() {
433 use futures::StreamExt;
434 let p = ClaudeProvider::new("invalid-key".to_string());
436 let schema = json!({"type": "object"});
437 let result = p.stream_extract("content".to_string(), schema).await;
438 assert!(result.is_ok(), "stream_extract should return Ok(stream)");
440 if let Ok(mut s) = result {
441 let item = s.next().await;
443 assert!(item.is_some());
444 }
446 }
447}