chat.go 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. package chat
  2. import (
  3. "context"
  4. "log"
  5. "net/http"
  6. "sync"
  7. "github.com/gogf/gf/v2/net/ghttp"
  8. "github.com/gorilla/websocket"
  9. v1 "cris/api/chat/v1"
  10. )
  11. // 维护 WebSocket 客户端连接池
  12. var clients = make(map[*websocket.Conn]bool)
  13. var broadcast = make(chan string) // 广播消息的通道
  14. var mu sync.Mutex //互斥锁
  15. type chat struct{}
  16. func New() *chat {
  17. return &chat{}
  18. }
  19. func (s *chat) Chat(ctx context.Context, r *ghttp.Request, req *v1.ChatReq) {
  20. // 升级 HTTP 请求为 WebSocket
  21. upgrader := websocket.Upgrader{
  22. CheckOrigin: func(r *http.Request) bool {
  23. return true
  24. },
  25. ReadBufferSize: 1024,
  26. WriteBufferSize: 1024,
  27. }
  28. // 建立 WebSocket 连接
  29. conn, err := upgrader.Upgrade(r.Response.Writer, r.Request, nil)
  30. if err != nil {
  31. log.Println("WebSocket upgrade failed:", err)
  32. return
  33. }
  34. // 将新连接加入到连接池
  35. mu.Lock()
  36. clients[conn] = true
  37. mu.Unlock()
  38. //销毁
  39. defer func() {
  40. mu.Lock()
  41. delete(clients, conn)
  42. mu.Unlock()
  43. conn.Close()
  44. }()
  45. for {
  46. _, msg, err := conn.ReadMessage()
  47. if err != nil {
  48. log.Println("Error reading message:", err)
  49. break
  50. }
  51. broadcast <- string(msg)
  52. }
  53. }
  54. // 广播消息
  55. func handleBroadcast() {
  56. for {
  57. msg := <-broadcast
  58. mu.Lock()
  59. for client := range clients {
  60. err := client.WriteMessage(websocket.TextMessage, []byte(msg))
  61. if err != nil {
  62. log.Println("Error writing message:", err)
  63. client.Close()
  64. delete(clients, client)
  65. }
  66. }
  67. mu.Unlock()
  68. }
  69. }
  70. // 启动广播处理 goroutine
  71. func init() {
  72. go handleBroadcast()
  73. }