Skip to content

Commit

Permalink
refactor: move check if client is authorized to connect to access_con…
Browse files Browse the repository at this point in the history
…trol.zig
  • Loading branch information
sectasy0 committed Feb 17, 2024
1 parent 2028203 commit f2ac3cd
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 33 deletions.
65 changes: 65 additions & 0 deletions src/server/access_control.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
const std = @import("std");

const log = @import("logger.zig");
const Config = @import("config.zig").Config;
const utils = @import("utils.zig");

pub const AccessControl = struct {
logger: *const log.Logger,
config: *const Config,

pub fn init(config: *const Config, logger: *const log.Logger) AccessControl {
return AccessControl{ .logger = logger, .config = config };
}

// Check whether a new connection, represented by the provided network address, can be established.
//
// # Arguments
// * `self` - The AccessControl instance managing access control configurations.
// * `address` - The network address of the incoming connection to be verified.
// * `connections` - A pointer to the current number of active connections.
pub fn verify(self: AccessControl, address: std.net.Address, connections: *u16) !void {
try self.check_connection_limit(connections);
try self.allow_whitelisted(address);
}

// Checks whether the current number of connections exceeds the maximum allowed
// connections specified in config file.
// # Arguments
// * `self` - The AccessControl instance managing access control configurations.
// * `connections` - A pointer to the current number of active connections.
fn check_connection_limit(self: AccessControl, connections: *u16) !void {
if (connections.* > self.config.max_connections) {
self.logger.log(
log.LogLevel.Info,
"* maximum connection exceeds, rejected",
.{},
);
return error.MaxClientsReached;
}
}

// Checks whether the provided network address is whitelisted.
// If whitelisting is enabled (whitelist capacity is greater than 0) and the given address is not whitelisted
//
// # Arguments
// * `self` - The AccessControl instance managing access control configurations.
// * `address` - The network address of the incoming connection to be checked.
fn allow_whitelisted(
self: AccessControl,
address: std.net.Address,
) !void {
if (self.config.whitelist.capacity > 0 and !utils.is_whitelisted(
self.config.whitelist,
address,
)) {
self.logger.log(
log.LogLevel.Info,
"* connection from {any} is not whitelisted, rejected",
.{address},
);

return error.NotWhitelisted;
}
}
};
49 changes: 16 additions & 33 deletions src/server/listener.zig
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
const std = @import("std");
const ProtocolHandler = @import("../protocol/handler.zig").ProtocolHandler;

const ZType = @import("../protocol/types.zig").ZType;
const ProtocolHandler = @import("../protocol/handler.zig").ProtocolHandler;
const AccessControl = @import("access_control.zig").AccessControl;
const MemoryStorage = @import("storage.zig").MemoryStorage;
const errors = @import("err_handler.zig");
const ZType = @import("../protocol/types.zig").ZType;
const CMDHandler = @import("cmd_handler.zig").CMDHandler;
const Config = @import("config.zig").Config;

const errors = @import("err_handler.zig");
const log = @import("logger.zig");
const utils = @import("utils.zig");

const Address = std.net.Address;
const Allocator = std.mem.Allocator;
const Pool = std.Thread.Pool;
const activeTag = std.meta.activeTag;

const Connection = std.net.StreamServer.Connection;

pub const ServerListener = struct {
server: std.net.StreamServer,
addr: *const std.net.Address,
addr: *const Address,

allocator: Allocator,
cmd_handler: CMDHandler,
Expand All @@ -29,7 +30,7 @@ pub const ServerListener = struct {
connections: u16 = 0,

logger: *const log.Logger,
pool: *std.Thread.Pool,
pool: *Pool,

pub fn init(
addr: *const Address,
Expand Down Expand Up @@ -69,21 +70,21 @@ pub const ServerListener = struct {
while (true) {
var connection: Connection = try self.server.accept();

if (self.connections > self.config.max_connections) {
const err = error.MaxClientsReached;
errors.handle(connection.stream, err, .{}, self.logger) catch {
self.logger.log(log.LogLevel.Error, "* failed to send error: {any}", .{err});
};

return;
}

self.logger.log(
log.LogLevel.Info,
"* new connection from {any}",
.{connection.address},
);

const access_control = AccessControl.init(self.config, self.logger);
access_control.verify(connection.address, &self.connections) catch |err| {
errors.handle(connection.stream, err, .{}, self.logger) catch {
self.logger.log(log.LogLevel.Error, "* failed to send error: {any}", .{err});
};

self.close_connection(connection);
};

self.connections += 1;

// Adds Task to the queue, then workers do its stuff
Expand All @@ -108,24 +109,6 @@ pub const ServerListener = struct {
var protocol = ProtocolHandler.init(self.allocator) catch return;
defer protocol.deinit();

if (self.config.whitelist.capacity > 0 and !utils.is_whitelisted(
self.config.whitelist,
connection.address,
)) {
self.logger.log(
log.LogLevel.Info,
"* connection from {any} is not whitelisted, rejected",
.{connection.address},
);

var err = error.NotWhitelisted;
errors.handle(connection.stream, err, .{}, self.logger) catch {
self.logger.log(log.LogLevel.Error, "* failed to send error: {any}", .{err});
};

return;
}

// reading data from client and then try to parse to protocol.
const reader: std.net.Stream.Reader = connection.stream.reader();
const result: ZType = protocol.serialize(&reader) catch |err| {
Expand Down
1 change: 1 addition & 0 deletions tests/run.zig
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ comptime {
_ = @import("server/logger.zig");
_ = @import("server/utils.zig");
_ = @import("server/persistance.zig");
_ = @import("server/access_control.zig");
}
47 changes: 47 additions & 0 deletions tests/server/access_control.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
const std = @import("std");

const log = @import("../../src/server/logger.zig");
const Config = @import("../../src/server/config.zig").Config;
const AccessControl = @import("../../src/server/access_control.zig").AccessControl;

test "should not return errors" {
const config = try Config.load(std.testing.allocator, null, null);
const logger = try log.Logger.init(std.testing.allocator, null);

const access_control = AccessControl.init(&config, &logger);
const address = std.net.Address.initIp4(.{ 192, 168, 0, 1 }, 1234);
var connections: u16 = 0;
const result = access_control.verify(address, &connections);
try std.testing.expectEqual(result, void{});
}

test "should return MaxClientsReached" {
var config = try Config.load(std.testing.allocator, null, null);
const logger = try log.Logger.init(std.testing.allocator, null);

config.max_connections = 1;

const access_control = AccessControl.init(&config, &logger);
const address = std.net.Address.initIp4(.{ 192, 168, 0, 1 }, 1234);
var connections: u16 = 2;
const result = access_control.verify(address, &connections);
try std.testing.expectEqual(result, error.MaxClientsReached);
}

test "should return NotWhitelisted" {
var config = try Config.load(std.testing.allocator, null, null);
const logger = try log.Logger.init(std.testing.allocator, null);

const whitelisted = std.net.Address.initIp4(.{ 192, 168, 0, 2 }, 1234);
var whitelist = std.ArrayList(std.net.Address).init(std.testing.allocator);
try whitelist.append(whitelisted);
defer whitelist.deinit();

config.whitelist = whitelist;

const access_control = AccessControl.init(&config, &logger);
const address = std.net.Address.initIp4(.{ 192, 168, 0, 1 }, 1234);
var connections: u16 = 1;
const result = access_control.verify(address, &connections);
try std.testing.expectEqual(result, error.NotWhitelisted);
}

0 comments on commit f2ac3cd

Please sign in to comment.