package bkosmi import ( "context" "encoding/json" "fmt" "net/http" "sync" "time" "github.com/gorilla/websocket" "github.com/sirupsen/logrus" ) // GraphQL WebSocket message types const ( typeConnectionInit = "connection_init" typeConnectionAck = "connection_ack" typeConnectionError = "connection_error" typeConnectionKeepAlive = "ka" typeStart = "start" typeData = "data" typeError = "error" typeComplete = "complete" typeStop = "stop" typeNext = "next" ) // GraphQLMessage represents a GraphQL WebSocket message type GraphQLMessage struct { ID string `json:"id,omitempty"` Type string `json:"type"` Payload json.RawMessage `json:"payload,omitempty"` } // NewMessagePayload represents the payload structure for new messages type NewMessagePayload struct { Data struct { NewMessage struct { Body string `json:"body"` Time int64 `json:"time"` User struct { DisplayName string `json:"displayName"` Username string `json:"username"` } `json:"user"` } `json:"newMessage"` } `json:"data"` } // GraphQLClient manages the WebSocket connection to Kosmi's GraphQL API type GraphQLClient struct { conn *websocket.Conn url string roomID string log *logrus.Entry subscriptionID string mu sync.RWMutex connected bool reconnectDelay time.Duration messageHandlers []func(*NewMessagePayload) ctx context.Context cancel context.CancelFunc } // NewGraphQLClient creates a new GraphQL WebSocket client func NewGraphQLClient(url, roomID string, log *logrus.Entry) *GraphQLClient { ctx, cancel := context.WithCancel(context.Background()) return &GraphQLClient{ url: url, roomID: roomID, log: log, reconnectDelay: 5 * time.Second, messageHandlers: []func(*NewMessagePayload){}, ctx: ctx, cancel: cancel, } } // Connect establishes the WebSocket connection and performs the GraphQL handshake func (c *GraphQLClient) Connect() error { c.log.Infof("Connecting to Kosmi GraphQL WebSocket: %s", c.url) // Set up WebSocket dialer with graphql-ws subprotocol dialer := websocket.Dialer{ Subprotocols: []string{"graphql-ws"}, ReadBufferSize: 1024, WriteBufferSize: 1024, } // Connect to WebSocket conn, resp, err := dialer.Dial(c.url, http.Header{}) if err != nil { if resp != nil { c.log.Errorf("WebSocket dial failed with status %d: %v", resp.StatusCode, err) } return fmt.Errorf("failed to connect to WebSocket: %w", err) } defer resp.Body.Close() c.mu.Lock() c.conn = conn c.mu.Unlock() c.log.Info("WebSocket connection established") // Send connection_init message initMsg := GraphQLMessage{ Type: typeConnectionInit, Payload: json.RawMessage(`{}`), } if err := c.writeMessage(initMsg); err != nil { return fmt.Errorf("failed to send connection_init: %w", err) } c.log.Debug("Sent connection_init message") // Wait for connection_ack if err := c.waitForConnectionAck(); err != nil { return fmt.Errorf("failed to receive connection_ack: %w", err) } c.mu.Lock() c.connected = true c.mu.Unlock() c.log.Info("GraphQL WebSocket handshake completed") return nil } // waitForConnectionAck waits for the connection_ack message func (c *GraphQLClient) waitForConnectionAck() error { c.mu.RLock() conn := c.conn c.mu.RUnlock() // Set a timeout for the ack conn.SetReadDeadline(time.Now().Add(10 * time.Second)) defer conn.SetReadDeadline(time.Time{}) for { var msg GraphQLMessage if err := conn.ReadJSON(&msg); err != nil { return fmt.Errorf("failed to read message: %w", err) } c.log.Debugf("Received message type: %s", msg.Type) switch msg.Type { case typeConnectionAck: c.log.Info("Received connection_ack") return nil case typeConnectionError: return fmt.Errorf("connection error: %s", string(msg.Payload)) case typeConnectionKeepAlive: c.log.Debug("Received keep-alive") // Continue waiting for ack default: c.log.Warnf("Unexpected message type during handshake: %s", msg.Type) } } } // SubscribeToMessages subscribes to new messages in the room func (c *GraphQLClient) SubscribeToMessages() error { c.mu.Lock() if !c.connected { c.mu.Unlock() return fmt.Errorf("not connected") } c.subscriptionID = "newMessage-1" c.mu.Unlock() // GraphQL subscription query for new messages query := fmt.Sprintf(` subscription { newMessage(roomId: "%s") { body time user { displayName username } } } `, c.roomID) payload := map[string]interface{}{ "query": query, "variables": map[string]interface{}{}, } payloadJSON, err := json.Marshal(payload) if err != nil { return fmt.Errorf("failed to marshal subscription payload: %w", err) } msg := GraphQLMessage{ ID: c.subscriptionID, Type: typeStart, Payload: payloadJSON, } if err := c.writeMessage(msg); err != nil { return fmt.Errorf("failed to send subscription: %w", err) } c.log.Infof("Subscribed to messages in room: %s", c.roomID) return nil } // OnMessage registers a handler for incoming messages func (c *GraphQLClient) OnMessage(handler func(*NewMessagePayload)) { c.mu.Lock() defer c.mu.Unlock() c.messageHandlers = append(c.messageHandlers, handler) } // Listen starts listening for messages from the WebSocket func (c *GraphQLClient) Listen() { c.log.Info("Starting message listener") for { select { case <-c.ctx.Done(): c.log.Info("Message listener stopped") return default: c.mu.RLock() conn := c.conn c.mu.RUnlock() if conn == nil { c.log.Warn("Connection is nil, stopping listener") return } var msg GraphQLMessage if err := conn.ReadJSON(&msg); err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { c.log.Errorf("WebSocket error: %v", err) } c.log.Warn("Connection closed, stopping listener") c.mu.Lock() c.connected = false c.mu.Unlock() return } c.handleMessage(&msg) } } } // handleMessage processes incoming GraphQL messages func (c *GraphQLClient) handleMessage(msg *GraphQLMessage) { c.log.Debugf("Received GraphQL message type: %s", msg.Type) switch msg.Type { case typeNext, typeData: // Parse the message payload var payload NewMessagePayload if err := json.Unmarshal(msg.Payload, &payload); err != nil { c.log.Errorf("Failed to parse message payload: %v", err) return } // Call all registered handlers c.mu.RLock() handlers := c.messageHandlers c.mu.RUnlock() for _, handler := range handlers { handler(&payload) } case typeConnectionKeepAlive: c.log.Debug("Received keep-alive") case typeError: c.log.Errorf("GraphQL error: %s", string(msg.Payload)) case typeComplete: c.log.Infof("Subscription %s completed", msg.ID) default: c.log.Debugf("Unhandled message type: %s", msg.Type) } } // SendMessage sends a message to the Kosmi room func (c *GraphQLClient) SendMessage(text string) error { c.mu.RLock() if !c.connected { c.mu.RUnlock() return fmt.Errorf("not connected") } c.mu.RUnlock() // GraphQL mutation to send a message mutation := fmt.Sprintf(` mutation { sendMessage(roomId: "%s", body: "%s") { id } } `, c.roomID, escapeGraphQLString(text)) payload := map[string]interface{}{ "query": mutation, "variables": map[string]interface{}{}, } payloadJSON, err := json.Marshal(payload) if err != nil { return fmt.Errorf("failed to marshal mutation payload: %w", err) } msg := GraphQLMessage{ ID: fmt.Sprintf("sendMessage-%d", time.Now().UnixNano()), Type: typeStart, Payload: payloadJSON, } if err := c.writeMessage(msg); err != nil { return fmt.Errorf("failed to send message: %w", err) } c.log.Debugf("Sent message: %s", text) return nil } // writeMessage writes a GraphQL message to the WebSocket func (c *GraphQLClient) writeMessage(msg GraphQLMessage) error { c.mu.RLock() conn := c.conn c.mu.RUnlock() if conn == nil { return fmt.Errorf("connection is nil") } return conn.WriteJSON(msg) } // Close closes the WebSocket connection func (c *GraphQLClient) Close() error { c.log.Info("Closing GraphQL client") c.cancel() c.mu.Lock() defer c.mu.Unlock() if c.conn != nil { // Send close message closeMsg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") c.conn.WriteMessage(websocket.CloseMessage, closeMsg) c.conn.Close() c.conn = nil } c.connected = false return nil } // IsConnected returns whether the client is connected func (c *GraphQLClient) IsConnected() bool { c.mu.RLock() defer c.mu.RUnlock() return c.connected } // escapeGraphQLString escapes special characters in GraphQL strings func escapeGraphQLString(s string) string { // Replace special characters that need escaping in GraphQL jsonBytes, err := json.Marshal(s) if err != nil { return s } // Remove surrounding quotes from JSON string if len(jsonBytes) >= 2 { return string(jsonBytes[1 : len(jsonBytes)-1]) } return string(jsonBytes) }