Skip to main content

stygian_graph/adapters/
websocket.rs

1//! WebSocket stream source adapter.
2//!
3//! Implements [`StreamSourcePort`](crate::ports::stream_source::StreamSourcePort) and [`ScrapingService`](crate::ports::ScrapingService) for consuming
4//! WebSocket feeds.  Uses `tokio-tungstenite` for the underlying connection.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use stygian_graph::adapters::websocket::WebSocketSource;
10//! use stygian_graph::ports::stream_source::StreamSourcePort;
11//!
12//! # async fn example() {
13//! let source = WebSocketSource::default();
14//! let events = source.subscribe("wss://api.example.com/ws", Some(10)).await.unwrap();
15//! println!("received {} events", events.len());
16//! # }
17//! ```
18
19use async_trait::async_trait;
20use futures::stream::StreamExt;
21use serde_json::json;
22use std::time::Duration;
23use tokio::time::timeout;
24use tokio_tungstenite::tungstenite::Message;
25use tokio_tungstenite::tungstenite::client::IntoClientRequest;
26
27use crate::domain::error::{Result, ServiceError, StygianError};
28use crate::ports::stream_source::{StreamEvent, StreamSourcePort};
29use crate::ports::{ScrapingService, ServiceInput, ServiceOutput};
30
31// ─── Configuration ────────────────────────────────────────────────────────────
32
33/// Configuration for a WebSocket connection.
34///
35/// # Example
36///
37/// ```
38/// use stygian_graph::adapters::websocket::WebSocketConfig;
39///
40/// let config = WebSocketConfig {
41///     subscribe_message: Some(r#"{"type":"subscribe","channel":"prices"}"#.into()),
42///     bearer_token: None,
43///     timeout_secs: 30,
44///     max_reconnect_attempts: 3,
45/// };
46/// ```
47#[derive(Debug, Clone)]
48pub struct WebSocketConfig {
49    /// Optional message to send immediately after connecting (e.g. subscribe).
50    pub subscribe_message: Option<String>,
51    /// Optional Bearer token for Authorization header on the upgrade request.
52    pub bearer_token: Option<String>,
53    /// Connection timeout in seconds.
54    pub timeout_secs: u64,
55    /// Maximum reconnection attempts on connection drop.
56    pub max_reconnect_attempts: u32,
57}
58
59impl Default for WebSocketConfig {
60    fn default() -> Self {
61        Self {
62            subscribe_message: None,
63            bearer_token: None,
64            timeout_secs: 30,
65            max_reconnect_attempts: 3,
66        }
67    }
68}
69
70// ─── Adapter ──────────────────────────────────────────────────────────────────
71
72/// WebSocket stream source adapter.
73///
74/// Connects to a WebSocket endpoint and collects messages until `max_events`
75/// is reached, the stream closes, or a connection timeout occurs.
76#[derive(Default)]
77pub struct WebSocketSource {
78    config: WebSocketConfig,
79}
80
81impl WebSocketSource {
82    /// Create a new WebSocket source with custom configuration.
83    #[must_use]
84    pub const fn new(config: WebSocketConfig) -> Self {
85        Self { config }
86    }
87
88    /// Extract configuration from `ServiceInput.params` overrides.
89    fn config_from_params(&self, params: &serde_json::Value) -> WebSocketConfig {
90        let mut cfg = self.config.clone();
91        if let Some(msg) = params.get("subscribe_message").and_then(|v| v.as_str()) {
92            cfg.subscribe_message = Some(msg.to_string());
93        }
94        if let Some(token) = params.get("bearer_token").and_then(|v| v.as_str()) {
95            cfg.bearer_token = Some(token.to_string());
96        }
97        if let Some(t) = params
98            .get("timeout_secs")
99            .and_then(serde_json::Value::as_u64)
100        {
101            cfg.timeout_secs = t;
102        }
103        if let Some(r) = params
104            .get("max_reconnect_attempts")
105            .and_then(serde_json::Value::as_u64)
106        {
107            cfg.max_reconnect_attempts = u32::try_from(r).unwrap_or(u32::MAX);
108        }
109        cfg
110    }
111
112    /// Connect and collect events from a WebSocket endpoint.
113    async fn collect_events(
114        &self,
115        url: &str,
116        max_events: Option<usize>,
117        cfg: &WebSocketConfig,
118    ) -> Result<Vec<StreamEvent>> {
119        let mut request = url.into_client_request().map_err(|e| {
120            StygianError::Service(ServiceError::Unavailable(format!(
121                "invalid WebSocket URL: {e}"
122            )))
123        })?;
124
125        // Inject auth header if configured
126        if let Some(token) = &cfg.bearer_token {
127            request.headers_mut().insert(
128                reqwest::header::AUTHORIZATION,
129                format!("Bearer {token}").parse().map_err(|e| {
130                    StygianError::Service(ServiceError::Unavailable(format!(
131                        "invalid auth header: {e}"
132                    )))
133                })?,
134            );
135        }
136
137        let connect_timeout = Duration::from_secs(cfg.timeout_secs);
138        let (ws_stream, _) = timeout(connect_timeout, tokio_tungstenite::connect_async(request))
139            .await
140            .map_err(|_| {
141                StygianError::Service(ServiceError::Unavailable(
142                    "WebSocket connection timed out".into(),
143                ))
144            })?
145            .map_err(|e| {
146                StygianError::Service(ServiceError::Unavailable(format!(
147                    "WebSocket connection failed: {e}"
148                )))
149            })?;
150
151        let (mut write, mut read) = ws_stream.split();
152
153        // Send subscribe message if configured
154        if let Some(ref sub_msg) = cfg.subscribe_message {
155            use futures::SinkExt;
156            write
157                .send(Message::Text(sub_msg.clone().into()))
158                .await
159                .map_err(|e| {
160                    StygianError::Service(ServiceError::Unavailable(format!(
161                        "failed to send subscribe message: {e}"
162                    )))
163                })?;
164        }
165
166        let mut events = Vec::new();
167        let mut frame_idx: u64 = 0;
168
169        while let Some(msg_result) = timeout(Duration::from_secs(cfg.timeout_secs), read.next())
170            .await
171            .ok()
172            .flatten()
173        {
174            match msg_result {
175                Ok(msg) => {
176                    if let Some(event) = map_message_to_event(msg, frame_idx) {
177                        events.push(event);
178                        frame_idx += 1;
179
180                        if let Some(max) = max_events
181                            && events.len() >= max
182                        {
183                            break;
184                        }
185                    }
186                }
187                Err(e) => {
188                    tracing::warn!("WebSocket receive error: {e}");
189                    break;
190                }
191            }
192        }
193
194        Ok(events)
195    }
196}
197
198/// Map a WebSocket message to a [`StreamEvent`].
199///
200/// Returns `None` for internal frames (Pong, Close, Frame).
201fn map_message_to_event(msg: Message, frame_idx: u64) -> Option<StreamEvent> {
202    match msg {
203        Message::Text(text) => Some(StreamEvent {
204            id: Some(frame_idx.to_string()),
205            event_type: Some("text".into()),
206            data: text.to_string(),
207        }),
208        Message::Binary(data) => {
209            use base64::Engine;
210            let encoded = base64::engine::general_purpose::STANDARD.encode(&data);
211            Some(StreamEvent {
212                id: Some(frame_idx.to_string()),
213                event_type: Some("binary".into()),
214                data: encoded,
215            })
216        }
217        Message::Ping(data) => Some(StreamEvent {
218            id: Some(frame_idx.to_string()),
219            event_type: Some("ping".into()),
220            data: String::from_utf8_lossy(&data).to_string(),
221        }),
222        // Pong, Close, and Frame are internal — skip
223        Message::Pong(_) | Message::Close(_) | Message::Frame(_) => None,
224    }
225}
226
227// ─── StreamSourcePort ─────────────────────────────────────────────────────────
228
229#[async_trait]
230impl StreamSourcePort for WebSocketSource {
231    async fn subscribe(&self, url: &str, max_events: Option<usize>) -> Result<Vec<StreamEvent>> {
232        let cfg = self.config.clone();
233        let mut last_err = None;
234
235        for attempt in 0..=cfg.max_reconnect_attempts {
236            match self.collect_events(url, max_events, &cfg).await {
237                Ok(events) => return Ok(events),
238                Err(e) => {
239                    tracing::warn!(
240                        "WebSocket attempt {}/{} failed: {e}",
241                        attempt + 1,
242                        cfg.max_reconnect_attempts + 1
243                    );
244                    last_err = Some(e);
245
246                    if attempt < cfg.max_reconnect_attempts {
247                        // Exponential backoff: 1s, 2s, 4s ...
248                        let backoff = Duration::from_secs(1 << attempt);
249                        tokio::time::sleep(backoff).await;
250                    }
251                }
252            }
253        }
254
255        Err(last_err.unwrap_or_else(|| {
256            StygianError::Service(ServiceError::Unavailable(
257                "WebSocket connection failed after all retries".into(),
258            ))
259        }))
260    }
261
262    fn source_name(&self) -> &'static str {
263        "websocket"
264    }
265}
266
267// ─── ScrapingService ──────────────────────────────────────────────────────────
268
269#[async_trait]
270impl ScrapingService for WebSocketSource {
271    /// Collect messages from a WebSocket and return as JSON array.
272    ///
273    /// # Params (optional)
274    ///
275    /// * `max_events` — integer; maximum messages to collect.
276    /// * `subscribe_message` — string; message to send on connect.
277    /// * `bearer_token` — string; Bearer token for auth header.
278    /// * `timeout_secs` — integer; connection/read timeout.
279    async fn execute(&self, input: ServiceInput) -> Result<ServiceOutput> {
280        let cfg = self.config_from_params(&input.params);
281        let max_events = input
282            .params
283            .get("max_events")
284            .and_then(serde_json::Value::as_u64)
285            .map(|n| usize::try_from(n).unwrap_or(usize::MAX));
286
287        let events = self.collect_events(&input.url, max_events, &cfg).await?;
288        let count = events.len();
289
290        let data = serde_json::to_string(&events).map_err(|e| {
291            StygianError::Service(ServiceError::InvalidResponse(format!(
292                "websocket serialization failed: {e}"
293            )))
294        })?;
295
296        Ok(ServiceOutput {
297            data,
298            metadata: json!({
299                "source": "websocket",
300                "event_count": count,
301                "source_url": input.url,
302            }),
303        })
304    }
305
306    fn name(&self) -> &'static str {
307        "websocket"
308    }
309}
310
311// ─── Tests ────────────────────────────────────────────────────────────────────
312
313#[cfg(test)]
314mod tests {
315    use base64::Engine;
316
317    use super::*;
318
319    #[test]
320    fn map_text_frame() -> std::result::Result<(), Box<dyn std::error::Error>> {
321        let msg = Message::Text(r#"{"price": 42.5}"#.into());
322        let event =
323            map_message_to_event(msg, 0).ok_or_else(|| std::io::Error::other("should map"))?;
324        assert_eq!(event.id.as_deref(), Some("0"));
325        assert_eq!(event.event_type.as_deref(), Some("text"));
326        assert_eq!(event.data, r#"{"price": 42.5}"#);
327        Ok(())
328    }
329
330    #[test]
331    fn map_binary_frame_to_base64() -> std::result::Result<(), Box<dyn std::error::Error>> {
332        let data = vec![0xDE, 0xAD, 0xBE, 0xEF];
333        let msg = Message::Binary(data.into());
334        let event =
335            map_message_to_event(msg, 1).ok_or_else(|| std::io::Error::other("should map"))?;
336        assert_eq!(event.event_type.as_deref(), Some("binary"));
337        // Verify it's valid base64
338        let decoded = base64::engine::general_purpose::STANDARD.decode(&event.data)?;
339        assert_eq!(decoded, vec![0xDE, 0xAD, 0xBE, 0xEF]);
340        Ok(())
341    }
342
343    #[test]
344    fn map_ping_frame() -> std::result::Result<(), Box<dyn std::error::Error>> {
345        let msg = Message::Ping(vec![1, 2, 3].into());
346        let event =
347            map_message_to_event(msg, 2).ok_or_else(|| std::io::Error::other("should map"))?;
348        assert_eq!(event.event_type.as_deref(), Some("ping"));
349        Ok(())
350    }
351
352    #[test]
353    fn pong_frame_is_skipped() {
354        let msg = Message::Pong(vec![].into());
355        assert!(map_message_to_event(msg, 0).is_none());
356    }
357
358    #[test]
359    fn close_frame_is_skipped() {
360        let msg = Message::Close(None);
361        assert!(map_message_to_event(msg, 0).is_none());
362    }
363
364    #[test]
365    fn default_config() {
366        let cfg = WebSocketConfig::default();
367        assert_eq!(cfg.timeout_secs, 30);
368        assert_eq!(cfg.max_reconnect_attempts, 3);
369        assert!(cfg.subscribe_message.is_none());
370        assert!(cfg.bearer_token.is_none());
371    }
372
373    #[test]
374    fn config_from_params_overrides() {
375        let source = WebSocketSource::default();
376        let params = json!({
377            "subscribe_message": "{\"action\":\"sub\"}",
378            "bearer_token": "tok123",
379            "timeout_secs": 60,
380            "max_reconnect_attempts": 5
381        });
382        let cfg = source.config_from_params(&params);
383        assert_eq!(
384            cfg.subscribe_message.as_deref(),
385            Some("{\"action\":\"sub\"}")
386        );
387        assert_eq!(cfg.bearer_token.as_deref(), Some("tok123"));
388        assert_eq!(cfg.timeout_secs, 60);
389        assert_eq!(cfg.max_reconnect_attempts, 5);
390    }
391
392    #[test]
393    fn frame_index_increments() {
394        let msgs = vec![
395            Message::Text("a".into()),
396            Message::Pong(vec![].into()), // skipped
397            Message::Text("b".into()),
398        ];
399
400        let mut idx: u64 = 0;
401        let mut events = Vec::new();
402        for msg in msgs {
403            if let Some(event) = map_message_to_event(msg, idx) {
404                events.push(event);
405                idx += 1;
406            }
407        }
408
409        assert_eq!(events.len(), 2);
410        assert_eq!(events.first().and_then(|e| e.id.as_deref()), Some("0"));
411        assert_eq!(events.get(1).and_then(|e| e.id.as_deref()), Some("1"));
412    }
413
414    // Integration tests require a running WebSocket server — marked #[ignore]
415    #[tokio::test]
416    #[ignore = "requires WebSocket echo server"]
417    async fn connect_to_echo_server() -> std::result::Result<(), Box<dyn std::error::Error>> {
418        let source = WebSocketSource::new(WebSocketConfig {
419            subscribe_message: Some("hello".into()),
420            timeout_secs: 5,
421            ..WebSocketConfig::default()
422        });
423        let events = source
424            .subscribe("ws://127.0.0.1:9001/echo", Some(1))
425            .await?;
426        assert!(!events.is_empty());
427        Ok(())
428    }
429}