diff --git a/servers/clientmanager.go b/servers/clientmanager.go index 18a3c3c..6add426 100644 --- a/servers/clientmanager.go +++ b/servers/clientmanager.go @@ -1,6 +1,7 @@ package servers import ( + "context" "encoding/json" "errors" log "github.com/sirupsen/logrus" @@ -13,29 +14,20 @@ import ( // 连接管理 type ClientManager struct { - ClientIdMap map[string]*Client // 全部的连接 - ClientIdMapLock sync.RWMutex // 读写锁 + ClientIdMap sync.Map // 全部的连接 + GroupClientIdMap sync.Map // Group to ClientId 映射 + SystemClients sync.Map // systemId to []clientId 映射 + mutex sync.Mutex // 用于保持一致性的全局互斥锁 Connect chan *Client // 连接处理 DisConnect chan *Client // 断开连接处理 - - GroupLock sync.RWMutex - Groups map[string][]string - - SystemClientsLock sync.RWMutex - SystemClients map[string][]string } -func NewClientManager() (clientManager *ClientManager) { - clientManager = &ClientManager{ - ClientIdMap: make(map[string]*Client), - Connect: make(chan *Client, 10000), - DisConnect: make(chan *Client, 10000), - Groups: make(map[string][]string, 100), - SystemClients: make(map[string][]string, 100), +func NewClientManager() *ClientManager { + return &ClientManager{ + Connect: make(chan *Client, 10000), + DisConnect: make(chan *Client, 10000), } - - return } // 管道处理程序 @@ -45,154 +37,156 @@ func (manager *ClientManager) Start() { case client := <-manager.Connect: // 建立连接事件 manager.EventConnect(client) - case conn := <-manager.DisConnect: + case client := <-manager.DisConnect: // 断开连接事件 - manager.EventDisconnect(conn) + manager.EventDisconnect(client) } } } // 建立连接事件 func (manager *ClientManager) EventConnect(client *Client) { - manager.AddClient(client) + // 新开协程,1-同一个锁,lock 时 死锁的问题;2-避免 channel 监听挂掉 + util.SafeGo(context.Background(), func() { + manager.AddClient(client) + + log.WithFields(log.Fields{ + "host": setting.GlobalSetting.LocalHost, + "port": setting.CommonSetting.HttpPort, + "clientId": client.ClientId, + "counts": manager.Count(), + }).Info("客户端已连接") + }) - log.WithFields(log.Fields{ - "host": setting.GlobalSetting.LocalHost, - "port": setting.CommonSetting.HttpPort, - "clientId": client.ClientId, - "counts": Manager.Count(), - }).Info("客户端已连接") } -// 断开连接时间 +// 断开连接事件 func (manager *ClientManager) EventDisconnect(client *Client) { - //关闭连接 - _ = client.Socket.Close() - manager.DelClient(client) + util.SafeGo(context.Background(), func() { + //关闭连接 + if err := client.Socket.Close(); err != nil { + log.Errorf("Socket close error: %v", err) + } - mJson, _ := json.Marshal(map[string]string{ - "clientId": client.ClientId, - "userId": client.UserId, - "extend": client.Extend, - }) - data := string(mJson) - sendUserId := "" + manager.DelClient(client) + + mJson, err := json.Marshal(map[string]string{ + "clientId": client.ClientId, + "userId": client.UserId, + "extend": client.Extend, + }) + if err != nil { + log.Errorf("JSON Marshal error: %v", err) + return + } + data := string(mJson) + sendUserId := "" - //发送下线通知 - if len(client.GroupList) > 0 { + // 发送下线通知 for _, groupName := range client.GroupList { SendMessage2Group(client.SystemId, sendUserId, groupName, retcode.OFFLINE_MESSAGE_CODE, "客户端下线", &data) } - } - log.WithFields(log.Fields{ - "host": setting.GlobalSetting.LocalHost, - "port": setting.CommonSetting.HttpPort, - "clientId": client.ClientId, - "counts": Manager.Count(), - "seconds": uint64(time.Now().Unix()) - client.ConnectTime, - }).Info("客户端已断开") + log.WithFields(log.Fields{ + "host": setting.GlobalSetting.LocalHost, + "port": setting.CommonSetting.HttpPort, + "clientId": client.ClientId, + "counts": manager.Count(), + "seconds": uint64(time.Now().Unix()) - client.ConnectTime, + }).Info("客户端已断开") + + // 标记销毁 + client.IsDeleted = true + }) - //标记销毁 - client.IsDeleted = true - client = nil } // 添加客户端 func (manager *ClientManager) AddClient(client *Client) { - manager.ClientIdMapLock.Lock() - defer manager.ClientIdMapLock.Unlock() - - manager.ClientIdMap[client.ClientId] = client + manager.ClientIdMap.Store(client.ClientId, client) } // 获取所有的客户端 func (manager *ClientManager) AllClient() map[string]*Client { - manager.ClientIdMapLock.RLock() - defer manager.ClientIdMapLock.RUnlock() - - return manager.ClientIdMap + clientsCopy := make(map[string]*Client) + manager.ClientIdMap.Range(func(key, value interface{}) bool { + clientsCopy[key.(string)] = value.(*Client) + return true + }) + return clientsCopy } // 客户端数量 func (manager *ClientManager) Count() int { - manager.ClientIdMapLock.RLock() - defer manager.ClientIdMapLock.RUnlock() - return len(manager.ClientIdMap) + count := 0 + manager.ClientIdMap.Range(func(key, value interface{}) bool { + count++ + return true + }) + return count } // 删除客户端 func (manager *ClientManager) DelClient(client *Client) { - manager.delClientIdMap(client.ClientId) + manager.mutex.Lock() + defer manager.mutex.Unlock() - //删除所在的分组 - if len(client.GroupList) > 0 { - for _, groupName := range client.GroupList { - manager.delGroupClient(util.GenGroupKey(client.SystemId, groupName), client.ClientId) - } + manager.ClientIdMap.Delete(client.ClientId) + + // 删除所在的分组 + for _, groupName := range client.GroupList { + groupKey := util.GenGroupKey(client.SystemId, groupName) + manager.delGroupClientUnsafe(groupKey, client.ClientId) } // 删除系统里的客户端 - manager.delSystemClient(client) -} - -// 删除clientIdMap -func (manager *ClientManager) delClientIdMap(clientId string) { - manager.ClientIdMapLock.Lock() - defer manager.ClientIdMapLock.Unlock() - - delete(manager.ClientIdMap, clientId) + manager.delSystemClientUnsafe(client.SystemId, client.ClientId) } -// 通过clientId获取 +// 通过clientId获取客户端 func (manager *ClientManager) GetByClientId(clientId string) (*Client, error) { - manager.ClientIdMapLock.RLock() - defer manager.ClientIdMapLock.RUnlock() - - if client, ok := manager.ClientIdMap[clientId]; !ok { + value, ok := manager.ClientIdMap.Load(clientId) + if !ok { return nil, errors.New("客户端不存在") - } else { - return client, nil } + return value.(*Client), nil } // 发送到本机分组 func (manager *ClientManager) SendMessage2LocalGroup(systemId, messageId, sendUserId, groupName string, code int, msg string, data *string) { - if len(groupName) > 0 { - clientIds := manager.GetGroupClientList(util.GenGroupKey(systemId, groupName)) - if len(clientIds) > 0 { - for _, clientId := range clientIds { - if _, err := Manager.GetByClientId(clientId); err == nil { - //添加到本地 - SendMessage2LocalClient(messageId, clientId, sendUserId, code, msg, data) - } else { - //删除分组 - manager.delGroupClient(util.GenGroupKey(systemId, groupName), clientId) - } - } + groupKey := util.GenGroupKey(systemId, groupName) + value, _ := manager.GroupClientIdMap.Load(groupKey) + clientIds, _ := value.([]string) + for _, clientId := range clientIds { + if _, err := manager.GetByClientId(clientId); err == nil { + //发送消息到本地客户端 + SendMessage2LocalClient(messageId, clientId, sendUserId, code, msg, data) + } else { + //删除无效的分组客户端 + manager.delGroupClient(groupKey, clientId) } } } -//发送给指定业务系统 +// 发送给指定业务系统 func (manager *ClientManager) SendMessage2LocalSystem(systemId, messageId string, sendUserId string, code int, msg string, data *string) { - if len(systemId) > 0 { - clientIds := Manager.GetSystemClientList(systemId) - if len(clientIds) > 0 { - for _, clientId := range clientIds { - SendMessage2LocalClient(messageId, clientId, sendUserId, code, msg, data) - } - } + value, _ := manager.SystemClients.Load(systemId) + clientIds, _ := value.([]string) + for _, clientId := range clientIds { + SendMessage2LocalClient(messageId, clientId, sendUserId, code, msg, data) } } // 添加到本地分组 func (manager *ClientManager) AddClient2LocalGroup(groupName string, client *Client, userId string, extend string) { + manager.mutex.Lock() + defer manager.mutex.Unlock() + //标记当前客户端的userId client.UserId = userId client.Extend = extend - //判断之前是否有添加过 + // 判断之前是否有添加过 for _, groupValue := range client.GroupList { if groupValue == groupName { return @@ -202,15 +196,19 @@ func (manager *ClientManager) AddClient2LocalGroup(groupName string, client *Cli // 为属性添加分组信息 groupKey := util.GenGroupKey(client.SystemId, groupName) - manager.addClient2Group(groupKey, client) + manager.addClient2GroupUnsafe(groupKey, client.ClientId) client.GroupList = append(client.GroupList, groupName) - mJson, _ := json.Marshal(map[string]string{ + mJson, err := json.Marshal(map[string]string{ "clientId": client.ClientId, "userId": client.UserId, "extend": client.Extend, }) + if err != nil { + log.Errorf("JSON Marshal error: %v", err) + return + } data := string(mJson) sendUserId := "" @@ -218,54 +216,67 @@ func (manager *ClientManager) AddClient2LocalGroup(groupName string, client *Cli SendMessage2Group(client.SystemId, sendUserId, groupName, retcode.ONLINE_MESSAGE_CODE, "客户端上线", &data) } -// 添加到本地分组 -func (manager *ClientManager) addClient2Group(groupKey string, client *Client) { - manager.GroupLock.Lock() - defer manager.GroupLock.Unlock() - manager.Groups[groupKey] = append(manager.Groups[groupKey], client.ClientId) +// 添加到本地分组 (非线程安全内部方法) +func (manager *ClientManager) addClient2GroupUnsafe(groupKey string, clientId string) { + value, _ := manager.GroupClientIdMap.Load(groupKey) + clientIds, _ := value.([]string) + clientIds = append(clientIds, clientId) + manager.GroupClientIdMap.Store(groupKey, clientIds) } // 删除分组里的客户端 func (manager *ClientManager) delGroupClient(groupKey string, clientId string) { - manager.GroupLock.Lock() - defer manager.GroupLock.Unlock() + manager.mutex.Lock() + defer manager.mutex.Unlock() - for index, groupClientId := range manager.Groups[groupKey] { - if groupClientId == clientId { - manager.Groups[groupKey] = append(manager.Groups[groupKey][:index], manager.Groups[groupKey][index+1:]...) + manager.delGroupClientUnsafe(groupKey, clientId) +} + +// 删除分组里的客户端 (非线程安全内部方法) +func (manager *ClientManager) delGroupClientUnsafe(groupKey string, clientId string) { + value, _ := manager.GroupClientIdMap.Load(groupKey) + clientIds, _ := value.([]string) + for index, id := range clientIds { + if id == clientId { + clientIds = append(clientIds[:index], clientIds[index+1:]...) + break } } + manager.GroupClientIdMap.Store(groupKey, clientIds) } // 获取本地分组的成员 func (manager *ClientManager) GetGroupClientList(groupKey string) []string { - manager.GroupLock.RLock() - defer manager.GroupLock.RUnlock() - return manager.Groups[groupKey] + value, _ := manager.GroupClientIdMap.Load(groupKey) + return value.([]string) } // 添加到系统客户端列表 func (manager *ClientManager) AddClient2SystemClient(systemId string, client *Client) { - manager.SystemClientsLock.Lock() - defer manager.SystemClientsLock.Unlock() - manager.SystemClients[systemId] = append(manager.SystemClients[systemId], client.ClientId) -} + manager.mutex.Lock() + defer manager.mutex.Unlock() -// 删除系统里的客户端 -func (manager *ClientManager) delSystemClient(client *Client) { - manager.SystemClientsLock.Lock() - defer manager.SystemClientsLock.Unlock() + value, _ := manager.SystemClients.Load(systemId) + clientIds, _ := value.([]string) + clientIds = append(clientIds, client.ClientId) + manager.SystemClients.Store(systemId, clientIds) +} - for index, clientId := range manager.SystemClients[client.SystemId] { - if clientId == client.ClientId { - manager.SystemClients[client.SystemId] = append(manager.SystemClients[client.SystemId][:index], manager.SystemClients[client.SystemId][index+1:]...) +// 删除系统里的客户端 (非线程安全内部方法) +func (manager *ClientManager) delSystemClientUnsafe(systemId string, clientId string) { + value, _ := manager.SystemClients.Load(systemId) + clientIds, _ := value.([]string) + for index, id := range clientIds { + if id == clientId { + clientIds = append(clientIds[:index], clientIds[index+1:]...) + break } } + manager.SystemClients.Store(systemId, clientIds) } // 获取指定系统的客户端列表 func (manager *ClientManager) GetSystemClientList(systemId string) []string { - manager.SystemClientsLock.RLock() - defer manager.SystemClientsLock.RUnlock() - return manager.SystemClients[systemId] + value, _ := manager.SystemClients.Load(systemId) + return value.([]string) } diff --git a/servers/server.go b/servers/server.go index 0016a31..248a1b8 100644 --- a/servers/server.go +++ b/servers/server.go @@ -9,10 +9,10 @@ import ( "time" ) -//channel通道 +// channel通道 var ToClientChan chan clientInfo -//channel通道结构体 +// channel通道结构体 type clientInfo struct { ClientId string SendUserId string @@ -46,7 +46,7 @@ func StartWebSocket() { go Manager.Start() } -//发送信息到指定客户端 +// 发送信息到指定客户端 func SendMessage2Client(clientId string, sendUserId string, code int, msg string, data *string) (messageId string) { messageId = util.GenUUID() if util.IsCluster() { @@ -71,7 +71,7 @@ func SendMessage2Client(clientId string, sendUserId string, code int, msg string return } -//关闭客户端 +// 关闭客户端 func CloseClient(clientId, systemId string) { if util.IsCluster() { addr, _, _, isLocal, err := util.GetAddrInfoAndIsLocal(clientId) @@ -95,7 +95,7 @@ func CloseClient(clientId, systemId string) { return } -//添加客户端到分组 +// 添加客户端到分组 func AddClient2Group(systemId string, groupName string, clientId string, userId string, extend string) { //如果是集群则用redis共享数据 if util.IsCluster() { @@ -125,7 +125,7 @@ func AddClient2Group(systemId string, groupName string, clientId string, userId } } -//发送信息到指定分组 +// 发送信息到指定分组 func SendMessage2Group(systemId, sendUserId, groupName string, code int, msg string, data *string) (messageId string) { messageId = util.GenUUID() if util.IsCluster() { @@ -138,7 +138,7 @@ func SendMessage2Group(systemId, sendUserId, groupName string, code int, msg str return } -//发送信息到指定系统 +// 发送信息到指定系统 func SendMessage2System(systemId, sendUserId string, code int, msg string, data string) { messageId := util.GenUUID() if util.IsCluster() { @@ -150,7 +150,7 @@ func SendMessage2System(systemId, sendUserId string, code int, msg string, data } } -//获取分组列表 +// 获取分组列表 func GetOnlineList(systemId *string, groupName *string) map[string]interface{} { var clientList []string if util.IsCluster() { @@ -168,7 +168,7 @@ func GetOnlineList(systemId *string, groupName *string) map[string]interface{} { } } -//通过本服务器发送信息 +// 通过本服务器发送信息 func SendMessage2LocalClient(messageId, clientId string, sendUserId string, code int, msg string, data *string) { log.WithFields(log.Fields{ "host": setting.GlobalSetting.LocalHost, @@ -179,7 +179,7 @@ func SendMessage2LocalClient(messageId, clientId string, sendUserId string, code return } -//发送关闭信号 +// 发送关闭信号 func CloseLocalClient(clientId, systemId string) { if conn, err := Manager.GetByClientId(clientId); err == nil && conn != nil { if conn.SystemId != systemId { @@ -195,7 +195,7 @@ func CloseLocalClient(clientId, systemId string) { return } -//监听并发送给客户端信息 +// 监听并发送给客户端信息 func WriteMessage() { for { clientInfo := <-ToClientChan @@ -233,7 +233,7 @@ func Render(conn *websocket.Conn, messageId string, sendUserId string, code int, }) } -//启动定时器进行心跳检测 +// 启动定时器进行心跳检测 func PingTimer() { go func() { ticker := time.NewTicker(heartbeatInterval) diff --git a/tools/util/goroutine.go b/tools/util/goroutine.go new file mode 100644 index 0000000..11c7fe5 --- /dev/null +++ b/tools/util/goroutine.go @@ -0,0 +1,30 @@ +package util + +import ( + "context" + log "github.com/sirupsen/logrus" + "reflect" + "runtime" + "runtime/debug" + "time" +) + +// Recovery : recover and log the error information. +func Recovery(ctx context.Context) { + if err := recover(); err != nil { + log.WithContext(ctx).Errorf("Catch Panic: %+v\nStackTrace:\n%s", err, debug.Stack()) + } +} + +func timeTrack(ctx context.Context, start time.Time, name string) { + elapsed := time.Since(start) + log.WithContext(ctx).Infof("%s took %dms", name, elapsed.Nanoseconds()/1000000) +} + +func SafeGo(ctx context.Context, fun func()) { + go func() { + defer Recovery(ctx) + defer timeTrack(ctx, time.Now(), runtime.FuncForPC(reflect.ValueOf(fun).Pointer()).Name()) + fun() + }() +}