diff --git a/src/main.rs b/src/main.rs index 8c3ed5a..fa0fdf4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,7 @@ mod html; mod irc_task; mod owncast_api; mod webhook; +mod websocket; fn main() { println!("owncast-irc-bridge"); diff --git a/src/websocket.rs b/src/websocket.rs new file mode 100644 index 0000000..3f55406 --- /dev/null +++ b/src/websocket.rs @@ -0,0 +1,155 @@ +use std::time::Duration; + +use futures_util::StreamExt; +use tokio::sync::mpsc; +use tokio_tungstenite::connect_async; +use tracing::{info, warn}; + +use crate::events::{BridgeEvent, Source}; +use crate::html::strip_html; + +pub async fn run_websocket_task( + owncast_url: String, + event_tx: mpsc::Sender, + mut shutdown: tokio::sync::watch::Receiver, +) { + let mut backoff = Duration::from_secs(1); + let max_backoff = Duration::from_secs(60); + + loop { + let ws_url = build_ws_url(&owncast_url); + info!(url = %ws_url, "Connecting to Owncast WebSocket"); + + match connect_and_listen(&ws_url, &event_tx, &mut shutdown).await { + Ok(()) => { + info!("WebSocket task exiting cleanly"); + return; + } + Err(e) => { + warn!(error = %e, "WebSocket connection error"); + info!(backoff_secs = backoff.as_secs(), "Reconnecting after backoff"); + + tokio::select! { + _ = tokio::time::sleep(backoff) => {}, + _ = shutdown.changed() => return, + } + + backoff = (backoff * 2).min(max_backoff); + } + } + } +} + +fn build_ws_url(base_url: &str) -> String { + let base = base_url.trim_end_matches('/'); + let ws_base = if base.starts_with("https://") { + base.replacen("https://", "wss://", 1) + } else { + base.replacen("http://", "ws://", 1) + }; + format!("{}/ws", ws_base) +} + +async fn connect_and_listen( + ws_url: &str, + event_tx: &mpsc::Sender, + shutdown: &mut tokio::sync::watch::Receiver, +) -> anyhow::Result<()> { + let (ws_stream, _) = connect_async(ws_url).await?; + let (_write, mut read) = ws_stream.split(); + + info!("WebSocket connected"); + + loop { + tokio::select! { + msg = read.next() => { + match msg { + Some(Ok(ws_msg)) => { + if let Ok(text) = ws_msg.into_text() { + if let Some(event) = parse_ws_message(&text) { + if event_tx.send(event).await.is_err() { + return Ok(()); + } + } + } + } + Some(Err(e)) => return Err(e.into()), + None => return Err(anyhow::anyhow!("WebSocket stream ended")), + } + } + _ = shutdown.changed() => return Ok(()), + } + } +} + +fn parse_ws_message(text: &str) -> Option { + let value: serde_json::Value = serde_json::from_str(text).ok()?; + let msg_type = value.get("type")?.as_str()?; + + match msg_type { + "CHAT" => { + let user = value.get("user")?; + let is_bot = user.get("isBot").and_then(|v| v.as_bool()).unwrap_or(false); + if is_bot { + return None; + } + let display_name = user.get("displayName")?.as_str()?.to_string(); + let body = value.get("body")?.as_str()?; + let id = value.get("id").and_then(|v| v.as_str()).map(String::from); + + Some(BridgeEvent::ChatMessage { + source: Source::Owncast, + username: display_name, + body: strip_html(body), + id, + }) + } + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_build_ws_url_https() { + assert_eq!( + build_ws_url("https://owncast.example.com"), + "wss://owncast.example.com/ws" + ); + } + + #[test] + fn test_build_ws_url_http() { + assert_eq!( + build_ws_url("http://localhost:8080"), + "ws://localhost:8080/ws" + ); + } + + #[test] + fn test_build_ws_url_trailing_slash() { + assert_eq!( + build_ws_url("https://owncast.example.com/"), + "wss://owncast.example.com/ws" + ); + } + + #[test] + fn test_parse_ws_chat_message() { + let json = r#"{"type":"CHAT","id":"abc","body":"hello","user":{"displayName":"viewer","isBot":false}}"#; + let event = parse_ws_message(json); + assert!(matches!( + event, + Some(BridgeEvent::ChatMessage { ref username, ref body, .. }) + if username == "viewer" && body == "hello" + )); + } + + #[test] + fn test_parse_ws_bot_message_ignored() { + let json = r#"{"type":"CHAT","id":"abc","body":"hello","user":{"displayName":"bot","isBot":true}}"#; + assert!(parse_ws_message(json).is_none()); + } +}