-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: move check if client is authorized to connect to access_con…
…trol.zig
- Loading branch information
Showing
4 changed files
with
129 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
const std = @import("std"); | ||
|
||
const log = @import("logger.zig"); | ||
const Config = @import("config.zig").Config; | ||
const utils = @import("utils.zig"); | ||
|
||
pub const AccessControl = struct { | ||
logger: *const log.Logger, | ||
config: *const Config, | ||
|
||
pub fn init(config: *const Config, logger: *const log.Logger) AccessControl { | ||
return AccessControl{ .logger = logger, .config = config }; | ||
} | ||
|
||
// Check whether a new connection, represented by the provided network address, can be established. | ||
// | ||
// # Arguments | ||
// * `self` - The AccessControl instance managing access control configurations. | ||
// * `address` - The network address of the incoming connection to be verified. | ||
// * `connections` - A pointer to the current number of active connections. | ||
pub fn verify(self: AccessControl, address: std.net.Address, connections: *u16) !void { | ||
try self.check_connection_limit(connections); | ||
try self.allow_whitelisted(address); | ||
} | ||
|
||
// Checks whether the current number of connections exceeds the maximum allowed | ||
// connections specified in config file. | ||
// # Arguments | ||
// * `self` - The AccessControl instance managing access control configurations. | ||
// * `connections` - A pointer to the current number of active connections. | ||
fn check_connection_limit(self: AccessControl, connections: *u16) !void { | ||
if (connections.* > self.config.max_connections) { | ||
self.logger.log( | ||
log.LogLevel.Info, | ||
"* maximum connection exceeds, rejected", | ||
.{}, | ||
); | ||
return error.MaxClientsReached; | ||
} | ||
} | ||
|
||
// Checks whether the provided network address is whitelisted. | ||
// If whitelisting is enabled (whitelist capacity is greater than 0) and the given address is not whitelisted | ||
// | ||
// # Arguments | ||
// * `self` - The AccessControl instance managing access control configurations. | ||
// * `address` - The network address of the incoming connection to be checked. | ||
fn allow_whitelisted( | ||
self: AccessControl, | ||
address: std.net.Address, | ||
) !void { | ||
if (self.config.whitelist.capacity > 0 and !utils.is_whitelisted( | ||
self.config.whitelist, | ||
address, | ||
)) { | ||
self.logger.log( | ||
log.LogLevel.Info, | ||
"* connection from {any} is not whitelisted, rejected", | ||
.{address}, | ||
); | ||
|
||
return error.NotWhitelisted; | ||
} | ||
} | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
const std = @import("std"); | ||
|
||
const log = @import("../../src/server/logger.zig"); | ||
const Config = @import("../../src/server/config.zig").Config; | ||
const AccessControl = @import("../../src/server/access_control.zig").AccessControl; | ||
|
||
test "should not return errors" { | ||
const config = try Config.load(std.testing.allocator, null, null); | ||
const logger = try log.Logger.init(std.testing.allocator, null); | ||
|
||
const access_control = AccessControl.init(&config, &logger); | ||
const address = std.net.Address.initIp4(.{ 192, 168, 0, 1 }, 1234); | ||
var connections: u16 = 0; | ||
const result = access_control.verify(address, &connections); | ||
try std.testing.expectEqual(result, void{}); | ||
} | ||
|
||
test "should return MaxClientsReached" { | ||
var config = try Config.load(std.testing.allocator, null, null); | ||
const logger = try log.Logger.init(std.testing.allocator, null); | ||
|
||
config.max_connections = 1; | ||
|
||
const access_control = AccessControl.init(&config, &logger); | ||
const address = std.net.Address.initIp4(.{ 192, 168, 0, 1 }, 1234); | ||
var connections: u16 = 2; | ||
const result = access_control.verify(address, &connections); | ||
try std.testing.expectEqual(result, error.MaxClientsReached); | ||
} | ||
|
||
test "should return NotWhitelisted" { | ||
var config = try Config.load(std.testing.allocator, null, null); | ||
const logger = try log.Logger.init(std.testing.allocator, null); | ||
|
||
const whitelisted = std.net.Address.initIp4(.{ 192, 168, 0, 2 }, 1234); | ||
var whitelist = std.ArrayList(std.net.Address).init(std.testing.allocator); | ||
try whitelist.append(whitelisted); | ||
defer whitelist.deinit(); | ||
|
||
config.whitelist = whitelist; | ||
|
||
const access_control = AccessControl.init(&config, &logger); | ||
const address = std.net.Address.initIp4(.{ 192, 168, 0, 1 }, 1234); | ||
var connections: u16 = 1; | ||
const result = access_control.verify(address, &connections); | ||
try std.testing.expectEqual(result, error.NotWhitelisted); | ||
} |