diff --git a/src/server/access_control.zig b/src/server/access_control.zig new file mode 100644 index 0000000..c5ba030 --- /dev/null +++ b/src/server/access_control.zig @@ -0,0 +1,65 @@ +const std = @import("std"); + +const log = @import("logger.zig"); +const Config = @import("config.zig").Config; +const utils = @import("utils.zig"); + +pub const AccessControl = struct { + logger: *const log.Logger, + config: *const Config, + + pub fn init(config: *const Config, logger: *const log.Logger) AccessControl { + return AccessControl{ .logger = logger, .config = config }; + } + + // Check whether a new connection, represented by the provided network address, can be established. + // + // # Arguments + // * `self` - The AccessControl instance managing access control configurations. + // * `address` - The network address of the incoming connection to be verified. + // * `connections` - A pointer to the current number of active connections. + pub fn verify(self: AccessControl, address: std.net.Address, connections: *u16) !void { + try self.check_connection_limit(connections); + try self.allow_whitelisted(address); + } + + // Checks whether the current number of connections exceeds the maximum allowed + // connections specified in config file. + // # Arguments + // * `self` - The AccessControl instance managing access control configurations. + // * `connections` - A pointer to the current number of active connections. + fn check_connection_limit(self: AccessControl, connections: *u16) !void { + if (connections.* > self.config.max_connections) { + self.logger.log( + log.LogLevel.Info, + "* maximum connection exceeds, rejected", + .{}, + ); + return error.MaxClientsReached; + } + } + + // Checks whether the provided network address is whitelisted. + // If whitelisting is enabled (whitelist capacity is greater than 0) and the given address is not whitelisted + // + // # Arguments + // * `self` - The AccessControl instance managing access control configurations. + // * `address` - The network address of the incoming connection to be checked. + fn allow_whitelisted( + self: AccessControl, + address: std.net.Address, + ) !void { + if (self.config.whitelist.capacity > 0 and !utils.is_whitelisted( + self.config.whitelist, + address, + )) { + self.logger.log( + log.LogLevel.Info, + "* connection from {any} is not whitelisted, rejected", + .{address}, + ); + + return error.NotWhitelisted; + } + } +}; diff --git a/src/server/listener.zig b/src/server/listener.zig index a20d590..aefa156 100644 --- a/src/server/listener.zig +++ b/src/server/listener.zig @@ -1,24 +1,25 @@ const std = @import("std"); -const ProtocolHandler = @import("../protocol/handler.zig").ProtocolHandler; -const ZType = @import("../protocol/types.zig").ZType; +const ProtocolHandler = @import("../protocol/handler.zig").ProtocolHandler; +const AccessControl = @import("access_control.zig").AccessControl; const MemoryStorage = @import("storage.zig").MemoryStorage; -const errors = @import("err_handler.zig"); +const ZType = @import("../protocol/types.zig").ZType; const CMDHandler = @import("cmd_handler.zig").CMDHandler; const Config = @import("config.zig").Config; + +const errors = @import("err_handler.zig"); const log = @import("logger.zig"); const utils = @import("utils.zig"); const Address = std.net.Address; const Allocator = std.mem.Allocator; const Pool = std.Thread.Pool; -const activeTag = std.meta.activeTag; const Connection = std.net.StreamServer.Connection; pub const ServerListener = struct { server: std.net.StreamServer, - addr: *const std.net.Address, + addr: *const Address, allocator: Allocator, cmd_handler: CMDHandler, @@ -29,7 +30,7 @@ pub const ServerListener = struct { connections: u16 = 0, logger: *const log.Logger, - pool: *std.Thread.Pool, + pool: *Pool, pub fn init( addr: *const Address, @@ -69,21 +70,21 @@ pub const ServerListener = struct { while (true) { var connection: Connection = try self.server.accept(); - if (self.connections > self.config.max_connections) { - const err = error.MaxClientsReached; - errors.handle(connection.stream, err, .{}, self.logger) catch { - self.logger.log(log.LogLevel.Error, "* failed to send error: {any}", .{err}); - }; - - return; - } - self.logger.log( log.LogLevel.Info, "* new connection from {any}", .{connection.address}, ); + const access_control = AccessControl.init(self.config, self.logger); + access_control.verify(connection.address, &self.connections) catch |err| { + errors.handle(connection.stream, err, .{}, self.logger) catch { + self.logger.log(log.LogLevel.Error, "* failed to send error: {any}", .{err}); + }; + + self.close_connection(connection); + }; + self.connections += 1; // Adds Task to the queue, then workers do its stuff @@ -108,24 +109,6 @@ pub const ServerListener = struct { var protocol = ProtocolHandler.init(self.allocator) catch return; defer protocol.deinit(); - if (self.config.whitelist.capacity > 0 and !utils.is_whitelisted( - self.config.whitelist, - connection.address, - )) { - self.logger.log( - log.LogLevel.Info, - "* connection from {any} is not whitelisted, rejected", - .{connection.address}, - ); - - var err = error.NotWhitelisted; - errors.handle(connection.stream, err, .{}, self.logger) catch { - self.logger.log(log.LogLevel.Error, "* failed to send error: {any}", .{err}); - }; - - return; - } - // reading data from client and then try to parse to protocol. const reader: std.net.Stream.Reader = connection.stream.reader(); const result: ZType = protocol.serialize(&reader) catch |err| { diff --git a/tests/run.zig b/tests/run.zig index 6851782..534eb2a 100644 --- a/tests/run.zig +++ b/tests/run.zig @@ -15,4 +15,5 @@ comptime { _ = @import("server/logger.zig"); _ = @import("server/utils.zig"); _ = @import("server/persistance.zig"); + _ = @import("server/access_control.zig"); } diff --git a/tests/server/access_control.zig b/tests/server/access_control.zig new file mode 100644 index 0000000..71c63bb --- /dev/null +++ b/tests/server/access_control.zig @@ -0,0 +1,47 @@ +const std = @import("std"); + +const log = @import("../../src/server/logger.zig"); +const Config = @import("../../src/server/config.zig").Config; +const AccessControl = @import("../../src/server/access_control.zig").AccessControl; + +test "should not return errors" { + const config = try Config.load(std.testing.allocator, null, null); + const logger = try log.Logger.init(std.testing.allocator, null); + + const access_control = AccessControl.init(&config, &logger); + const address = std.net.Address.initIp4(.{ 192, 168, 0, 1 }, 1234); + var connections: u16 = 0; + const result = access_control.verify(address, &connections); + try std.testing.expectEqual(result, void{}); +} + +test "should return MaxClientsReached" { + var config = try Config.load(std.testing.allocator, null, null); + const logger = try log.Logger.init(std.testing.allocator, null); + + config.max_connections = 1; + + const access_control = AccessControl.init(&config, &logger); + const address = std.net.Address.initIp4(.{ 192, 168, 0, 1 }, 1234); + var connections: u16 = 2; + const result = access_control.verify(address, &connections); + try std.testing.expectEqual(result, error.MaxClientsReached); +} + +test "should return NotWhitelisted" { + var config = try Config.load(std.testing.allocator, null, null); + const logger = try log.Logger.init(std.testing.allocator, null); + + const whitelisted = std.net.Address.initIp4(.{ 192, 168, 0, 2 }, 1234); + var whitelist = std.ArrayList(std.net.Address).init(std.testing.allocator); + try whitelist.append(whitelisted); + defer whitelist.deinit(); + + config.whitelist = whitelist; + + const access_control = AccessControl.init(&config, &logger); + const address = std.net.Address.initIp4(.{ 192, 168, 0, 1 }, 1234); + var connections: u16 = 1; + const result = access_control.verify(address, &connections); + try std.testing.expectEqual(result, error.NotWhitelisted); +}