Files
Pilz/Pilz.Net/Api/ApiServer.cs
Pilzinsel64 ffa6c647f8 final fix
2025-12-03 08:54:28 +01:00

648 lines
22 KiB
C#

using Castle.Core.Logging;
using Pilz.Data;
using Pilz.Extensions.Reflection;
using Pilz.Jobs;
using System.Diagnostics.CodeAnalysis;
using System.Net;
using System.Reflection;
using System.Runtime.InteropServices;
using System.Text.RegularExpressions;
using System.Web;
using static Pilz.Net.Api.IApiServer;
namespace Pilz.Net.Api;
public class ApiServer : IApiServer
{
protected record struct ThreadHolder(Thread? Thread);
public class MissingDataManagerException : Exception { }
protected readonly List<PrivateMessageHandler> handlers = [];
protected readonly List<object> handlerObjects = [];
protected readonly Dictionary<Type, IApiMessageSerializer> serializers = [];
protected readonly Dictionary<ThreadHolder, IDataManager> managers = [];
protected HttpListener httpListener;
protected uint restartAttempts = 0;
protected DateTime lastRestartAttempt;
protected SemaphoreSlim? semaphore;
protected bool doListen;
protected bool isAutoRestarting;
protected bool initializedHandlers;
public event OnCheckAuthenticationEventHandler? OnCheckAuthentication;
public event OnCheckContextEventHandler? OnCheckContext;
public event OnCheckContextEventHandler? OnCheckContextCompleted;
public event OnGetNewDataManagerEventHandler? OnGetNewDataManager;
public event DataManagerEventHandler? OnResetDataManager;
protected record PrivateParameterInfo(string Name, int Index);
protected record PrivateMessageHandler(string Url, bool UseRegEx, Delegate Handler, PrivateParameterInfo[] Parameters, ApiMessageHandlerAttribute Attribute);
protected record PrivateApiResult(ApiResult Original, object? ResultContent);
public int HandlersCount => handlers.Count;
public int ManagersCount => managers.Count;
public uint RestartAttempts => restartAttempts;
public string ApiUrl { get; }
public uint ApiVersion { get; set; } = 1;
public virtual bool EnableAuth { get; set; }
public IApiMessageSerializer Serializer { get; set; } = new DefaultApiMessageSerializer();
public ILogger Log { get; set; } = NullLogger.Instance;
public bool DebugMode { get; set; }
public bool AllowMultipleRequests { get; set; }
public int StopDelay { get; set; } = 5000;
public bool AutoRestartOnError { get; set; } = true;
public int MaxAutoRestartsPerMinute { get; set; } = 10;
public int MaxConcurentConnections { get; set; } = 5;
public IDataManager Manager => GetManager();
public bool ThreadedDataManager { get; set; }
public JobCenter Jobs { get; } = new();
public ApiServer(string apiUrl) : this(apiUrl, null)
{
}
public ApiServer(string apiUrl, HttpListener? httpListener)
{
ApiUrl = apiUrl;
this.httpListener = httpListener ?? CreateDefaultHttpListener();
Jobs.BeforeExecute += Jobs_BeforeExecute;
Jobs.AfterExecute += Jobs_AfterExecute;
}
private void Jobs_BeforeExecute(object? sender, EventArgs e)
{
ResetManager();
}
private void Jobs_AfterExecute(object? sender, EventArgs e)
{
ResetManager();
}
private IDataManager GetManager()
{
var curThread = ThreadedDataManager ? Thread.CurrentThread : null;
var threadHolder = new ThreadHolder(curThread);
IDataManager manager;
lock (managers)
{
if (managers.TryGetValue(threadHolder, out var mgr))
return mgr;
if (OnGetNewDataManager?.Invoke(this, EventArgs.Empty) is not IDataManager managerr)
throw new MissingDataManagerException();
manager = managerr;
managers.Add(threadHolder, manager);
}
return manager;
}
public virtual void ResetManager()
{
if (managers.Remove(new(ThreadedDataManager ? Thread.CurrentThread : null), out var manager))
OnResetDataManager?.Invoke(this, new(manager));
}
protected virtual HttpListener CreateDefaultHttpListener()
{
var httpListener = new HttpListener();
httpListener.TimeoutManager.IdleConnection = new TimeSpan(0, 2, 0);
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
httpListener.TimeoutManager.RequestQueue = new TimeSpan(0, 2, 0);
httpListener.TimeoutManager.HeaderWait = new TimeSpan(0, 2, 0);
}
return httpListener;
}
public virtual void Start()
{
Log.Info("Starting listener");
InitializeHandlers();
httpListener.Prefixes.Add(ApiUrl + "/");
doListen = true;
httpListener.Start();
Log.Info("Started listener");
Jobs.Start();
Receive();
}
public virtual void Stop(bool graceful)
{
Log.Info("Stopping listener");
Jobs.Stop();
doListen = false;
if (graceful)
httpListener.Stop();
else
httpListener.Abort();
Thread.Sleep(StopDelay);
httpListener.Close();
Log.Info("Stopped listener");
}
public virtual void Restart(bool graceful)
{
Log.Info("Restarting listener");
Stop(graceful);
httpListener = new();
semaphore?.Release(int.MaxValue);
Start();
Log.Info("Restarted listener");
}
protected virtual bool AutoRestart(bool graceful)
{
if (AutoRestartOnError)
return false;
if (isAutoRestarting)
return true;
var now = DateTime.Now;
if (now - lastRestartAttempt > TimeSpan.FromMinutes(1))
{
lastRestartAttempt = now;
restartAttempts = 0;
}
if (restartAttempts > MaxAutoRestartsPerMinute)
{
Log.Fatal("Reached maximum auto-restart attempts");
Stop(false);
return false;
}
restartAttempts += 1;
isAutoRestarting = true;
Restart(graceful);
isAutoRestarting = false;
return true;
}
protected virtual void WaitForSlot()
{
if (!AllowMultipleRequests)
return; // Unlimited in this case
semaphore ??= new(MaxConcurentConnections, MaxConcurentConnections);
semaphore.Wait();
}
protected virtual void FreeSlot()
{
if (semaphore != null && semaphore.CurrentCount < MaxConcurentConnections)
semaphore.Release();
}
protected virtual void InitializeHandlers()
{
foreach (var instance in handlerObjects)
{
if (instance is IApiHandlerInitializer initializer)
initializer.Initialize();
}
}
public virtual T? GetHandler<T>()
{
return handlerObjects.OfType<T>().FirstOrDefault();
}
public virtual Dictionary<string, string[]> GetEndpoints()
{
return handlers.OrderBy(n => n.Attribute.Route).GroupBy(n => n.Attribute.Route).ToDictionary(n => n.Key, n => n.SelectMany(n => n.Attribute.Methods.OrderBy(n => n)).ToArray());
}
public virtual void RegisterHandler<T>(T instance) where T : class
{
if (!handlerObjects.Contains(instance))
handlerObjects.Add(instance);
// Get all public instance methods
var methods = instance.GetType().GetMethods(BindingFlags.Instance | BindingFlags.Public);
// Register each method
foreach (var method in methods)
RegisterHandler(method.CreateDelegate(instance), false);
}
public virtual void RegisterHandler(Delegate handler)
{
RegisterHandler(handler, true);
}
public virtual void RegisterHandler(Delegate handler, bool throwOnError)
{
var method = handler.Method;
// Sanity checks
if (method.GetCustomAttribute<ApiMessageHandlerAttribute>() is not ApiMessageHandlerAttribute attribute
|| !method.ReturnType.IsAssignableTo(typeof(ApiResult)))
{
if (throwOnError)
throw new NotSupportedException("The first parameter needs to be of type ApiMessage and must return an ApiResult object and the method must have the MessageHandlerAttribute.");
return;
}
RegisterHandler(handler, attribute, throwOnError);
}
public virtual void RegisterHandler(Delegate handler, ApiMessageHandlerAttribute attribute, bool throwOnError)
{
// Resolves parameters
var url = attribute.Route;
var useRegEx = false;
var nextBreacket = url.IndexOf('{');
var parameters = new List<PrivateParameterInfo>();
while (nextBreacket != -1)
{
var endBreacket = url.IndexOf('}', nextBreacket + 1);
if (endBreacket != -1)
{
var name = url.Substring(nextBreacket + 1, endBreacket - nextBreacket - 1);
const string regex = "[A-Za-z0-9%_.-]+";
url = url.Replace(url.Substring(nextBreacket, endBreacket - nextBreacket + 1), regex);
var index = url[..(nextBreacket + 1)].Split('/').Length - 1;
parameters.Add(new(name, index));
useRegEx = true;
nextBreacket = url.IndexOf('{', endBreacket + 1);
}
}
if (useRegEx)
{
url = url.Replace(".", "\\."); // Escape special characters
url = $"^{url}$"; // Define start and end of line (matches whole string)
}
// Add handler
Log.InfoFormat("Added handler for {0}", attribute.Route);
handlers.Add(new(url, useRegEx, handler, [.. parameters], attribute));
}
protected void Receive()
{
if (httpListener.IsListening)
httpListener?.BeginGetContext(ListenerCallback, null);
}
protected void ListenerCallback(IAsyncResult result)
{
HttpListenerContext? context;
// Skip if not lisstening anymore
if (!httpListener.IsListening || !doListen)
return;
// Wait for a free slot
try
{
WaitForSlot();
}
catch (ObjectDisposedException)
{
return;
}
catch (Exception ex)
{
Log.Fatal($"Too many concurent connections", ex);
Thread.Sleep(1000);
FreeSlot();
return;
}
// Get context
try
{
context = httpListener.EndGetContext(result);
}
catch (ObjectDisposedException)
{
return;
}
catch (HttpListenerException ex)
{
if (ex.ErrorCode == 995 || ex.ErrorCode == 64)
{
Log.Fatal($"Fatal http error retriving context with code {ex.ErrorCode}", ex);
if (AutoRestart(false)) // Try restart the server and skip this context
return;
}
else
Log.Error($"Http error retriving context with code {ex.ErrorCode}", ex);
context = null;
}
catch (Exception ex)
{
Log.Error("Error retriving context", ex);
context = null;
}
// Immitatly listen for new request
if (AllowMultipleRequests)
Receive();
// Check context
var success = false;
if (context is not null)
{
Log.Info("Request retrived for " + context.Request.RawUrl);
ResetManager();
OnCheckContext?.Invoke(this, new(context));
try
{
CheckContext(context);
success = true;
}
catch (MissingDataManagerException mdmex)
{
Log.Error("DataManager is not supported by this server instance!", mdmex);
if (DebugMode) throw;
}
catch (Exception ex)
{
Log.Error("Error checking context", ex);
if (DebugMode) throw;
}
finally
{
OnCheckContextCompleted?.Invoke(this, new(context));
ResetManager();
}
}
// Try closing the request on fail
if (!success)
{
try
{
context?.Response.Close();
}
catch (Exception)
{
}
}
FreeSlot();
// Listen for new request
if (!AllowMultipleRequests)
Receive();
}
protected virtual void CheckContext(HttpListenerContext context)
{
Log.Debug("Start handling request");
void close(bool badRequest)
{
Log.Debug("End handling request");
if (badRequest)
context.Response.StatusCode = (int)HttpStatusCode.BadRequest;
context.Response.OutputStream.Close();
}
// Parse url
Log.Debug("Parse url");
var path = context.Request.Url?.AbsolutePath;
var query = context.Request.Url?.Query;
if (string.IsNullOrWhiteSpace(path))
{
Log.Warn("Request has no path");
close(true);
return;
}
// Find handler
Log.Debug("Find handler");
if (!TryGetHandler(path, query, context.Request.HttpMethod, out var handler))
{
Log.Warn("Request handler couldn't be found");
close(true);
return;
}
// Get auth key
Log.Debug("Get auth key");
if (context.Request.Headers.Get("API-AUTH-KEY") is not string authKey)
authKey = null!;
// Handle message
Log.Debug("Handle mssage");
if (HandleMessage(context, path, query, handler, authKey) is not PrivateApiResult result)
{
Log.Warn("Request couldn't be handled");
close(true);
return;
}
// Set response header
Log.Debug("Set response headers");
context.Response.AppendHeader("API-VERSION", ApiVersion.ToString());
if (result.Original.Message is ApiRawMessage apiRawMessage && !string.IsNullOrWhiteSpace(apiRawMessage.FileName))
context.Response.AppendHeader("Content-Disposition", $"filename=\"{apiRawMessage.FileName}\"");
// Set response parameters
Log.Debug("Set response parameters");
context.Response.StatusCode = (int)result.Original.StatusCode;
// Write response content
Log.Debug("Create response");
if (result.ResultContent is string resultJson)
{
Log.Info("Sending json response for " + context.Request.RawUrl);
context.Response.ContentType = "application/json";
using StreamWriter output = new(context.Response.OutputStream);
output.Write(resultJson);
output.Flush();
}
else if (result.ResultContent is byte[] resultBytes)
{
Log.Info("Sending raw bytes response for " + context.Request.RawUrl);
context.Response.ContentType = "application/octet-stream";
context.Response.ContentLength64 = resultBytes.Length;
context.Response.OutputStream.Write(resultBytes, 0, resultBytes.Length);
context.Response.OutputStream.Flush();
}
else if (result.ResultContent is Stream resultStream)
{
Log.Info("Sending stream response for " + context.Request.RawUrl);
context.Response.ContentType = "application/octet-stream";
context.Response.ContentLength64 = resultStream.Length;
resultStream.CopyTo(context.Response.OutputStream);
context.Response.OutputStream.Flush();
resultStream.Close();
}
Log.Debug("Finish response");
close(false);
return;
}
protected virtual PrivateApiResult? HandleMessage(HttpListenerContext context, string url, string? query, PrivateMessageHandler handler, string? authKey)
{
// Check authentication
Log.Debug("Check authentication");
var isAuthenticated = false;
if (string.IsNullOrWhiteSpace(authKey) || DecodeAuthKey(authKey) is not string authKeyDecoded)
authKeyDecoded = null!;
isAuthenticated = CheckAuthentication(authKeyDecoded, handler.Handler, context);
if (!isAuthenticated)
return new(ApiResult.Unauthorized(), null);
// Get required infos
Log.Debug("Identify message parameter type and serializer");
var targetType = handler.Handler.Method.GetParameters().FirstOrDefault(p => p.ParameterType.IsAssignableTo(typeof(ApiMessage)))?.ParameterType;
var serializer = GetSerializer(handler.Attribute.Serializer);
// Read input content
Log.Debug("Read input content");
ApiMessage? message = null;
switch (context.Request.ContentType)
{
case "application/json":
try
{
Log.Debug("Deserialize message");
using StreamReader input = new(context.Request.InputStream);
var contentJson = input.ReadToEnd();
if (targetType == null)
return null;
message = serializer.Deserialize(contentJson, targetType);
}
catch (OutOfMemoryException)
{
Log.Error("Error reading remote data due to missing memory");
return null;
}
catch (Exception ex)
{
Log.Error("Error reading remote data", ex);
return null;
}
break;
case "application/octet-stream":
Log.Debug("Process raw message");
if (targetType == null)
return null;
else if (targetType.IsAssignableTo(typeof(ApiRawByteMessage)))
{
var bytes = new byte[context.Request.ContentLength64];
context.Request.InputStream.Read(bytes, 0, bytes.Length);
message = new ApiRawByteMessage(bytes);
}
else if (targetType.IsAssignableTo(typeof(ApiRawStreamMessage)))
message = new ApiRawInputStreamMessage(context.Request.InputStream, context.Request.ContentLength64);
break;
}
// Invoke handler
Log.Debug("Invoke handler");
var parameters = BuildParameters(url, query, handler, () => message, () => new(message, isAuthenticated, authKeyDecoded, url, context.Request.HttpMethod, context));
if (handler.Handler.DynamicInvoke(parameters) is not ApiResult result)
return new(ApiResult.InternalServerError(), null);
// Return result without message
Log.Debug("Check message");
if (result.Message is null)
return new(result, null);
// Return result with raw data
if (result.Message is ApiRawByteMessage dataMsg)
return new(result, dataMsg.Data);
if (result.Message is ApiRawStreamMessage streamMsg)
return new(result, streamMsg.Data);
// Serializer
Log.Debug("Serialize message");
if (serializer.Serialize(result.Message) is not string resultStr)
return new(ApiResult.InternalServerError(), null);
// Return result with message
Log.Debug("Complete result");
return new(result, resultStr);
}
protected virtual bool TryGetHandler(string url, string? query, string method, [NotNullWhen(true)] out PrivateMessageHandler? handler)
{
// Filter by method
var filtered = handlers.Where(handler => handler.Attribute.Methods.Contains(method));
// Check if equals via string comparation (ignore-case)
handler = filtered.FirstOrDefault(handler => handler.Url.Equals(url, StringComparison.InvariantCultureIgnoreCase));
// Check if equals via RegEx
handler ??= filtered.FirstOrDefault(handler => handler.UseRegEx && Regex.IsMatch(url, handler.Url, RegexOptions.IgnoreCase));
return handler != null;
}
protected virtual object?[]? BuildParameters(string url, string? query, PrivateMessageHandler handler, Func<ApiMessage?> getMessage, Func<ApiRequestInfo> getRequestInfo)
{
var infos = handler.Handler.Method.GetParameters();
var objs = new List<object?>();
var queryparams = query == null ? [] : HttpUtility.ParseQueryString(query);
foreach (var info in infos)
{
if (info.ParameterType.IsAssignableTo(typeof(ApiMessage)))
objs.Add(getMessage());
else if (info.ParameterType.IsAssignableTo(typeof(ApiRequestInfo)))
objs.Add(getRequestInfo());
else if (info.ParameterType.IsAssignableTo(typeof(IDictionary<string, string>)))
objs.Add(queryparams.AllKeys.OfType<string>().ToDictionary(n => n, n => HttpUtility.UrlDecode(queryparams.Get(n))));
else if (handler.Parameters.FirstOrDefault(p => p.Name.Equals(info.Name, StringComparison.InvariantCultureIgnoreCase)) is PrivateParameterInfo parameterInfo
&& url.Split('/').ElementAtOrDefault(parameterInfo.Index) is string parameterValue)
objs.Add(Convert.ChangeType(HttpUtility.UrlDecode(parameterValue), info.ParameterType)); // or Uri.UnescapeDataString(); maybe run this line twice?
else if (queryparams.AllKeys.FirstOrDefault(n => n != null && n.Equals(info.Name, StringComparison.InvariantCultureIgnoreCase)) is string querykey)
objs.Add(Convert.ChangeType(HttpUtility.UrlDecode(queryparams.Get(querykey)), info.ParameterType));
else if (info.HasDefaultValue)
objs.Add(info.DefaultValue);
else
objs.Add(null);
}
return [.. objs];
}
protected virtual IApiMessageSerializer GetSerializer(Type? t)
{
if (t is not null)
{
if (serializers.TryGetValue(t, out var s) && s is not null)
return s;
else if (Activator.CreateInstance(t) is IApiMessageSerializer ss)
{
serializers.Add(t, ss);
return ss;
}
}
return Serializer;
}
protected virtual bool CheckAuthentication(string authKey, Delegate? handler, HttpListenerContext context)
{
if (OnCheckAuthentication != null)
{
var args = new ApiAuthCheckEventArgs(authKey, handler, context);
OnCheckAuthentication?.Invoke(this, args);
return args.Valid;
}
return false;
}
protected virtual string? DecodeAuthKey(string authKey)
{
return authKey;
}
}