Files
IRC-kosmi-relay/bridge/kosmi/graphql.go
2025-10-31 16:17:04 -04:00

392 lines
9.0 KiB
Go

package bkosmi
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
)
// GraphQL WebSocket message types
const (
typeConnectionInit = "connection_init"
typeConnectionAck = "connection_ack"
typeConnectionError = "connection_error"
typeConnectionKeepAlive = "ka"
typeStart = "start"
typeData = "data"
typeError = "error"
typeComplete = "complete"
typeStop = "stop"
typeNext = "next"
)
// GraphQLMessage represents a GraphQL WebSocket message
type GraphQLMessage struct {
ID string `json:"id,omitempty"`
Type string `json:"type"`
Payload json.RawMessage `json:"payload,omitempty"`
}
// NewMessagePayload represents the payload structure for new messages
type NewMessagePayload struct {
Data struct {
NewMessage struct {
Body string `json:"body"`
Time int64 `json:"time"`
User struct {
DisplayName string `json:"displayName"`
Username string `json:"username"`
} `json:"user"`
} `json:"newMessage"`
} `json:"data"`
}
// GraphQLClient manages the WebSocket connection to Kosmi's GraphQL API
type GraphQLClient struct {
conn *websocket.Conn
url string
roomID string
log *logrus.Entry
subscriptionID string
mu sync.RWMutex
connected bool
reconnectDelay time.Duration
messageHandlers []func(*NewMessagePayload)
ctx context.Context
cancel context.CancelFunc
}
// NewGraphQLClient creates a new GraphQL WebSocket client
func NewGraphQLClient(url, roomID string, log *logrus.Entry) *GraphQLClient {
ctx, cancel := context.WithCancel(context.Background())
return &GraphQLClient{
url: url,
roomID: roomID,
log: log,
reconnectDelay: 5 * time.Second,
messageHandlers: []func(*NewMessagePayload){},
ctx: ctx,
cancel: cancel,
}
}
// Connect establishes the WebSocket connection and performs the GraphQL handshake
func (c *GraphQLClient) Connect() error {
c.log.Infof("Connecting to Kosmi GraphQL WebSocket: %s", c.url)
// Set up WebSocket dialer with graphql-ws subprotocol
dialer := websocket.Dialer{
Subprotocols: []string{"graphql-ws"},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
// Connect to WebSocket
conn, resp, err := dialer.Dial(c.url, http.Header{})
if err != nil {
if resp != nil {
c.log.Errorf("WebSocket dial failed with status %d: %v", resp.StatusCode, err)
}
return fmt.Errorf("failed to connect to WebSocket: %w", err)
}
defer resp.Body.Close()
c.mu.Lock()
c.conn = conn
c.mu.Unlock()
c.log.Info("WebSocket connection established")
// Send connection_init message
initMsg := GraphQLMessage{
Type: typeConnectionInit,
Payload: json.RawMessage(`{}`),
}
if err := c.writeMessage(initMsg); err != nil {
return fmt.Errorf("failed to send connection_init: %w", err)
}
c.log.Debug("Sent connection_init message")
// Wait for connection_ack
if err := c.waitForConnectionAck(); err != nil {
return fmt.Errorf("failed to receive connection_ack: %w", err)
}
c.mu.Lock()
c.connected = true
c.mu.Unlock()
c.log.Info("GraphQL WebSocket handshake completed")
return nil
}
// waitForConnectionAck waits for the connection_ack message
func (c *GraphQLClient) waitForConnectionAck() error {
c.mu.RLock()
conn := c.conn
c.mu.RUnlock()
// Set a timeout for the ack
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
defer conn.SetReadDeadline(time.Time{})
for {
var msg GraphQLMessage
if err := conn.ReadJSON(&msg); err != nil {
return fmt.Errorf("failed to read message: %w", err)
}
c.log.Debugf("Received message type: %s", msg.Type)
switch msg.Type {
case typeConnectionAck:
c.log.Info("Received connection_ack")
return nil
case typeConnectionError:
return fmt.Errorf("connection error: %s", string(msg.Payload))
case typeConnectionKeepAlive:
c.log.Debug("Received keep-alive")
// Continue waiting for ack
default:
c.log.Warnf("Unexpected message type during handshake: %s", msg.Type)
}
}
}
// SubscribeToMessages subscribes to new messages in the room
func (c *GraphQLClient) SubscribeToMessages() error {
c.mu.Lock()
if !c.connected {
c.mu.Unlock()
return fmt.Errorf("not connected")
}
c.subscriptionID = "newMessage-1"
c.mu.Unlock()
// GraphQL subscription query for new messages
query := fmt.Sprintf(`
subscription {
newMessage(roomId: "%s") {
body
time
user {
displayName
username
}
}
}
`, c.roomID)
payload := map[string]interface{}{
"query": query,
"variables": map[string]interface{}{},
}
payloadJSON, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal subscription payload: %w", err)
}
msg := GraphQLMessage{
ID: c.subscriptionID,
Type: typeStart,
Payload: payloadJSON,
}
if err := c.writeMessage(msg); err != nil {
return fmt.Errorf("failed to send subscription: %w", err)
}
c.log.Infof("Subscribed to messages in room: %s", c.roomID)
return nil
}
// OnMessage registers a handler for incoming messages
func (c *GraphQLClient) OnMessage(handler func(*NewMessagePayload)) {
c.mu.Lock()
defer c.mu.Unlock()
c.messageHandlers = append(c.messageHandlers, handler)
}
// Listen starts listening for messages from the WebSocket
func (c *GraphQLClient) Listen() {
c.log.Info("Starting message listener")
for {
select {
case <-c.ctx.Done():
c.log.Info("Message listener stopped")
return
default:
c.mu.RLock()
conn := c.conn
c.mu.RUnlock()
if conn == nil {
c.log.Warn("Connection is nil, stopping listener")
return
}
var msg GraphQLMessage
if err := conn.ReadJSON(&msg); err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
c.log.Errorf("WebSocket error: %v", err)
}
c.log.Warn("Connection closed, stopping listener")
c.mu.Lock()
c.connected = false
c.mu.Unlock()
return
}
c.handleMessage(&msg)
}
}
}
// handleMessage processes incoming GraphQL messages
func (c *GraphQLClient) handleMessage(msg *GraphQLMessage) {
c.log.Debugf("Received GraphQL message type: %s", msg.Type)
switch msg.Type {
case typeNext, typeData:
// Parse the message payload
var payload NewMessagePayload
if err := json.Unmarshal(msg.Payload, &payload); err != nil {
c.log.Errorf("Failed to parse message payload: %v", err)
return
}
// Call all registered handlers
c.mu.RLock()
handlers := c.messageHandlers
c.mu.RUnlock()
for _, handler := range handlers {
handler(&payload)
}
case typeConnectionKeepAlive:
c.log.Debug("Received keep-alive")
case typeError:
c.log.Errorf("GraphQL error: %s", string(msg.Payload))
case typeComplete:
c.log.Infof("Subscription %s completed", msg.ID)
default:
c.log.Debugf("Unhandled message type: %s", msg.Type)
}
}
// SendMessage sends a message to the Kosmi room
func (c *GraphQLClient) SendMessage(text string) error {
c.mu.RLock()
if !c.connected {
c.mu.RUnlock()
return fmt.Errorf("not connected")
}
c.mu.RUnlock()
// GraphQL mutation to send a message
mutation := fmt.Sprintf(`
mutation {
sendMessage(roomId: "%s", body: "%s") {
id
}
}
`, c.roomID, escapeGraphQLString(text))
payload := map[string]interface{}{
"query": mutation,
"variables": map[string]interface{}{},
}
payloadJSON, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal mutation payload: %w", err)
}
msg := GraphQLMessage{
ID: fmt.Sprintf("sendMessage-%d", time.Now().UnixNano()),
Type: typeStart,
Payload: payloadJSON,
}
if err := c.writeMessage(msg); err != nil {
return fmt.Errorf("failed to send message: %w", err)
}
c.log.Debugf("Sent message: %s", text)
return nil
}
// writeMessage writes a GraphQL message to the WebSocket
func (c *GraphQLClient) writeMessage(msg GraphQLMessage) error {
c.mu.RLock()
conn := c.conn
c.mu.RUnlock()
if conn == nil {
return fmt.Errorf("connection is nil")
}
return conn.WriteJSON(msg)
}
// Close closes the WebSocket connection
func (c *GraphQLClient) Close() error {
c.log.Info("Closing GraphQL client")
c.cancel()
c.mu.Lock()
defer c.mu.Unlock()
if c.conn != nil {
// Send close message
closeMsg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")
c.conn.WriteMessage(websocket.CloseMessage, closeMsg)
c.conn.Close()
c.conn = nil
}
c.connected = false
return nil
}
// IsConnected returns whether the client is connected
func (c *GraphQLClient) IsConnected() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.connected
}
// escapeGraphQLString escapes special characters in GraphQL strings
func escapeGraphQLString(s string) string {
// Replace special characters that need escaping in GraphQL
jsonBytes, err := json.Marshal(s)
if err != nil {
return s
}
// Remove surrounding quotes from JSON string
if len(jsonBytes) >= 2 {
return string(jsonBytes[1 : len(jsonBytes)-1])
}
return string(jsonBytes)
}