chat.go 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. package chat
  2. import (
  3. "context"
  4. "log"
  5. "net/http"
  6. "sync"
  7. "github.com/gogf/gf/v2/net/ghttp"
  8. "github.com/gorilla/websocket"
  9. v1 "cris/api/chat/v1"
  10. )
  11. // 维护 WebSocket 客户端连接池
  12. var clients = make(map[*websocket.Conn]bool)
  13. var broadcast = make(chan string) // 广播消息的通道
  14. var mu sync.Mutex //互斥锁
  15. type schat struct{}
  16. func (s *schat) Chat(ctx context.Context, r *ghttp.Request, req *v1.ChatReq) {
  17. // 升级 HTTP 请求为 WebSocket
  18. upgrader := websocket.Upgrader{
  19. CheckOrigin: func(r *http.Request) bool {
  20. return true
  21. },
  22. ReadBufferSize: 1024,
  23. WriteBufferSize: 1024,
  24. }
  25. // 建立 WebSocket 连接
  26. conn, err := upgrader.Upgrade(r.Response.Writer, r.Request, nil)
  27. if err != nil {
  28. log.Println("WebSocket upgrade failed:", err)
  29. return
  30. }
  31. // 将新连接加入到连接池
  32. mu.Lock()
  33. clients[conn] = true
  34. mu.Unlock()
  35. //销毁
  36. defer func() {
  37. mu.Lock()
  38. delete(clients, conn)
  39. mu.Unlock()
  40. conn.Close()
  41. }()
  42. for {
  43. _, msg, err := conn.ReadMessage()
  44. if err != nil {
  45. log.Println("Error reading message:", err)
  46. break
  47. }
  48. broadcast <- string(msg)
  49. }
  50. }
  51. // 广播消息
  52. func handleBroadcast() {
  53. for {
  54. msg := <-broadcast
  55. mu.Lock()
  56. for client := range clients {
  57. err := client.WriteMessage(websocket.TextMessage, []byte(msg))
  58. if err != nil {
  59. log.Println("Error writing message:", err)
  60. client.Close()
  61. delete(clients, client)
  62. }
  63. }
  64. mu.Unlock()
  65. }
  66. }
  67. // 启动广播处理 goroutine
  68. func init() {
  69. go handleBroadcast()
  70. }