From 90af4ebee76637dafaac5c2c6f4f30b4fd0ce474 Mon Sep 17 00:00:00 2001
From: waterliu99 <qqcc2012game@163.com>
Date: Wed, 17 Jul 2019 18:07:08 +0800
Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=9D=83=E9=99=90=E8=AE=A4?=
 =?UTF-8?q?=E8=AF=81=E6=8E=A5=E5=8F=A3=E5=8F=8A=E9=BB=98=E8=AE=A4=E5=AE=9E?=
 =?UTF-8?q?=E7=8E=B0?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 .../Interfaces/IJT1078Authorization.cs        | 13 ++++
 .../Interfaces/IJT1078WebSocketBuilder.cs     |  1 +
 .../JT1078.DotNetty.Core.csproj               |  1 +
 .../Session/JT1078WebSocketSessionManager.cs  | 60 ++++---------------
 .../JT1078WebSocketPushHostedService.cs       | 59 ++++++++++--------
 src/JT1078.DotNetty.TestHosting/Program.cs    |  1 +
 .../JT1078AuthorizationDefault.cs             | 32 ++++++++++
 .../Handlers/JT1078WebSocketServerHandler.cs  | 34 +++++++----
 .../JT1078WebSocketBuilderDefault.cs          |  6 ++
 .../JT1078WebSocketDotnettyExtensions.cs      |  3 +
 10 files changed, 124 insertions(+), 86 deletions(-)
 create mode 100644 src/JT1078.DotNetty.Core/Interfaces/IJT1078Authorization.cs
 create mode 100644 src/JT1078.DotNetty.WebSocket/Authorization/JT1078AuthorizationDefault.cs

diff --git a/src/JT1078.DotNetty.Core/Interfaces/IJT1078Authorization.cs b/src/JT1078.DotNetty.Core/Interfaces/IJT1078Authorization.cs
new file mode 100644
index 0000000..3cbe1e1
--- /dev/null
+++ b/src/JT1078.DotNetty.Core/Interfaces/IJT1078Authorization.cs
@@ -0,0 +1,13 @@
+using DotNetty.Codecs.Http;
+using System;
+using System.Collections.Generic;
+using System.Security.Principal;
+using System.Text;
+
+namespace JT1078.DotNetty.Core.Interfaces
+{
+    public interface IJT1078Authorization
+    {
+        bool Authorization(IFullHttpRequest request, out IPrincipal principal);
+    }
+}
diff --git a/src/JT1078.DotNetty.Core/Interfaces/IJT1078WebSocketBuilder.cs b/src/JT1078.DotNetty.Core/Interfaces/IJT1078WebSocketBuilder.cs
index 2da46df..d96181b 100644
--- a/src/JT1078.DotNetty.Core/Interfaces/IJT1078WebSocketBuilder.cs
+++ b/src/JT1078.DotNetty.Core/Interfaces/IJT1078WebSocketBuilder.cs
@@ -9,5 +9,6 @@ namespace JT1078.DotNetty.Core.Interfaces
     {
         IJT1078Builder Instance { get; }
         IJT1078Builder Builder();
+        IJT1078WebSocketBuilder Replace<T>() where T : IJT1078Authorization;
     }
 }
diff --git a/src/JT1078.DotNetty.Core/JT1078.DotNetty.Core.csproj b/src/JT1078.DotNetty.Core/JT1078.DotNetty.Core.csproj
index da38c70..9eaa5a1 100644
--- a/src/JT1078.DotNetty.Core/JT1078.DotNetty.Core.csproj
+++ b/src/JT1078.DotNetty.Core/JT1078.DotNetty.Core.csproj
@@ -9,6 +9,7 @@
   </PropertyGroup>
 
   <ItemGroup>
+    <PackageReference Include="DotNetty.Codecs.Http" Version="0.6.0" />
     <PackageReference Include="DotNetty.Handlers" Version="0.6.0" />
     <PackageReference Include="DotNetty.Transport.Libuv" Version="0.6.0" />
     <PackageReference Include="DotNetty.Codecs" Version="0.6.0" />
diff --git a/src/JT1078.DotNetty.Core/Session/JT1078WebSocketSessionManager.cs b/src/JT1078.DotNetty.Core/Session/JT1078WebSocketSessionManager.cs
index 1edbeba..458cb99 100644
--- a/src/JT1078.DotNetty.Core/Session/JT1078WebSocketSessionManager.cs
+++ b/src/JT1078.DotNetty.Core/Session/JT1078WebSocketSessionManager.cs
@@ -21,79 +21,43 @@ namespace JT1078.DotNetty.Core.Session
             logger = loggerFactory.CreateLogger<JT1078WebSocketSessionManager>();
         }
 
-        private ConcurrentDictionary<string, JT1078WebSocketSession> SessionIdDict = new ConcurrentDictionary<string, JT1078WebSocketSession>(StringComparer.OrdinalIgnoreCase);
+        private ConcurrentDictionary<string, JT1078WebSocketSession> SessionDict = new ConcurrentDictionary<string,JT1078WebSocketSession>();
 
         public int SessionCount
         {
             get
             {
-                return SessionIdDict.Count;
+                return SessionDict.Count;
             }
         }
 
-        public JT1078WebSocketSession GetSession(string userId)
+        public List<JT1078WebSocketSession> GetSessions(string userId)
         {
-            if (string.IsNullOrEmpty(userId))
-                return default;
-            if (SessionIdDict.TryGetValue(userId, out JT1078WebSocketSession targetSession))
-            {
-                return targetSession;
-            }
-            else
-            {
-                return default;
-            }
-        }
-
-        public void TryAdd(string terminalPhoneNo,IChannel channel)
-        {
-            if (SessionIdDict.TryGetValue(terminalPhoneNo, out JT1078WebSocketSession oldSession))
-            {
-                oldSession.LastActiveTime = DateTime.Now;
-                oldSession.Channel = channel;
-                SessionIdDict.TryUpdate(terminalPhoneNo, oldSession, oldSession);
-            }
-            else
-            {
-                JT1078WebSocketSession session = new JT1078WebSocketSession(channel, terminalPhoneNo);
-                if (SessionIdDict.TryAdd(terminalPhoneNo, session))
-                {
-
-                }
-            }
+           return SessionDict.Where(m => m.Value.UserId == userId).Select(m=>m.Value).ToList();
         }
 
-        public JT1078WebSocketSession RemoveSession(string terminalPhoneNo)
+        public void TryAdd(string userId,IChannel channel)
         {
-            if (string.IsNullOrEmpty(terminalPhoneNo)) return default;
-            if (SessionIdDict.TryRemove(terminalPhoneNo, out JT1078WebSocketSession sessionRemove))
+            SessionDict.TryAdd(channel.Id.AsShortText(), new JT1078WebSocketSession(channel, userId));
+            if (logger.IsEnabled(LogLevel.Information))
             {
-                logger.LogInformation($">>>{terminalPhoneNo} Session Remove.");
-                return sessionRemove;
+                logger.LogInformation($">>>{userId},{channel.Id.AsShortText()} Channel Connection.");
             }
-            else
-            {
-                return default;
-            }  
         }
 
         public void RemoveSessionByChannel(IChannel channel)
         {
-            var terminalPhoneNos = SessionIdDict.Where(w => w.Value.Channel.Id == channel.Id).Select(s => s.Key).ToList();
-            if (terminalPhoneNos.Count > 0)
+            if (channel.Open&& SessionDict.TryRemove(channel.Id.AsShortText(), out var session))
             {
-                foreach (var key in terminalPhoneNos)
+                if (logger.IsEnabled(LogLevel.Information))
                 {
-                    SessionIdDict.TryRemove(key, out JT1078WebSocketSession sessionRemove);
+                    logger.LogInformation($">>>{session.UserId},{session.Channel.Id.AsShortText()} Channel Remove.");
                 }
-                string nos = string.Join(",", terminalPhoneNos);
-                logger.LogInformation($">>>{nos} Channel Remove.");
             }
         }
-
         public IEnumerable<JT1078WebSocketSession> GetAll()
         {
-            return SessionIdDict.Select(s => s.Value).ToList();
+            return SessionDict.Select(s => s.Value).ToList();
         }
     }
 }
diff --git a/src/JT1078.DotNetty.TestHosting/JT1078WebSocketPushHostedService.cs b/src/JT1078.DotNetty.TestHosting/JT1078WebSocketPushHostedService.cs
index 17a88cd..82650cc 100644
--- a/src/JT1078.DotNetty.TestHosting/JT1078WebSocketPushHostedService.cs
+++ b/src/JT1078.DotNetty.TestHosting/JT1078WebSocketPushHostedService.cs
@@ -38,37 +38,44 @@ namespace JT1078.DotNetty.TestHosting
                     {
                         foreach (var item in jT1078DataService.DataBlockingCollection.GetConsumingEnumerable())
                         {
+                            //if (jT1078WebSocketSessionManager.GetAll().Count() > 0)
+                            //{
+                            //    Parallel.ForEach(jT1078WebSocketSessionManager.GetAll(), new ParallelOptions { MaxDegreeOfParallelism = 5 }, session =>
+                            //    {
+                            //        //if (item.Label3.SubpackageType == JT1078SubPackageType.分包处理时的第一个包)
+                            //        //{
+                            //        //    SubcontractKey.TryRemove(item.SIM, out _);
+
+                            //        //    SubcontractKey.TryAdd(item.SIM, item.Bodies);
+                            //        //}
+                            //        //else if (item.Label3.SubpackageType == JT1078SubPackageType.分包处理时的中间包)
+                            //        //{
+                            //        //    if (SubcontractKey.TryGetValue(item.SIM, out var buffer))
+                            //        //    {
+                            //        //        SubcontractKey[item.SIM] = buffer.Concat(item.Bodies).ToArray();
+                            //        //    }
+                            //        //}
+                            //        //else if (item.Label3.SubpackageType == JT1078SubPackageType.分包处理时的最后一个包)
+                            //        //{
+                            //        //    if (SubcontractKey.TryGetValue(item.SIM, out var buffer))
+                            //        //    {
+                            //        //        session.Channel.WriteAndFlushAsync(new BinaryWebSocketFrame(Unpooled.WrappedBuffer(buffer.Concat(item.Bodies).ToArray())));
+                            //        //    }
+                            //        //}
+                            //        //else
+                            //        //{
+                            //                session.Channel.WriteAndFlushAsync(new BinaryWebSocketFrame(Unpooled.WrappedBuffer(item.Bodies)));
+                            //        // }
+                            //    });
+                            //}
+
                             if (jT1078WebSocketSessionManager.GetAll().Count() > 0)
                             {
                                 Parallel.ForEach(jT1078WebSocketSessionManager.GetAll(), new ParallelOptions { MaxDegreeOfParallelism = 5 }, session =>
-                                {
-                                    //if (item.Label3.SubpackageType == JT1078SubPackageType.分包处理时的第一个包)
-                                    //{
-                                    //    SubcontractKey.TryRemove(item.SIM, out _);
-
-                                    //    SubcontractKey.TryAdd(item.SIM, item.Bodies);
-                                    //}
-                                    //else if (item.Label3.SubpackageType == JT1078SubPackageType.分包处理时的中间包)
-                                    //{
-                                    //    if (SubcontractKey.TryGetValue(item.SIM, out var buffer))
-                                    //    {
-                                    //        SubcontractKey[item.SIM] = buffer.Concat(item.Bodies).ToArray();
-                                    //    }
-                                    //}
-                                    //else if (item.Label3.SubpackageType == JT1078SubPackageType.分包处理时的最后一个包)
-                                    //{
-                                    //    if (SubcontractKey.TryGetValue(item.SIM, out var buffer))
-                                    //    {
-                                    //        session.Channel.WriteAndFlushAsync(new BinaryWebSocketFrame(Unpooled.WrappedBuffer(buffer.Concat(item.Bodies).ToArray())));
-                                    //    }
-                                    //}
-                                    //else
-                                    //{
-                                            session.Channel.WriteAndFlushAsync(new BinaryWebSocketFrame(Unpooled.WrappedBuffer(item.Bodies)));
-                                    // }
+                                {                             
+                                    session.Channel.WriteAndFlushAsync(new BinaryWebSocketFrame(Unpooled.WrappedBuffer(item.Bodies)));
                                 });
                             }
-                            
                         }
                     }
                     catch
diff --git a/src/JT1078.DotNetty.TestHosting/Program.cs b/src/JT1078.DotNetty.TestHosting/Program.cs
index 25a9abd..0da1aa0 100644
--- a/src/JT1078.DotNetty.TestHosting/Program.cs
+++ b/src/JT1078.DotNetty.TestHosting/Program.cs
@@ -66,6 +66,7 @@ namespace JT1078.DotNetty.TestHosting
                             //.Replace<JT1078UdpMessageHandlers>()
                             //.Builder()
                             .AddJT1078WebSocketHost()
+                           // .Replace()
                             .Builder();
                     services.AddHostedService<JT1078WebSocketPushHostedService>();
                 });
diff --git a/src/JT1078.DotNetty.WebSocket/Authorization/JT1078AuthorizationDefault.cs b/src/JT1078.DotNetty.WebSocket/Authorization/JT1078AuthorizationDefault.cs
new file mode 100644
index 0000000..9679cac
--- /dev/null
+++ b/src/JT1078.DotNetty.WebSocket/Authorization/JT1078AuthorizationDefault.cs
@@ -0,0 +1,32 @@
+using DotNetty.Codecs.Http;
+using JT1078.DotNetty.Core.Interfaces;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Security.Claims;
+using System.Security.Principal;
+using System.Text;
+
+namespace JT1078.DotNetty.WebSocket.Authorization
+{
+    class JT1078AuthorizationDefault : IJT1078Authorization
+    {
+        public bool Authorization(IFullHttpRequest request, out IPrincipal principal)
+        {
+            var uriSpan = request.Uri.AsSpan();
+            var uriParamStr = uriSpan.Slice(uriSpan.IndexOf('?')+1).ToString().ToLower();
+            var uriParams = uriParamStr.Split('&');
+            var tokenParam = uriParams.FirstOrDefault(m => m.Contains("token"));
+            if (!string.IsNullOrEmpty(tokenParam))
+            {
+                principal = new ClaimsPrincipal(new GenericIdentity(tokenParam.Split('=')[1]));
+                return true;
+            }
+            else
+            {
+                principal = null;
+                return false;
+            }
+        }
+    }
+}
diff --git a/src/JT1078.DotNetty.WebSocket/Handlers/JT1078WebSocketServerHandler.cs b/src/JT1078.DotNetty.WebSocket/Handlers/JT1078WebSocketServerHandler.cs
index 6ef013d..6a1a86b 100644
--- a/src/JT1078.DotNetty.WebSocket/Handlers/JT1078WebSocketServerHandler.cs
+++ b/src/JT1078.DotNetty.WebSocket/Handlers/JT1078WebSocketServerHandler.cs
@@ -12,6 +12,7 @@ using static DotNetty.Codecs.Http.HttpResponseStatus;
 using Microsoft.Extensions.Logging;
 using JT1078.DotNetty.Core.Session;
 using System.Text.RegularExpressions;
+using JT1078.DotNetty.Core.Interfaces;
 
 namespace JT1078.DotNetty.WebSocket.Handlers
 {
@@ -24,12 +25,15 @@ namespace JT1078.DotNetty.WebSocket.Handlers
         private readonly ILogger<JT1078WebSocketServerHandler> logger;
 
         private readonly JT1078WebSocketSessionManager jT1078WebSocketSessionManager;
+        private readonly IJT1078Authorization iJT1078Authorization;
 
         public JT1078WebSocketServerHandler(
             JT1078WebSocketSessionManager jT1078WebSocketSessionManager,
+            IJT1078Authorization iJT1078Authorization,
             ILoggerFactory loggerFactory)
         {
             this.jT1078WebSocketSessionManager = jT1078WebSocketSessionManager;
+            this.iJT1078Authorization = iJT1078Authorization;
             logger = loggerFactory.CreateLogger<JT1078WebSocketServerHandler>();
         }
 
@@ -46,7 +50,7 @@ namespace JT1078.DotNetty.WebSocket.Handlers
         protected override void ChannelRead0(IChannelHandlerContext ctx, object msg)
         {
             if (msg is IFullHttpRequest request)
-            {
+            {             
                 this.HandleHttpRequest(ctx, request);
             }
             else if (msg is WebSocketFrame frame)
@@ -77,19 +81,24 @@ namespace JT1078.DotNetty.WebSocket.Handlers
                 SendHttpResponse(ctx, req, res);
                 return;
             }
-            // Handshake
-            var wsFactory = new WebSocketServerHandshakerFactory(GetWebSocketLocation(req), null, true, 5 * 1024 * 1024);
-            this.handshaker = wsFactory.NewHandshaker(req);
-            if (this.handshaker == null)
+            if (iJT1078Authorization.Authorization(req, out var principal))
             {
-                WebSocketServerHandshakerFactory.SendUnsupportedVersionResponse(ctx.Channel);
+                // Handshake
+                var wsFactory = new WebSocketServerHandshakerFactory(GetWebSocketLocation(req), null, true, 5 * 1024 * 1024);
+                this.handshaker = wsFactory.NewHandshaker(req);
+                if (this.handshaker == null)
+                {
+                    WebSocketServerHandshakerFactory.SendUnsupportedVersionResponse(ctx.Channel);
+                }
+                else
+                {
+                    this.handshaker.HandshakeAsync(ctx.Channel, req);
+                    jT1078WebSocketSessionManager.TryAdd(principal.Identity.Name, ctx.Channel);
+                }
             }
-            else
-            {
-                this.handshaker.HandshakeAsync(ctx.Channel, req);
-                var uriSpan = req.Uri.AsSpan();
-                var userId = uriSpan.Slice(uriSpan.IndexOf('?')).ToString().Split('=')[1];
-                jT1078WebSocketSessionManager.TryAdd(userId, ctx.Channel);
+            else {
+                SendHttpResponse(ctx, req, new DefaultFullHttpResponse(Http11, Unauthorized));
+                return;
             }
         }
 
@@ -141,6 +150,7 @@ namespace JT1078.DotNetty.WebSocket.Handlers
         public override void ExceptionCaught(IChannelHandlerContext ctx, Exception e)
         {
             logger.LogError(e, ctx.Channel.Id.AsShortText());
+            ctx.Channel.WriteAndFlushAsync(new DefaultFullHttpResponse(Http11, InternalServerError));
             jT1078WebSocketSessionManager.RemoveSessionByChannel(ctx.Channel);
             ctx.CloseAsync();
         }
diff --git a/src/JT1078.DotNetty.WebSocket/JT1078WebSocketBuilderDefault.cs b/src/JT1078.DotNetty.WebSocket/JT1078WebSocketBuilderDefault.cs
index 0c2c017..e4d7fe3 100644
--- a/src/JT1078.DotNetty.WebSocket/JT1078WebSocketBuilderDefault.cs
+++ b/src/JT1078.DotNetty.WebSocket/JT1078WebSocketBuilderDefault.cs
@@ -20,5 +20,11 @@ namespace JT1078.DotNetty.WebSocket
         {
             return Instance;
         }
+
+        public IJT1078WebSocketBuilder Replace<T>() where T : IJT1078Authorization
+        {
+            Instance.Services.Replace(new ServiceDescriptor(typeof(IJT1078Authorization), typeof(T), ServiceLifetime.Singleton));
+            return this;
+        }
     }
 }
diff --git a/src/JT1078.DotNetty.WebSocket/JT1078WebSocketDotnettyExtensions.cs b/src/JT1078.DotNetty.WebSocket/JT1078WebSocketDotnettyExtensions.cs
index b4acc99..81723e1 100644
--- a/src/JT1078.DotNetty.WebSocket/JT1078WebSocketDotnettyExtensions.cs
+++ b/src/JT1078.DotNetty.WebSocket/JT1078WebSocketDotnettyExtensions.cs
@@ -1,6 +1,8 @@
 using JT1078.DotNetty.Core.Codecs;
+using JT1078.DotNetty.Core.Impl;
 using JT1078.DotNetty.Core.Interfaces;
 using JT1078.DotNetty.Core.Session;
+using JT1078.DotNetty.WebSocket.Authorization;
 using JT1078.DotNetty.WebSocket.Handlers;
 using Microsoft.Extensions.DependencyInjection;
 using Microsoft.Extensions.DependencyInjection.Extensions;
@@ -14,6 +16,7 @@ namespace JT1078.DotNetty.WebSocket
         public static IJT1078WebSocketBuilder AddJT1078WebSocketHost(this IJT1078Builder builder)
         {
             builder.Services.TryAddSingleton<JT1078WebSocketSessionManager>();
+            builder.Services.TryAddSingleton<IJT1078Authorization,JT1078AuthorizationDefault>();
             builder.Services.AddScoped<JT1078WebSocketServerHandler>();
             builder.Services.AddHostedService<JT1078WebSocketServerHost>();
             return new JT1078WebSocketBuilderDefault(builder);