Skip to main content

stygian_graph/adapters/
webhook.rs

1//! Webhook trigger adapter — axum-based HTTP listener.
2//!
3//! Implements [`WebhookTrigger`](crate::ports::webhook::WebhookTrigger) with an embedded axum server that accepts
4//! inbound webhooks, verifies HMAC-SHA256 signatures, enforces body-size limits,
5//! and emits [`WebhookEvent`](crate::ports::webhook::WebhookEvent)s via a channel.
6//!
7//! Also implements [`ScrapingService`](crate::ports::ScrapingService) so a pipeline node can start a webhook
8//! listener and wait for the next event as input.
9//!
10//! # Feature gate
11//!
12//! Requires `feature = "api"`.
13
14use crate::domain::error::{Result, ServiceError, StygianError};
15use crate::ports::webhook::{WebhookConfig, WebhookEvent, WebhookListenerHandle, WebhookTrigger};
16use crate::ports::{ScrapingService, ServiceInput, ServiceOutput};
17use async_trait::async_trait;
18use axum::Router;
19use axum::body::Bytes;
20use axum::extract::State;
21use axum::http::{HeaderMap, StatusCode};
22use axum::response::IntoResponse;
23use axum::routing::{get, post};
24use hmac::{Hmac, KeyInit, Mac};
25use serde_json::json;
26use sha2::Sha256;
27use std::collections::HashMap;
28use std::time::{SystemTime, UNIX_EPOCH};
29use tokio::net::TcpListener;
30use tokio::sync::{Mutex, broadcast};
31use tracing::{debug, info, warn};
32
33type HmacSha256 = Hmac<Sha256>;
34
35// ─── Shared state ─────────────────────────────────────────────────────────────
36
37#[derive(Clone)]
38struct AppState {
39    config: WebhookConfig,
40    tx: broadcast::Sender<WebhookEvent>,
41}
42
43// ─── Adapter ──────────────────────────────────────────────────────────────────
44
45/// Axum-based webhook trigger adapter.
46pub struct AxumWebhookTrigger {
47    tx: broadcast::Sender<WebhookEvent>,
48    rx: Mutex<broadcast::Receiver<WebhookEvent>>,
49    shutdown: Mutex<Option<tokio::sync::oneshot::Sender<()>>>,
50}
51
52impl AxumWebhookTrigger {
53    /// Create a new [`AxumWebhookTrigger`].
54    #[must_use]
55    pub fn new() -> Self {
56        let (tx, rx) = broadcast::channel(256);
57        Self {
58            tx,
59            rx: Mutex::new(rx),
60            shutdown: Mutex::new(None),
61        }
62    }
63
64    /// Verify an HMAC-SHA256 signature.
65    ///
66    /// The `signature` should be in the form `sha256=<hex>`.
67    fn verify_hmac(secret: &str, signature: &str, body: &[u8]) -> bool {
68        let Some(hex_sig) = signature.strip_prefix("sha256=") else {
69            return false;
70        };
71
72        let Ok(expected_bytes) = hex_decode(hex_sig) else {
73            return false;
74        };
75
76        let Ok(mut mac) = HmacSha256::new_from_slice(secret.as_bytes()) else {
77            return false;
78        };
79
80        mac.update(body);
81        mac.verify_slice(&expected_bytes).is_ok()
82    }
83}
84
85impl Default for AxumWebhookTrigger {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91/// Decode a hex string to bytes.
92fn hex_decode(hex: &str) -> std::result::Result<Vec<u8>, ()> {
93    if !hex.len().is_multiple_of(2) {
94        return Err(());
95    }
96    (0..hex.len())
97        .step_by(2)
98        .map(|i| u8::from_str_radix(&hex[i..i + 2], 16).map_err(|_| ()))
99        .collect()
100}
101
102// ─── Routes ───────────────────────────────────────────────────────────────────
103
104async fn trigger_handler(
105    State(state): State<AppState>,
106    headers: HeaderMap,
107    body: Bytes,
108) -> impl IntoResponse {
109    // Enforce body size
110    if body.len() > state.config.max_body_size {
111        warn!(
112            size = body.len(),
113            max = state.config.max_body_size,
114            "webhook body too large"
115        );
116        return StatusCode::PAYLOAD_TOO_LARGE;
117    }
118
119    let body_str = String::from_utf8_lossy(&body).to_string();
120
121    // Extract signature header
122    let signature = headers
123        .get("x-hub-signature-256")
124        .or_else(|| headers.get("x-signature-256"))
125        .and_then(|v| v.to_str().ok())
126        .map(String::from);
127
128    // Verify signature if secret is configured
129    if let Some(ref secret) = state.config.secret {
130        if let Some(sig) = &signature {
131            if !AxumWebhookTrigger::verify_hmac(secret, sig, &body) {
132                warn!("webhook signature verification failed");
133                return StatusCode::UNAUTHORIZED;
134            }
135            debug!("webhook signature verified");
136        } else {
137            warn!("webhook missing signature header, secret is configured");
138            return StatusCode::UNAUTHORIZED;
139        }
140    }
141
142    // Build filtered headers map
143    let filtered_headers: HashMap<String, String> = headers
144        .iter()
145        .filter_map(|(k, v)| {
146            let key = k.as_str().to_lowercase();
147            // Filter to relevant headers
148            if key.starts_with("x-")
149                || key == "content-type"
150                || key == "user-agent"
151                || key == "accept"
152            {
153                v.to_str().ok().map(|val| (key, val.to_string()))
154            } else {
155                None
156            }
157        })
158        .collect();
159
160    let source_ip = headers
161        .get("x-forwarded-for")
162        .or_else(|| headers.get("x-real-ip"))
163        .and_then(|v| v.to_str().ok())
164        .map(String::from);
165
166    let received_at_ms: u64 = SystemTime::now()
167        .duration_since(UNIX_EPOCH)
168        .unwrap_or_default()
169        .as_millis()
170        .try_into()
171        .unwrap_or(0);
172
173    let event = WebhookEvent {
174        method: "POST".into(),
175        path: state.config.path_prefix.clone(),
176        headers: filtered_headers,
177        body: body_str,
178        received_at_ms,
179        signature,
180        source_ip,
181    };
182
183    info!(path = %event.path, "webhook event received");
184
185    if state.tx.send(event).is_err() {
186        warn!("no webhook subscribers connected");
187    }
188
189    StatusCode::OK
190}
191
192async fn health_handler() -> impl IntoResponse {
193    StatusCode::OK
194}
195
196// ─── WebhookTrigger ───────────────────────────────────────────────────────────
197
198#[async_trait]
199impl WebhookTrigger for AxumWebhookTrigger {
200    async fn start_listener(&self, config: WebhookConfig) -> Result<WebhookListenerHandle> {
201        let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
202
203        let state = AppState {
204            config: config.clone(),
205            tx: self.tx.clone(),
206        };
207
208        let trigger_path = format!("{}/trigger", config.path_prefix);
209        let health_path = format!("{}/health", config.path_prefix);
210
211        let app = Router::new()
212            .route(&trigger_path, post(trigger_handler))
213            .route(&health_path, get(health_handler))
214            .with_state(state);
215
216        let listener = TcpListener::bind(&config.bind_address).await.map_err(|e| {
217            StygianError::Service(ServiceError::Unavailable(format!(
218                "failed to bind webhook listener on {}: {e}",
219                config.bind_address
220            )))
221        })?;
222
223        let handle_id = format!("webhook-{}", config.bind_address);
224
225        info!(bind = %config.bind_address, prefix = %config.path_prefix, "webhook listener started");
226
227        tokio::spawn(async move {
228            axum::serve(listener, app)
229                .with_graceful_shutdown(async {
230                    let _ = shutdown_rx.await;
231                })
232                .await
233                .ok();
234        });
235
236        *self.shutdown.lock().await = Some(shutdown_tx);
237
238        Ok(WebhookListenerHandle { id: handle_id })
239    }
240
241    async fn stop_listener(&self, handle: WebhookListenerHandle) -> Result<()> {
242        let shutdown_tx = {
243            let mut shutdown = self.shutdown.lock().await;
244            shutdown.take()
245        };
246        if let Some(tx) = shutdown_tx {
247            let _ = tx.send(());
248            info!(id = %handle.id, "webhook listener stopped");
249        }
250        Ok(())
251    }
252
253    async fn recv_event(&self) -> Result<Option<WebhookEvent>> {
254        let mut rx = self.rx.lock().await;
255        match rx.recv().await {
256            Ok(event) => Ok(Some(event)),
257            Err(broadcast::error::RecvError::Closed) => Ok(None),
258            Err(broadcast::error::RecvError::Lagged(n)) => {
259                warn!(skipped = n, "webhook receiver lagged, events dropped");
260                // Try again after lag
261                Ok(rx.recv().await.ok())
262            }
263        }
264    }
265
266    fn verify_signature(&self, secret: &str, signature: &str, body: &[u8]) -> bool {
267        Self::verify_hmac(secret, signature, body)
268    }
269}
270
271// ─── ScrapingService ──────────────────────────────────────────────────────────
272
273#[async_trait]
274impl ScrapingService for AxumWebhookTrigger {
275    /// Start a webhook listener and wait for the next event.
276    ///
277    /// The `input.url` is used as the bind address. Params:
278    /// - `"path_prefix"`: URL path prefix (default: `"/webhooks"`)
279    /// - `"secret"`: Optional HMAC secret
280    /// - `"timeout_secs"`: Max seconds to wait for an event (default: 60)
281    async fn execute(&self, input: ServiceInput) -> Result<ServiceOutput> {
282        let path_prefix = input
283            .params
284            .get("path_prefix")
285            .and_then(|v| v.as_str())
286            .unwrap_or("/webhooks")
287            .to_string();
288
289        let secret = input
290            .params
291            .get("secret")
292            .and_then(|v| v.as_str())
293            .map(String::from);
294
295        let timeout_secs = input
296            .params
297            .get("timeout_secs")
298            .and_then(serde_json::Value::as_u64)
299            .unwrap_or(60);
300
301        let config = WebhookConfig {
302            bind_address: input.url.clone(),
303            path_prefix,
304            secret,
305            max_body_size: 1_048_576,
306        };
307
308        let handle = self.start_listener(config).await?;
309
310        let event = tokio::time::timeout(
311            std::time::Duration::from_secs(timeout_secs),
312            self.recv_event(),
313        )
314        .await;
315
316        // Stop listener regardless of outcome
317        let _ = self.stop_listener(handle).await;
318
319        match event {
320            Ok(Ok(Some(evt))) => Ok(ServiceOutput {
321                data: evt.body.clone(),
322                metadata: json!({
323                    "source": "webhook",
324                    "method": evt.method,
325                    "path": evt.path,
326                    "received_at_ms": evt.received_at_ms,
327                    "source_ip": evt.source_ip,
328                }),
329            }),
330            Ok(Ok(None)) => Err(StygianError::Service(ServiceError::Unavailable(
331                "webhook listener closed without receiving event".into(),
332            ))),
333            Ok(Err(e)) => Err(e),
334            Err(_) => Err(StygianError::Service(ServiceError::Timeout(
335                timeout_secs * 1000,
336            ))),
337        }
338    }
339
340    fn name(&self) -> &'static str {
341        "webhook-trigger"
342    }
343}
344
345// ─── Tests ────────────────────────────────────────────────────────────────────
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use std::fmt::Write as _;
351
352    #[test]
353    fn test_hex_decode_valid() -> std::result::Result<(), Box<dyn std::error::Error>> {
354        let result =
355            hex_decode("48656c6c6f").map_err(|()| std::io::Error::other("hex decode failed"))?;
356        assert_eq!(result, b"Hello");
357        Ok(())
358    }
359
360    #[test]
361    fn test_hex_decode_empty() -> std::result::Result<(), Box<dyn std::error::Error>> {
362        let result = hex_decode("").map_err(|()| std::io::Error::other("hex decode failed"))?;
363        assert!(result.is_empty());
364        Ok(())
365    }
366
367    #[test]
368    fn test_hex_decode_odd_length() {
369        assert!(hex_decode("abc").is_err());
370    }
371
372    #[test]
373    fn test_hex_decode_invalid_chars() {
374        assert!(hex_decode("zzzz").is_err());
375    }
376
377    #[test]
378    fn test_verify_hmac_valid() -> std::result::Result<(), Box<dyn std::error::Error>> {
379        let secret = "test-secret";
380        let body = b"test body";
381
382        // Compute expected signature
383        let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
384            .map_err(|err| std::io::Error::other(format!("hmac init failed: {err}")))?;
385        mac.update(body);
386        let result = mac.finalize();
387        let mut hex = String::with_capacity(64);
388        for b in result.into_bytes() {
389            write!(hex, "{b:02x}")?;
390        }
391        let signature = format!("sha256={hex}");
392
393        assert!(AxumWebhookTrigger::verify_hmac(secret, &signature, body));
394        Ok(())
395    }
396
397    #[test]
398    fn test_verify_hmac_invalid_signature() {
399        assert!(!AxumWebhookTrigger::verify_hmac(
400            "secret",
401            "sha256=invalidhex",
402            b"body"
403        ));
404    }
405
406    #[test]
407    fn test_verify_hmac_wrong_prefix() {
408        assert!(!AxumWebhookTrigger::verify_hmac(
409            "secret",
410            "md5=abc123",
411            b"body"
412        ));
413    }
414
415    #[test]
416    fn test_verify_hmac_wrong_secret() -> std::result::Result<(), Box<dyn std::error::Error>> {
417        let body = b"test body";
418        let mut mac = HmacSha256::new_from_slice(b"correct-secret")
419            .map_err(|err| std::io::Error::other(format!("hmac init failed: {err}")))?;
420        mac.update(body);
421        let result = mac.finalize();
422        let mut hex = String::with_capacity(64);
423        for b in result.into_bytes() {
424            write!(hex, "{b:02x}")?;
425        }
426        let signature = format!("sha256={hex}");
427
428        assert!(!AxumWebhookTrigger::verify_hmac(
429            "wrong-secret",
430            &signature,
431            body
432        ));
433        Ok(())
434    }
435
436    #[test]
437    fn test_default_trigger() {
438        let trigger = AxumWebhookTrigger::default();
439        assert_eq!(trigger.name(), "webhook-trigger");
440    }
441
442    #[tokio::test]
443    async fn test_start_and_stop_listener() -> std::result::Result<(), Box<dyn std::error::Error>> {
444        let trigger = AxumWebhookTrigger::new();
445        let config = WebhookConfig {
446            bind_address: "127.0.0.1:0".into(), // OS-assigned port
447            ..Default::default()
448        };
449
450        let handle = trigger.start_listener(config).await?;
451        assert!(handle.id.starts_with("webhook-"));
452
453        trigger.stop_listener(handle).await?;
454        Ok(())
455    }
456}