392 lines
9.0 KiB
Go
392 lines
9.0 KiB
Go
package bkosmi
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// GraphQL WebSocket message types
|
|
const (
|
|
typeConnectionInit = "connection_init"
|
|
typeConnectionAck = "connection_ack"
|
|
typeConnectionError = "connection_error"
|
|
typeConnectionKeepAlive = "ka"
|
|
typeStart = "start"
|
|
typeData = "data"
|
|
typeError = "error"
|
|
typeComplete = "complete"
|
|
typeStop = "stop"
|
|
typeNext = "next"
|
|
)
|
|
|
|
// GraphQLMessage represents a GraphQL WebSocket message
|
|
type GraphQLMessage struct {
|
|
ID string `json:"id,omitempty"`
|
|
Type string `json:"type"`
|
|
Payload json.RawMessage `json:"payload,omitempty"`
|
|
}
|
|
|
|
// NewMessagePayload represents the payload structure for new messages
|
|
type NewMessagePayload struct {
|
|
Data struct {
|
|
NewMessage struct {
|
|
Body string `json:"body"`
|
|
Time int64 `json:"time"`
|
|
User struct {
|
|
DisplayName string `json:"displayName"`
|
|
Username string `json:"username"`
|
|
} `json:"user"`
|
|
} `json:"newMessage"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
// GraphQLClient manages the WebSocket connection to Kosmi's GraphQL API
|
|
type GraphQLClient struct {
|
|
conn *websocket.Conn
|
|
url string
|
|
roomID string
|
|
log *logrus.Entry
|
|
subscriptionID string
|
|
mu sync.RWMutex
|
|
connected bool
|
|
reconnectDelay time.Duration
|
|
messageHandlers []func(*NewMessagePayload)
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
// NewGraphQLClient creates a new GraphQL WebSocket client
|
|
func NewGraphQLClient(url, roomID string, log *logrus.Entry) *GraphQLClient {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
return &GraphQLClient{
|
|
url: url,
|
|
roomID: roomID,
|
|
log: log,
|
|
reconnectDelay: 5 * time.Second,
|
|
messageHandlers: []func(*NewMessagePayload){},
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
}
|
|
|
|
// Connect establishes the WebSocket connection and performs the GraphQL handshake
|
|
func (c *GraphQLClient) Connect() error {
|
|
c.log.Infof("Connecting to Kosmi GraphQL WebSocket: %s", c.url)
|
|
|
|
// Set up WebSocket dialer with graphql-ws subprotocol
|
|
dialer := websocket.Dialer{
|
|
Subprotocols: []string{"graphql-ws"},
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
}
|
|
|
|
// Connect to WebSocket
|
|
conn, resp, err := dialer.Dial(c.url, http.Header{})
|
|
if err != nil {
|
|
if resp != nil {
|
|
c.log.Errorf("WebSocket dial failed with status %d: %v", resp.StatusCode, err)
|
|
}
|
|
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
c.mu.Lock()
|
|
c.conn = conn
|
|
c.mu.Unlock()
|
|
|
|
c.log.Info("WebSocket connection established")
|
|
|
|
// Send connection_init message
|
|
initMsg := GraphQLMessage{
|
|
Type: typeConnectionInit,
|
|
Payload: json.RawMessage(`{}`),
|
|
}
|
|
|
|
if err := c.writeMessage(initMsg); err != nil {
|
|
return fmt.Errorf("failed to send connection_init: %w", err)
|
|
}
|
|
|
|
c.log.Debug("Sent connection_init message")
|
|
|
|
// Wait for connection_ack
|
|
if err := c.waitForConnectionAck(); err != nil {
|
|
return fmt.Errorf("failed to receive connection_ack: %w", err)
|
|
}
|
|
|
|
c.mu.Lock()
|
|
c.connected = true
|
|
c.mu.Unlock()
|
|
|
|
c.log.Info("GraphQL WebSocket handshake completed")
|
|
|
|
return nil
|
|
}
|
|
|
|
// waitForConnectionAck waits for the connection_ack message
|
|
func (c *GraphQLClient) waitForConnectionAck() error {
|
|
c.mu.RLock()
|
|
conn := c.conn
|
|
c.mu.RUnlock()
|
|
|
|
// Set a timeout for the ack
|
|
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
|
defer conn.SetReadDeadline(time.Time{})
|
|
|
|
for {
|
|
var msg GraphQLMessage
|
|
if err := conn.ReadJSON(&msg); err != nil {
|
|
return fmt.Errorf("failed to read message: %w", err)
|
|
}
|
|
|
|
c.log.Debugf("Received message type: %s", msg.Type)
|
|
|
|
switch msg.Type {
|
|
case typeConnectionAck:
|
|
c.log.Info("Received connection_ack")
|
|
return nil
|
|
case typeConnectionError:
|
|
return fmt.Errorf("connection error: %s", string(msg.Payload))
|
|
case typeConnectionKeepAlive:
|
|
c.log.Debug("Received keep-alive")
|
|
// Continue waiting for ack
|
|
default:
|
|
c.log.Warnf("Unexpected message type during handshake: %s", msg.Type)
|
|
}
|
|
}
|
|
}
|
|
|
|
// SubscribeToMessages subscribes to new messages in the room
|
|
func (c *GraphQLClient) SubscribeToMessages() error {
|
|
c.mu.Lock()
|
|
if !c.connected {
|
|
c.mu.Unlock()
|
|
return fmt.Errorf("not connected")
|
|
}
|
|
c.subscriptionID = "newMessage-1"
|
|
c.mu.Unlock()
|
|
|
|
// GraphQL subscription query for new messages
|
|
query := fmt.Sprintf(`
|
|
subscription {
|
|
newMessage(roomId: "%s") {
|
|
body
|
|
time
|
|
user {
|
|
displayName
|
|
username
|
|
}
|
|
}
|
|
}
|
|
`, c.roomID)
|
|
|
|
payload := map[string]interface{}{
|
|
"query": query,
|
|
"variables": map[string]interface{}{},
|
|
}
|
|
|
|
payloadJSON, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal subscription payload: %w", err)
|
|
}
|
|
|
|
msg := GraphQLMessage{
|
|
ID: c.subscriptionID,
|
|
Type: typeStart,
|
|
Payload: payloadJSON,
|
|
}
|
|
|
|
if err := c.writeMessage(msg); err != nil {
|
|
return fmt.Errorf("failed to send subscription: %w", err)
|
|
}
|
|
|
|
c.log.Infof("Subscribed to messages in room: %s", c.roomID)
|
|
|
|
return nil
|
|
}
|
|
|
|
// OnMessage registers a handler for incoming messages
|
|
func (c *GraphQLClient) OnMessage(handler func(*NewMessagePayload)) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
c.messageHandlers = append(c.messageHandlers, handler)
|
|
}
|
|
|
|
// Listen starts listening for messages from the WebSocket
|
|
func (c *GraphQLClient) Listen() {
|
|
c.log.Info("Starting message listener")
|
|
|
|
for {
|
|
select {
|
|
case <-c.ctx.Done():
|
|
c.log.Info("Message listener stopped")
|
|
return
|
|
default:
|
|
c.mu.RLock()
|
|
conn := c.conn
|
|
c.mu.RUnlock()
|
|
|
|
if conn == nil {
|
|
c.log.Warn("Connection is nil, stopping listener")
|
|
return
|
|
}
|
|
|
|
var msg GraphQLMessage
|
|
if err := conn.ReadJSON(&msg); err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
c.log.Errorf("WebSocket error: %v", err)
|
|
}
|
|
c.log.Warn("Connection closed, stopping listener")
|
|
c.mu.Lock()
|
|
c.connected = false
|
|
c.mu.Unlock()
|
|
return
|
|
}
|
|
|
|
c.handleMessage(&msg)
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleMessage processes incoming GraphQL messages
|
|
func (c *GraphQLClient) handleMessage(msg *GraphQLMessage) {
|
|
c.log.Debugf("Received GraphQL message type: %s", msg.Type)
|
|
|
|
switch msg.Type {
|
|
case typeNext, typeData:
|
|
// Parse the message payload
|
|
var payload NewMessagePayload
|
|
if err := json.Unmarshal(msg.Payload, &payload); err != nil {
|
|
c.log.Errorf("Failed to parse message payload: %v", err)
|
|
return
|
|
}
|
|
|
|
// Call all registered handlers
|
|
c.mu.RLock()
|
|
handlers := c.messageHandlers
|
|
c.mu.RUnlock()
|
|
|
|
for _, handler := range handlers {
|
|
handler(&payload)
|
|
}
|
|
|
|
case typeConnectionKeepAlive:
|
|
c.log.Debug("Received keep-alive")
|
|
|
|
case typeError:
|
|
c.log.Errorf("GraphQL error: %s", string(msg.Payload))
|
|
|
|
case typeComplete:
|
|
c.log.Infof("Subscription %s completed", msg.ID)
|
|
|
|
default:
|
|
c.log.Debugf("Unhandled message type: %s", msg.Type)
|
|
}
|
|
}
|
|
|
|
// SendMessage sends a message to the Kosmi room
|
|
func (c *GraphQLClient) SendMessage(text string) error {
|
|
c.mu.RLock()
|
|
if !c.connected {
|
|
c.mu.RUnlock()
|
|
return fmt.Errorf("not connected")
|
|
}
|
|
c.mu.RUnlock()
|
|
|
|
// GraphQL mutation to send a message
|
|
mutation := fmt.Sprintf(`
|
|
mutation {
|
|
sendMessage(roomId: "%s", body: "%s") {
|
|
id
|
|
}
|
|
}
|
|
`, c.roomID, escapeGraphQLString(text))
|
|
|
|
payload := map[string]interface{}{
|
|
"query": mutation,
|
|
"variables": map[string]interface{}{},
|
|
}
|
|
|
|
payloadJSON, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal mutation payload: %w", err)
|
|
}
|
|
|
|
msg := GraphQLMessage{
|
|
ID: fmt.Sprintf("sendMessage-%d", time.Now().UnixNano()),
|
|
Type: typeStart,
|
|
Payload: payloadJSON,
|
|
}
|
|
|
|
if err := c.writeMessage(msg); err != nil {
|
|
return fmt.Errorf("failed to send message: %w", err)
|
|
}
|
|
|
|
c.log.Debugf("Sent message: %s", text)
|
|
|
|
return nil
|
|
}
|
|
|
|
// writeMessage writes a GraphQL message to the WebSocket
|
|
func (c *GraphQLClient) writeMessage(msg GraphQLMessage) error {
|
|
c.mu.RLock()
|
|
conn := c.conn
|
|
c.mu.RUnlock()
|
|
|
|
if conn == nil {
|
|
return fmt.Errorf("connection is nil")
|
|
}
|
|
|
|
return conn.WriteJSON(msg)
|
|
}
|
|
|
|
// Close closes the WebSocket connection
|
|
func (c *GraphQLClient) Close() error {
|
|
c.log.Info("Closing GraphQL client")
|
|
|
|
c.cancel()
|
|
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if c.conn != nil {
|
|
// Send close message
|
|
closeMsg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")
|
|
c.conn.WriteMessage(websocket.CloseMessage, closeMsg)
|
|
c.conn.Close()
|
|
c.conn = nil
|
|
}
|
|
|
|
c.connected = false
|
|
|
|
return nil
|
|
}
|
|
|
|
// IsConnected returns whether the client is connected
|
|
func (c *GraphQLClient) IsConnected() bool {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
return c.connected
|
|
}
|
|
|
|
// escapeGraphQLString escapes special characters in GraphQL strings
|
|
func escapeGraphQLString(s string) string {
|
|
// Replace special characters that need escaping in GraphQL
|
|
jsonBytes, err := json.Marshal(s)
|
|
if err != nil {
|
|
return s
|
|
}
|
|
// Remove surrounding quotes from JSON string
|
|
if len(jsonBytes) >= 2 {
|
|
return string(jsonBytes[1 : len(jsonBytes)-1])
|
|
}
|
|
return string(jsonBytes)
|
|
}
|
|
|