using Castle.Core.Logging; using Pilz.Data; using Pilz.Extensions.Reflection; using Pilz.Jobs; using System.Diagnostics.CodeAnalysis; using System.Net; using System.Reflection; using System.Runtime.InteropServices; using System.Text.RegularExpressions; using System.Web; using static Pilz.Net.Api.IApiServer; namespace Pilz.Net.Api; public class ApiServer : IApiServer { protected record struct ThreadHolder(Thread? Thread); public class MissingDataManagerException : Exception { } protected readonly List handlers = []; protected readonly List handlerObjects = []; protected readonly Dictionary serializers = []; protected readonly Dictionary managers = []; protected HttpListener httpListener; protected uint restartAttempts = 0; protected DateTime lastRestartAttempt; protected SemaphoreSlim? semaphore; protected bool doListen; protected bool isAutoRestarting; protected bool initializedHandlers; public event OnCheckAuthenticationEventHandler? OnCheckAuthentication; public event OnCheckContextEventHandler? OnCheckContext; public event OnCheckContextEventHandler? OnCheckContextCompleted; public event OnGetNewDataManagerEventHandler? OnGetNewDataManager; public event DataManagerEventHandler? OnResetDataManager; protected record PrivateParameterInfo(string Name, int Index); protected record PrivateMessageHandler(string Url, bool UseRegEx, Delegate Handler, PrivateParameterInfo[] Parameters, ApiMessageHandlerAttribute Attribute); protected record PrivateApiResult(ApiResult Original, object? ResultContent); public int HandlersCount => handlers.Count; public int ManagersCount => managers.Count; public uint RestartAttempts => restartAttempts; public string ApiUrl { get; } public uint ApiVersion { get; set; } = 1; public virtual bool EnableAuth { get; set; } public IApiMessageSerializer Serializer { get; set; } = new DefaultApiMessageSerializer(); public ILogger Log { get; set; } = NullLogger.Instance; public bool DebugMode { get; set; } public bool AllowMultipleRequests { get; set; } public int StopDelay { get; set; } = 5000; public bool AutoRestartOnError { get; set; } = true; public int MaxAutoRestartsPerMinute { get; set; } = 10; public int MaxConcurentConnections { get; set; } = 5; public IDataManager Manager => GetManager(); public bool ThreadedDataManager { get; set; } public JobCenter Jobs { get; } = new(); public ApiServer(string apiUrl) : this(apiUrl, null) { } public ApiServer(string apiUrl, HttpListener? httpListener) { ApiUrl = apiUrl; this.httpListener = httpListener ?? CreateDefaultHttpListener(); Jobs.BeforeExecute += Jobs_BeforeExecute; Jobs.AfterExecute += Jobs_AfterExecute; } private void Jobs_BeforeExecute(object? sender, EventArgs e) { ResetManager(); } private void Jobs_AfterExecute(object? sender, EventArgs e) { ResetManager(); } private IDataManager GetManager() { var curThread = ThreadedDataManager ? Thread.CurrentThread : null; var threadHolder = new ThreadHolder(curThread); IDataManager manager; lock (managers) { if (managers.TryGetValue(threadHolder, out var mgr)) return mgr; if (OnGetNewDataManager?.Invoke(this, EventArgs.Empty) is not IDataManager managerr) throw new MissingDataManagerException(); manager = managerr; managers.Add(threadHolder, manager); } return manager; } public virtual void ResetManager() { if (managers.Remove(new(ThreadedDataManager ? Thread.CurrentThread : null), out var manager)) OnResetDataManager?.Invoke(this, new(manager)); } protected virtual HttpListener CreateDefaultHttpListener() { var httpListener = new HttpListener(); httpListener.TimeoutManager.IdleConnection = new TimeSpan(0, 2, 0); if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { httpListener.TimeoutManager.RequestQueue = new TimeSpan(0, 2, 0); httpListener.TimeoutManager.HeaderWait = new TimeSpan(0, 2, 0); } return httpListener; } public virtual void Start() { Log.Info("Starting listener"); InitializeHandlers(); httpListener.Prefixes.Add(ApiUrl + "/"); doListen = true; httpListener.Start(); Log.Info("Started listener"); Jobs.Start(); Receive(); } public virtual void Stop(bool graceful) { Log.Info("Stopping listener"); Jobs.Stop(); doListen = false; if (graceful) httpListener.Stop(); else httpListener.Abort(); Thread.Sleep(StopDelay); httpListener.Close(); Log.Info("Stopped listener"); } public virtual void Restart(bool graceful) { Log.Info("Restarting listener"); Stop(graceful); httpListener = new(); semaphore?.Release(int.MaxValue); Start(); Log.Info("Restarted listener"); } protected virtual bool AutoRestart(bool graceful) { if (AutoRestartOnError) return false; if (isAutoRestarting) return true; var now = DateTime.Now; if (now - lastRestartAttempt > TimeSpan.FromMinutes(1)) { lastRestartAttempt = now; restartAttempts = 0; } if (restartAttempts > MaxAutoRestartsPerMinute) { Log.Fatal("Reached maximum auto-restart attempts"); Stop(false); return false; } restartAttempts += 1; isAutoRestarting = true; Restart(graceful); isAutoRestarting = false; return true; } protected virtual void WaitForSlot() { if (!AllowMultipleRequests) return; // Unlimited in this case semaphore ??= new(MaxConcurentConnections, MaxConcurentConnections); semaphore.Wait(); } protected virtual void FreeSlot() { if (semaphore != null && semaphore.CurrentCount < MaxConcurentConnections) semaphore.Release(); } protected virtual void InitializeHandlers() { foreach (var instance in handlerObjects) { if (instance is IApiHandlerInitializer initializer) initializer.Initialize(); } } public virtual T? GetHandler() { return handlerObjects.OfType().FirstOrDefault(); } public virtual Dictionary GetEndpoints() { return handlers.OrderBy(n => n.Attribute.Route).GroupBy(n => n.Attribute.Route).ToDictionary(n => n.Key, n => n.SelectMany(n => n.Attribute.Methods.OrderBy(n => n)).ToArray()); } public virtual void RegisterHandler(T instance) where T : class { if (!handlerObjects.Contains(instance)) handlerObjects.Add(instance); // Get all public instance methods var methods = instance.GetType().GetMethods(BindingFlags.Instance | BindingFlags.Public); // Register each method foreach (var method in methods) RegisterHandler(method.CreateDelegate(instance), false); } public virtual void RegisterHandler(Delegate handler) { RegisterHandler(handler, true); } public virtual void RegisterHandler(Delegate handler, bool throwOnError) { var method = handler.Method; // Sanity checks if (method.GetCustomAttribute() is not ApiMessageHandlerAttribute attribute || !method.ReturnType.IsAssignableTo(typeof(ApiResult))) { if (throwOnError) throw new NotSupportedException("The first parameter needs to be of type ApiMessage and must return an ApiResult object and the method must have the MessageHandlerAttribute."); return; } RegisterHandler(handler, attribute, throwOnError); } public virtual void RegisterHandler(Delegate handler, ApiMessageHandlerAttribute attribute, bool throwOnError) { // Resolves parameters var url = attribute.Route; var useRegEx = false; var nextBreacket = url.IndexOf('{'); var parameters = new List(); while (nextBreacket != -1) { var endBreacket = url.IndexOf('}', nextBreacket + 1); if (endBreacket != -1) { var name = url.Substring(nextBreacket + 1, endBreacket - nextBreacket - 1); const string regex = "[A-Za-z0-9%_.-]+"; url = url.Replace(url.Substring(nextBreacket, endBreacket - nextBreacket + 1), regex); var index = url[..(nextBreacket + 1)].Split('/').Length - 1; parameters.Add(new(name, index)); useRegEx = true; nextBreacket = url.IndexOf('{', endBreacket + 1); } } if (useRegEx) { url = url.Replace(".", "\\."); // Escape special characters url = $"^{url}$"; // Define start and end of line (matches whole string) } // Add handler Log.InfoFormat("Added handler for {0}", attribute.Route); handlers.Add(new(url, useRegEx, handler, [.. parameters], attribute)); } protected void Receive() { if (httpListener.IsListening) httpListener?.BeginGetContext(ListenerCallback, null); } protected void ListenerCallback(IAsyncResult result) { HttpListenerContext? context; // Skip if not lisstening anymore if (!httpListener.IsListening || !doListen) return; // Wait for a free slot try { WaitForSlot(); } catch (ObjectDisposedException) { return; } catch (Exception ex) { Log.Fatal($"Too many concurent connections", ex); Thread.Sleep(1000); FreeSlot(); return; } // Get context try { context = httpListener.EndGetContext(result); } catch (ObjectDisposedException) { return; } catch (HttpListenerException ex) { if (ex.ErrorCode == 995 || ex.ErrorCode == 64) { Log.Fatal($"Fatal http error retriving context with code {ex.ErrorCode}", ex); if (AutoRestart(false)) // Try restart the server and skip this context return; } else Log.Error($"Http error retriving context with code {ex.ErrorCode}", ex); context = null; } catch (Exception ex) { Log.Error("Error retriving context", ex); context = null; } // Immitatly listen for new request if (AllowMultipleRequests) Receive(); // Check context var success = false; if (context is not null) { Log.Info("Request retrived for " + context.Request.RawUrl); ResetManager(); OnCheckContext?.Invoke(this, new(context)); try { CheckContext(context); success = true; } catch (MissingDataManagerException mdmex) { Log.Error("DataManager is not supported by this server instance!", mdmex); if (DebugMode) throw; } catch (Exception ex) { Log.Error("Error checking context", ex); if (DebugMode) throw; } finally { OnCheckContextCompleted?.Invoke(this, new(context)); ResetManager(); } } // Try closing the request on fail if (!success) { try { context?.Response.Close(); } catch (Exception) { } } FreeSlot(); // Listen for new request if (!AllowMultipleRequests) Receive(); } protected virtual void CheckContext(HttpListenerContext context) { Log.Debug("Start handling request"); void close(bool badRequest) { Log.Debug("End handling request"); if (badRequest) context.Response.StatusCode = (int)HttpStatusCode.BadRequest; context.Response.OutputStream.Close(); } // Parse url Log.Debug("Parse url"); var path = context.Request.Url?.AbsolutePath; var query = context.Request.Url?.Query; if (string.IsNullOrWhiteSpace(path)) { Log.Warn("Request has no path"); close(true); return; } // Find handler Log.Debug("Find handler"); if (!TryGetHandler(path, query, context.Request.HttpMethod, out var handler)) { Log.Warn("Request handler couldn't be found"); close(true); return; } // Get auth key Log.Debug("Get auth key"); if (context.Request.Headers.Get("API-AUTH-KEY") is not string authKey) authKey = null!; // Handle message Log.Debug("Handle mssage"); if (HandleMessage(context, path, query, handler, authKey) is not PrivateApiResult result) { Log.Warn("Request couldn't be handled"); close(true); return; } // Set response header Log.Debug("Set response headers"); context.Response.AppendHeader("API-VERSION", ApiVersion.ToString()); if (result.Original.Message is ApiRawMessage apiRawMessage && !string.IsNullOrWhiteSpace(apiRawMessage.FileName)) context.Response.AppendHeader("Content-Disposition", $"filename=\"{apiRawMessage.FileName}\""); // Set response parameters Log.Debug("Set response parameters"); context.Response.StatusCode = (int)result.Original.StatusCode; // Write response content Log.Debug("Create response"); if (result.ResultContent is string resultJson) { Log.Info("Sending json response for " + context.Request.RawUrl); context.Response.ContentType = "application/json"; using StreamWriter output = new(context.Response.OutputStream); output.Write(resultJson); output.Flush(); } else if (result.ResultContent is byte[] resultBytes) { Log.Info("Sending raw bytes response for " + context.Request.RawUrl); context.Response.ContentType = "application/octet-stream"; context.Response.ContentLength64 = resultBytes.Length; context.Response.OutputStream.Write(resultBytes, 0, resultBytes.Length); context.Response.OutputStream.Flush(); } else if (result.ResultContent is Stream resultStream) { Log.Info("Sending stream response for " + context.Request.RawUrl); context.Response.ContentType = "application/octet-stream"; context.Response.ContentLength64 = resultStream.Length; resultStream.CopyTo(context.Response.OutputStream); context.Response.OutputStream.Flush(); resultStream.Close(); } Log.Debug("Finish response"); close(false); return; } protected virtual PrivateApiResult? HandleMessage(HttpListenerContext context, string url, string? query, PrivateMessageHandler handler, string? authKey) { // Check authentication Log.Debug("Check authentication"); var isAuthenticated = false; if (string.IsNullOrWhiteSpace(authKey) || DecodeAuthKey(authKey) is not string authKeyDecoded) authKeyDecoded = null!; isAuthenticated = CheckAuthentication(authKeyDecoded, handler.Handler, context); if (!isAuthenticated) return new(ApiResult.Unauthorized(), null); // Get required infos Log.Debug("Identify message parameter type and serializer"); var targetType = handler.Handler.Method.GetParameters().FirstOrDefault(p => p.ParameterType.IsAssignableTo(typeof(ApiMessage)))?.ParameterType; var serializer = GetSerializer(handler.Attribute.Serializer); // Read input content Log.Debug("Read input content"); ApiMessage? message = null; var contentType = context.Request.ContentType; var contentTypeEnd = contentType?.IndexOf(';') ?? -1; if (contentType != null && contentTypeEnd != -1) contentType = contentType[..contentTypeEnd]; switch (contentType) { case "application/json": try { Log.Debug("Deserialize message"); using StreamReader input = new(context.Request.InputStream); var contentJson = input.ReadToEnd(); if (targetType != null) message = serializer.Deserialize(contentJson, targetType); } catch (OutOfMemoryException) { Log.Error("Error reading remote data due to missing memory"); return null; } catch (Exception ex) { Log.Error("Error reading remote data", ex); return null; } break; case "application/octet-stream": Log.Debug("Process raw message"); if (targetType == null) return null; else if (targetType.IsAssignableTo(typeof(ApiRawByteMessage))) { var ms = new MemoryStream((int)context.Request.ContentLength64); context.Request.InputStream.CopyTo(ms); message = new ApiRawByteMessage(ms.ToArray()); } else if (targetType.IsAssignableTo(typeof(ApiRawStreamMessage))) message = new ApiRawInputStreamMessage(context.Request.InputStream, context.Request.ContentLength64); break; } // Invoke handler Log.Debug("Invoke handler"); var parameters = BuildParameters(url, query, handler, () => message, () => new(message, isAuthenticated, authKeyDecoded, url, context.Request.HttpMethod, context)); if (handler.Handler.DynamicInvoke(parameters) is not ApiResult result) return new(ApiResult.InternalServerError(), null); // Return result without message Log.Debug("Check message"); if (result.Message is null) return new(result, null); // Return result with raw data if (result.Message is ApiRawByteMessage dataMsg) return new(result, dataMsg.Data); if (result.Message is ApiRawStreamMessage streamMsg) return new(result, streamMsg.Data); // Serializer Log.Debug("Serialize message"); if (serializer.Serialize(result.Message) is not string resultStr) return new(ApiResult.InternalServerError(), null); // Return result with message Log.Debug("Complete result"); return new(result, resultStr); } protected virtual bool TryGetHandler(string url, string? query, string method, [NotNullWhen(true)] out PrivateMessageHandler? handler) { // Filter by method var filtered = handlers.Where(handler => handler.Attribute.Methods.Contains(method)); // Check if equals via string comparation (ignore-case) handler = filtered.FirstOrDefault(handler => handler.Url.Equals(url, StringComparison.InvariantCultureIgnoreCase)); // Check if equals via RegEx handler ??= filtered.FirstOrDefault(handler => handler.UseRegEx && Regex.IsMatch(url, handler.Url, RegexOptions.IgnoreCase)); return handler != null; } protected virtual object?[]? BuildParameters(string url, string? query, PrivateMessageHandler handler, Func getMessage, Func getRequestInfo) { var infos = handler.Handler.Method.GetParameters(); var objs = new List(); var queryparams = query == null ? [] : HttpUtility.ParseQueryString(query); foreach (var info in infos) { if (info.ParameterType.IsAssignableTo(typeof(ApiMessage))) objs.Add(getMessage()); else if (info.ParameterType.IsAssignableTo(typeof(ApiRequestInfo))) objs.Add(getRequestInfo()); else if (info.ParameterType.IsAssignableTo(typeof(IDictionary))) objs.Add(queryparams.AllKeys.OfType().ToDictionary(n => n, n => HttpUtility.UrlDecode(queryparams.Get(n)))); else if (handler.Parameters.FirstOrDefault(p => p.Name.Equals(info.Name, StringComparison.InvariantCultureIgnoreCase)) is PrivateParameterInfo parameterInfo && url.Split('/').ElementAtOrDefault(parameterInfo.Index) is string parameterValue) objs.Add(Convert.ChangeType(HttpUtility.UrlDecode(parameterValue), info.ParameterType)); // or Uri.UnescapeDataString(); maybe run this line twice? else if (queryparams.AllKeys.FirstOrDefault(n => n != null && n.Equals(info.Name, StringComparison.InvariantCultureIgnoreCase)) is string querykey) objs.Add(Convert.ChangeType(HttpUtility.UrlDecode(queryparams.Get(querykey)), info.ParameterType)); else if (info.HasDefaultValue) objs.Add(info.DefaultValue); else objs.Add(null); } return [.. objs]; } protected virtual IApiMessageSerializer GetSerializer(Type? t) { if (t is not null) { if (serializers.TryGetValue(t, out var s) && s is not null) return s; else if (Activator.CreateInstance(t) is IApiMessageSerializer ss) { serializers.Add(t, ss); return ss; } } return Serializer; } protected virtual bool CheckAuthentication(string authKey, Delegate? handler, HttpListenerContext context) { if (OnCheckAuthentication != null) { var args = new ApiAuthCheckEventArgs(authKey, handler, context); OnCheckAuthentication?.Invoke(this, args); return args.Valid; } return false; } protected virtual string? DecodeAuthKey(string authKey) { return authKey; } }