Skip to main content

stygian_graph/adapters/
webhook.rs

1//! Webhook trigger adapter — axum-based HTTP listener.
2//!
3//! Implements [`WebhookTrigger`](crate::ports::webhook::WebhookTrigger) with an embedded axum server that accepts
4//! inbound webhooks, verifies HMAC-SHA256 signatures, enforces body-size limits,
5//! and emits [`WebhookEvent`](crate::ports::webhook::WebhookEvent)s via a channel.
6//!
7//! Also implements [`ScrapingService`](crate::ports::ScrapingService) so a pipeline node can start a webhook
8//! listener and wait for the next event as input.
9//!
10//! # Feature gate
11//!
12//! Requires `feature = "api"`.
13
14use crate::domain::error::{Result, ServiceError, StygianError};
15use crate::ports::webhook::{WebhookConfig, WebhookEvent, WebhookListenerHandle, WebhookTrigger};
16use crate::ports::{ScrapingService, ServiceInput, ServiceOutput};
17use async_trait::async_trait;
18use axum::Router;
19use axum::body::Bytes;
20use axum::extract::State;
21use axum::http::{HeaderMap, StatusCode};
22use axum::response::IntoResponse;
23use axum::routing::{get, post};
24use hmac::{Hmac, Mac};
25use serde_json::json;
26use sha2::Sha256;
27use std::collections::HashMap;
28use std::time::{SystemTime, UNIX_EPOCH};
29use tokio::net::TcpListener;
30use tokio::sync::{Mutex, broadcast};
31use tracing::{debug, info, warn};
32
33type HmacSha256 = Hmac<Sha256>;
34
35// ─── Shared state ─────────────────────────────────────────────────────────────
36
37#[derive(Clone)]
38struct AppState {
39    config: WebhookConfig,
40    tx: broadcast::Sender<WebhookEvent>,
41}
42
43// ─── Adapter ──────────────────────────────────────────────────────────────────
44
45/// Axum-based webhook trigger adapter.
46pub struct AxumWebhookTrigger {
47    tx: broadcast::Sender<WebhookEvent>,
48    rx: Mutex<broadcast::Receiver<WebhookEvent>>,
49    shutdown: Mutex<Option<tokio::sync::oneshot::Sender<()>>>,
50}
51
52impl AxumWebhookTrigger {
53    /// Create a new [`AxumWebhookTrigger`].
54    pub fn new() -> Self {
55        let (tx, rx) = broadcast::channel(256);
56        Self {
57            tx,
58            rx: Mutex::new(rx),
59            shutdown: Mutex::new(None),
60        }
61    }
62
63    /// Verify an HMAC-SHA256 signature.
64    ///
65    /// The `signature` should be in the form `sha256=<hex>`.
66    fn verify_hmac(secret: &str, signature: &str, body: &[u8]) -> bool {
67        let Some(hex_sig) = signature.strip_prefix("sha256=") else {
68            return false;
69        };
70
71        let Ok(expected_bytes) = hex_decode(hex_sig) else {
72            return false;
73        };
74
75        let Ok(mut mac) = HmacSha256::new_from_slice(secret.as_bytes()) else {
76            return false;
77        };
78
79        mac.update(body);
80        mac.verify_slice(&expected_bytes).is_ok()
81    }
82}
83
84impl Default for AxumWebhookTrigger {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90/// Decode a hex string to bytes.
91fn hex_decode(hex: &str) -> std::result::Result<Vec<u8>, ()> {
92    if !hex.len().is_multiple_of(2) {
93        return Err(());
94    }
95    (0..hex.len())
96        .step_by(2)
97        .map(|i| u8::from_str_radix(&hex[i..i + 2], 16).map_err(|_| ()))
98        .collect()
99}
100
101// ─── Routes ───────────────────────────────────────────────────────────────────
102
103async fn trigger_handler(
104    State(state): State<AppState>,
105    headers: HeaderMap,
106    body: Bytes,
107) -> impl IntoResponse {
108    // Enforce body size
109    if body.len() > state.config.max_body_size {
110        warn!(
111            size = body.len(),
112            max = state.config.max_body_size,
113            "webhook body too large"
114        );
115        return StatusCode::PAYLOAD_TOO_LARGE;
116    }
117
118    let body_str = String::from_utf8_lossy(&body).to_string();
119
120    // Extract signature header
121    let signature = headers
122        .get("x-hub-signature-256")
123        .or_else(|| headers.get("x-signature-256"))
124        .and_then(|v| v.to_str().ok())
125        .map(String::from);
126
127    // Verify signature if secret is configured
128    if let Some(ref secret) = state.config.secret {
129        if let Some(sig) = &signature {
130            if !AxumWebhookTrigger::verify_hmac(secret, sig, &body) {
131                warn!("webhook signature verification failed");
132                return StatusCode::UNAUTHORIZED;
133            }
134            debug!("webhook signature verified");
135        } else {
136            warn!("webhook missing signature header, secret is configured");
137            return StatusCode::UNAUTHORIZED;
138        }
139    }
140
141    // Build filtered headers map
142    let filtered_headers: HashMap<String, String> = headers
143        .iter()
144        .filter_map(|(k, v)| {
145            let key = k.as_str().to_lowercase();
146            // Filter to relevant headers
147            if key.starts_with("x-")
148                || key == "content-type"
149                || key == "user-agent"
150                || key == "accept"
151            {
152                v.to_str().ok().map(|val| (key, val.to_string()))
153            } else {
154                None
155            }
156        })
157        .collect();
158
159    let source_ip = headers
160        .get("x-forwarded-for")
161        .or_else(|| headers.get("x-real-ip"))
162        .and_then(|v| v.to_str().ok())
163        .map(String::from);
164
165    let received_at_ms: u64 = SystemTime::now()
166        .duration_since(UNIX_EPOCH)
167        .unwrap_or_default()
168        .as_millis()
169        .try_into()
170        .unwrap_or(0);
171
172    let event = WebhookEvent {
173        method: "POST".into(),
174        path: state.config.path_prefix.clone(),
175        headers: filtered_headers,
176        body: body_str,
177        received_at_ms,
178        signature,
179        source_ip,
180    };
181
182    info!(path = %event.path, "webhook event received");
183
184    if state.tx.send(event).is_err() {
185        warn!("no webhook subscribers connected");
186    }
187
188    StatusCode::OK
189}
190
191async fn health_handler() -> impl IntoResponse {
192    StatusCode::OK
193}
194
195// ─── WebhookTrigger ───────────────────────────────────────────────────────────
196
197#[async_trait]
198impl WebhookTrigger for AxumWebhookTrigger {
199    async fn start_listener(&self, config: WebhookConfig) -> Result<WebhookListenerHandle> {
200        let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
201
202        let state = AppState {
203            config: config.clone(),
204            tx: self.tx.clone(),
205        };
206
207        let trigger_path = format!("{}/trigger", config.path_prefix);
208        let health_path = format!("{}/health", config.path_prefix);
209
210        let app = Router::new()
211            .route(&trigger_path, post(trigger_handler))
212            .route(&health_path, get(health_handler))
213            .with_state(state);
214
215        let listener = TcpListener::bind(&config.bind_address).await.map_err(|e| {
216            StygianError::Service(ServiceError::Unavailable(format!(
217                "failed to bind webhook listener on {}: {e}",
218                config.bind_address
219            )))
220        })?;
221
222        let handle_id = format!("webhook-{}", config.bind_address);
223
224        info!(bind = %config.bind_address, prefix = %config.path_prefix, "webhook listener started");
225
226        tokio::spawn(async move {
227            axum::serve(listener, app)
228                .with_graceful_shutdown(async {
229                    let _ = shutdown_rx.await;
230                })
231                .await
232                .ok();
233        });
234
235        *self.shutdown.lock().await = Some(shutdown_tx);
236
237        Ok(WebhookListenerHandle { id: handle_id })
238    }
239
240    async fn stop_listener(&self, handle: WebhookListenerHandle) -> Result<()> {
241        let shutdown_tx = {
242            let mut shutdown = self.shutdown.lock().await;
243            shutdown.take()
244        };
245        if let Some(tx) = shutdown_tx {
246            let _ = tx.send(());
247            info!(id = %handle.id, "webhook listener stopped");
248        }
249        Ok(())
250    }
251
252    async fn recv_event(&self) -> Result<Option<WebhookEvent>> {
253        let mut rx = self.rx.lock().await;
254        match rx.recv().await {
255            Ok(event) => Ok(Some(event)),
256            Err(broadcast::error::RecvError::Closed) => Ok(None),
257            Err(broadcast::error::RecvError::Lagged(n)) => {
258                warn!(skipped = n, "webhook receiver lagged, events dropped");
259                // Try again after lag
260                Ok(rx.recv().await.ok())
261            }
262        }
263    }
264
265    fn verify_signature(&self, secret: &str, signature: &str, body: &[u8]) -> bool {
266        Self::verify_hmac(secret, signature, body)
267    }
268}
269
270// ─── ScrapingService ──────────────────────────────────────────────────────────
271
272#[async_trait]
273impl ScrapingService for AxumWebhookTrigger {
274    /// Start a webhook listener and wait for the next event.
275    ///
276    /// The `input.url` is used as the bind address. Params:
277    /// - `"path_prefix"`: URL path prefix (default: `"/webhooks"`)
278    /// - `"secret"`: Optional HMAC secret
279    /// - `"timeout_secs"`: Max seconds to wait for an event (default: 60)
280    async fn execute(&self, input: ServiceInput) -> Result<ServiceOutput> {
281        let path_prefix = input
282            .params
283            .get("path_prefix")
284            .and_then(|v| v.as_str())
285            .unwrap_or("/webhooks")
286            .to_string();
287
288        let secret = input
289            .params
290            .get("secret")
291            .and_then(|v| v.as_str())
292            .map(String::from);
293
294        let timeout_secs = input
295            .params
296            .get("timeout_secs")
297            .and_then(serde_json::Value::as_u64)
298            .unwrap_or(60);
299
300        let config = WebhookConfig {
301            bind_address: input.url.clone(),
302            path_prefix,
303            secret,
304            max_body_size: 1_048_576,
305        };
306
307        let handle = self.start_listener(config).await?;
308
309        let event = tokio::time::timeout(
310            std::time::Duration::from_secs(timeout_secs),
311            self.recv_event(),
312        )
313        .await;
314
315        // Stop listener regardless of outcome
316        let _ = self.stop_listener(handle).await;
317
318        match event {
319            Ok(Ok(Some(evt))) => Ok(ServiceOutput {
320                data: evt.body.clone(),
321                metadata: json!({
322                    "source": "webhook",
323                    "method": evt.method,
324                    "path": evt.path,
325                    "received_at_ms": evt.received_at_ms,
326                    "source_ip": evt.source_ip,
327                }),
328            }),
329            Ok(Ok(None)) => Err(StygianError::Service(ServiceError::Unavailable(
330                "webhook listener closed without receiving event".into(),
331            ))),
332            Ok(Err(e)) => Err(e),
333            Err(_) => Err(StygianError::Service(ServiceError::Timeout(
334                timeout_secs * 1000,
335            ))),
336        }
337    }
338
339    fn name(&self) -> &'static str {
340        "webhook-trigger"
341    }
342}
343
344// ─── Tests ────────────────────────────────────────────────────────────────────
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use std::fmt::Write as _;
350
351    #[test]
352    fn test_hex_decode_valid() -> std::result::Result<(), Box<dyn std::error::Error>> {
353        let result =
354            hex_decode("48656c6c6f").map_err(|()| std::io::Error::other("hex decode failed"))?;
355        assert_eq!(result, b"Hello");
356        Ok(())
357    }
358
359    #[test]
360    fn test_hex_decode_empty() -> std::result::Result<(), Box<dyn std::error::Error>> {
361        let result = hex_decode("").map_err(|()| std::io::Error::other("hex decode failed"))?;
362        assert!(result.is_empty());
363        Ok(())
364    }
365
366    #[test]
367    fn test_hex_decode_odd_length() {
368        assert!(hex_decode("abc").is_err());
369    }
370
371    #[test]
372    fn test_hex_decode_invalid_chars() {
373        assert!(hex_decode("zzzz").is_err());
374    }
375
376    #[test]
377    fn test_verify_hmac_valid() -> std::result::Result<(), Box<dyn std::error::Error>> {
378        let secret = "test-secret";
379        let body = b"test body";
380
381        // Compute expected signature
382        let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
383            .map_err(|err| std::io::Error::other(format!("hmac init failed: {err}")))?;
384        mac.update(body);
385        let result = mac.finalize();
386        let mut hex = String::with_capacity(64);
387        for b in result.into_bytes() {
388            write!(hex, "{b:02x}")?;
389        }
390        let signature = format!("sha256={hex}");
391
392        assert!(AxumWebhookTrigger::verify_hmac(secret, &signature, body));
393        Ok(())
394    }
395
396    #[test]
397    fn test_verify_hmac_invalid_signature() {
398        assert!(!AxumWebhookTrigger::verify_hmac(
399            "secret",
400            "sha256=invalidhex",
401            b"body"
402        ));
403    }
404
405    #[test]
406    fn test_verify_hmac_wrong_prefix() {
407        assert!(!AxumWebhookTrigger::verify_hmac(
408            "secret",
409            "md5=abc123",
410            b"body"
411        ));
412    }
413
414    #[test]
415    fn test_verify_hmac_wrong_secret() -> std::result::Result<(), Box<dyn std::error::Error>> {
416        let body = b"test body";
417        let mut mac = HmacSha256::new_from_slice(b"correct-secret")
418            .map_err(|err| std::io::Error::other(format!("hmac init failed: {err}")))?;
419        mac.update(body);
420        let result = mac.finalize();
421        let mut hex = String::with_capacity(64);
422        for b in result.into_bytes() {
423            write!(hex, "{b:02x}")?;
424        }
425        let signature = format!("sha256={hex}");
426
427        assert!(!AxumWebhookTrigger::verify_hmac(
428            "wrong-secret",
429            &signature,
430            body
431        ));
432        Ok(())
433    }
434
435    #[test]
436    fn test_default_trigger() {
437        let trigger = AxumWebhookTrigger::default();
438        assert_eq!(trigger.name(), "webhook-trigger");
439    }
440
441    #[tokio::test]
442    async fn test_start_and_stop_listener() -> std::result::Result<(), Box<dyn std::error::Error>> {
443        let trigger = AxumWebhookTrigger::new();
444        let config = WebhookConfig {
445            bind_address: "127.0.0.1:0".into(), // OS-assigned port
446            ..Default::default()
447        };
448
449        let handle = trigger.start_listener(config).await?;
450        assert!(handle.id.starts_with("webhook-"));
451
452        trigger.stop_listener(handle).await?;
453        Ok(())
454    }
455}