From 68294f096788225e286b7d1a9a45a7b2b0fd9540 Mon Sep 17 00:00:00 2001 From: melonedo <44501064+melonedo@users.noreply.github.com> Date: Mon, 11 Mar 2024 15:04:01 +0800 Subject: [PATCH] Really use TlsClient instead of std.crypto.tls.Client --- src/HttpClient.zig | 10 ++-- src/TlsClient.zig | 14 ++--- src/crypto/Bundle.zig | 4 +- src/crypto/Bundle/macos.zig | 114 ------------------------------------ src/crypto/Certificate.zig | 8 +-- 5 files changed, 18 insertions(+), 132 deletions(-) delete mode 100644 src/crypto/Bundle/macos.zig diff --git a/src/HttpClient.zig b/src/HttpClient.zig index ac0a23c..b76fc42 100644 --- a/src/HttpClient.zig +++ b/src/HttpClient.zig @@ -18,7 +18,7 @@ pub const default_connection_pool_size = 32; pub const connection_pool_size = std.options.http_connection_pool_size; allocator: Allocator, -ca_bundle: std.crypto.Certificate.Bundle = .{}, +ca_bundle: Certificate.Bundle = .{}, ca_bundle_mutex: std.Thread.Mutex = .{}, /// When this is `true`, the next time this client performs an HTTPS request, /// it will first rescan the system for root certificates. @@ -152,7 +152,7 @@ pub const Connection = struct { stream: net.Stream, /// undefined unless protocol is tls. - tls_client: *std.crypto.tls.Client, + tls_client: *TlsClient, protocol: Protocol, host: []u8, @@ -288,7 +288,7 @@ pub const Connection = struct { pub fn close(conn: *Connection, client: *const Client) void { if (conn.protocol == .tls) { // try to cleanly close the TLS connection, for any server that cares. - _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; + _ = conn.tls_client.writeEnd(conn.stream, "", true, .application_data) catch {}; client.allocator.destroy(conn.tls_client); } @@ -908,10 +908,10 @@ pub fn connectUnproxied(client: *Client, host: []const u8, port: u16, protocol: switch (protocol) { .plain => {}, .tls => { - conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client); + conn.data.tls_client = try client.allocator.create(TlsClient); errdefer client.allocator.destroy(conn.data.tls_client); - conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed; + conn.data.tls_client.* = TlsClient.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed; // This is appropriate for HTTPS because the HTTP headers contain // the content length which is used to detect truncation attacks. conn.data.tls_client.allow_truncation_attacks = true; diff --git a/src/TlsClient.zig b/src/TlsClient.zig index 5fd47da..283598f 100644 --- a/src/TlsClient.zig +++ b/src/TlsClient.zig @@ -426,10 +426,10 @@ pub fn init(stream: std.net.Stream, ca_bundle: Certificate.Bundle, host: []const error.IdentityElement => return error.InsufficientEntropy, }; - const mul = pk.p.mulPublic(secp256r1_kp.secret_key.bytes, .big) catch { + const mul = pk.p.mulPublic(secp256r1_kp.secret_key.bytes, .Big) catch { return error.TlsDecryptFailure; }; - shared_key = &mul.affineCoordinates().x.toBytes(.big); + shared_key = &mul.affineCoordinates().x.toBytes(.Big); break :blk &secp256r1_kp.public_key.toUncompressedSec1(); }, else => unreachable, @@ -854,7 +854,7 @@ pub fn readvAdvanced(c: *Client, stream: std.net.Stream, iovecs: []const std.os. // Skip `stream.readv` if there is a complete record unprocessed // This may happen when different types of traffic are mixed. if (c.ciphertext_slice.len > 5) { - const record_len = mem.readInt(u16, c.ciphertext_slice[3..5], .big); + const record_len = mem.readInt(u16, c.ciphertext_slice[3..5], .Big); if (record_len + 5 <= c.ciphertext_slice.len) break; } @@ -905,8 +905,8 @@ pub fn readvAdvanced(c: *Client, stream: std.net.Stream, iovecs: []const std.os. // Ensure a complete record is in `frag` const ct: tls.ContentType = @enumFromInt(frag[in]); - const legacy_version = mem.readInt(u16, frag[in..][1..3], .big); - const record_len = mem.readInt(u16, frag[in..][3..5], .big); + const legacy_version = mem.readInt(u16, frag[in..][1..3], .Big); + const record_len = mem.readInt(u16, frag[in..][3..5], .Big); if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; in += 5; const end = in + record_len; @@ -1035,8 +1035,8 @@ const native_endian = builtin.cpu.arch.endian(); inline fn big(x: anytype) @TypeOf(x) { return switch (native_endian) { - .big => x, - .little => @byteSwap(x), + .Big => x, + .Little => @byteSwap(x), }; } diff --git a/src/crypto/Bundle.zig b/src/crypto/Bundle.zig index 3ab090c..c3cabbd 100644 --- a/src/crypto/Bundle.zig +++ b/src/crypto/Bundle.zig @@ -68,8 +68,8 @@ pub fn rescan(cb: *Bundle, gpa: Allocator) RescanError!void { } } -const rescanMac = @import("Bundle/macos.zig").rescanMac; -const RescanMacError = @import("Bundle/macos.zig").RescanMacError; +const rescanMac = std.crypto.Certificate.Bundle.rescan; +const RescanMacError = std.crypto.Certificate.Bundle.RescanError; const RescanLinuxError = AddCertsFromFilePathError || AddCertsFromDirPathError; diff --git a/src/crypto/Bundle/macos.zig b/src/crypto/Bundle/macos.zig deleted file mode 100644 index 028275a..0000000 --- a/src/crypto/Bundle/macos.zig +++ /dev/null @@ -1,114 +0,0 @@ -const std = @import("std"); -const assert = std.debug.assert; -const fs = std.fs; -const mem = std.mem; -const Allocator = std.mem.Allocator; -const Bundle = @import("../Bundle.zig"); - -pub const RescanMacError = Allocator.Error || fs.File.OpenError || fs.File.ReadError || fs.File.SeekError || Bundle.ParseCertError || error{EndOfStream}; - -pub fn rescanMac(cb: *Bundle, gpa: Allocator) RescanMacError!void { - cb.bytes.clearRetainingCapacity(); - cb.map.clearRetainingCapacity(); - - const file = try fs.openFileAbsolute("/System/Library/Keychains/SystemRootCertificates.keychain", .{}); - defer file.close(); - - const bytes = try file.readToEndAlloc(gpa, std.math.maxInt(u32)); - defer gpa.free(bytes); - - var stream = std.io.fixedBufferStream(bytes); - const reader = stream.reader(); - - const db_header = try reader.readStructBig(ApplDbHeader); - assert(mem.eql(u8, "kych", &@as([4]u8, @bitCast(db_header.signature)))); - - try stream.seekTo(db_header.schema_offset); - - const db_schema = try reader.readStructBig(ApplDbSchema); - - var table_list = try gpa.alloc(u32, db_schema.table_count); - defer gpa.free(table_list); - - var table_idx: u32 = 0; - while (table_idx < table_list.len) : (table_idx += 1) { - table_list[table_idx] = try reader.readIntBig(u32); - } - - const now_sec = std.time.timestamp(); - - for (table_list) |table_offset| { - try stream.seekTo(db_header.schema_offset + table_offset); - - const table_header = try reader.readStructBig(TableHeader); - - if (@as(std.os.darwin.cssm.DB_RECORDTYPE, @enumFromInt(table_header.table_id)) != .X509_CERTIFICATE) { - continue; - } - - var record_list = try gpa.alloc(u32, table_header.record_count); - defer gpa.free(record_list); - - var record_idx: u32 = 0; - while (record_idx < record_list.len) : (record_idx += 1) { - record_list[record_idx] = try reader.readIntBig(u32); - } - - for (record_list) |record_offset| { - try stream.seekTo(db_header.schema_offset + table_offset + record_offset); - - const cert_header = try reader.readStructBig(X509CertHeader); - - try cb.bytes.ensureUnusedCapacity(gpa, cert_header.cert_size); - - const cert_start = @as(u32, @intCast(cb.bytes.items.len)); - const dest_buf = cb.bytes.allocatedSlice()[cert_start..]; - cb.bytes.items.len += try reader.readAtLeast(dest_buf, cert_header.cert_size); - - try cb.parseCert(gpa, cert_start, now_sec); - } - } - - cb.bytes.shrinkAndFree(gpa, cb.bytes.items.len); -} - -const ApplDbHeader = extern struct { - signature: @Vector(4, u8), - version: u32, - header_size: u32, - schema_offset: u32, - auth_offset: u32, -}; - -const ApplDbSchema = extern struct { - schema_size: u32, - table_count: u32, -}; - -const TableHeader = extern struct { - table_size: u32, - table_id: u32, - record_count: u32, - records: u32, - indexes_offset: u32, - free_list_head: u32, - record_numbers_count: u32, -}; - -const X509CertHeader = extern struct { - record_size: u32, - record_number: u32, - unknown1: u32, - unknown2: u32, - cert_size: u32, - unknown3: u32, - cert_type: u32, - cert_encoding: u32, - print_name: u32, - alias: u32, - subject: u32, - issuer: u32, - serial_number: u32, - subject_key_identifier: u32, - public_key_hash: u32, -}; diff --git a/src/crypto/Certificate.zig b/src/crypto/Certificate.zig index 39bf932..b322799 100644 --- a/src/crypto/Certificate.zig +++ b/src/crypto/Certificate.zig @@ -1109,7 +1109,7 @@ pub const rsa = struct { // Reject modulus below 512 bits. // 512-bit RSA was factored in 1999, so this limit barely means anything, // but establish some limit now to ratchet in what we can. - const _n = Modulus.fromBytes(modulus_bytes, .big) catch return error.CertificatePublicKeyInvalid; + const _n = Modulus.fromBytes(modulus_bytes, .Big) catch return error.CertificatePublicKeyInvalid; if (_n.bits() < 512) return error.CertificatePublicKeyInvalid; // Exponent must be odd and greater than 2. @@ -1119,7 +1119,7 @@ pub const rsa = struct { // Windows commonly does. // [1] https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-rsapubkey if (pub_bytes.len > 4) return error.CertificatePublicKeyInvalid; - const _e = Fe.fromBytes(_n, pub_bytes, .big) catch return error.CertificatePublicKeyInvalid; + const _e = Fe.fromBytes(_n, pub_bytes, .Big) catch return error.CertificatePublicKeyInvalid; if (!_e.isOdd()) return error.CertificatePublicKeyInvalid; const e_v = _e.toPrimitive(u32) catch return error.CertificatePublicKeyInvalid; if (e_v < 2) return error.CertificatePublicKeyInvalid; @@ -1150,10 +1150,10 @@ pub const rsa = struct { }; fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey) ![modulus_len]u8 { - const m = Fe.fromBytes(public_key.n, &msg, .big) catch return error.MessageTooLong; + const m = Fe.fromBytes(public_key.n, &msg, .Big) catch return error.MessageTooLong; const e = public_key.n.powPublic(m, public_key.e) catch unreachable; var res: [modulus_len]u8 = undefined; - e.toBytes(&res, .big) catch unreachable; + e.toBytes(&res, .Big) catch unreachable; return res; } };