Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable strict concurrency #54

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,9 @@ let package = Package(
],
cLanguageStandard: .gnu11
)

for target in package.targets {
var settings = target.swiftSettings ?? []
settings.append(.enableExperimentalFeature("StrictConcurrency=complete"))
target.swiftSettings = settings
}
19 changes: 11 additions & 8 deletions Sources/AsyncDNSResolver/c-ares/AresChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@ import Foundation
// MARK: - ares_channel

@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
class AresChannel {
let pointer: UnsafeMutablePointer<ares_channel?>
let lock = NSLock()
final class AresChannel: @unchecked Sendable {
private let locked_pointer: UnsafeMutablePointer<ares_channel?>
private let lock = NSLock()

private var underlying: ares_channel? {
self.pointer.pointee
// For testing only.
var underlying: ares_channel? {
self.locked_pointer.pointee
}

deinit {
ares_destroy(pointer.pointee)
pointer.deallocate()
// Safe to perform without the lock, as in deinit we know that no more
// strong references to self exist, so nobody can be holding the lock.
ares_destroy(locked_pointer.pointee)
locked_pointer.deallocate()
ares_library_cleanup()
}

Expand All @@ -49,7 +52,7 @@ class AresChannel {
try checkAresResult { ares_set_sortlist(pointer.pointee, sortlist) }
}

self.pointer = pointer
self.locked_pointer = pointer
}

func withChannel(_ body: (ares_channel) -> Void) {
Expand Down
4 changes: 2 additions & 2 deletions Sources/AsyncDNSResolver/c-ares/AresOptions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import CAsyncDNSResolver
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension CAresDNSResolver {
/// Options for ``CAresDNSResolver``.
public struct Options {
public struct Options: Sendable {
public static var `default`: Options {
.init()
}
Expand Down Expand Up @@ -91,7 +91,7 @@ extension CAresDNSResolver {

@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension CAresDNSResolver.Options {
public struct Flags: OptionSet {
public struct Flags: OptionSet, Sendable {
public let rawValue: Int32

public init(rawValue: Int32) {
Expand Down
33 changes: 18 additions & 15 deletions Sources/AsyncDNSResolver/c-ares/DNSResolver_c-ares.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
//===----------------------------------------------------------------------===//

import CAsyncDNSResolver
import Foundation

/// ``DNSResolver`` implementation backed by c-ares C library.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
public class CAresDNSResolver: DNSResolver {
public final class CAresDNSResolver: DNSResolver, Sendable {
let options: Options
let ares: Ares

Expand Down Expand Up @@ -121,18 +122,15 @@ extension QueryType {
// MARK: - c-ares query wrapper

@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
class Ares {
final class Ares: Sendable {
typealias QueryCallback = @convention(c) (
UnsafeMutableRawPointer?, CInt, CInt, UnsafeMutablePointer<CUnsignedChar>?, CInt
) -> Void

let options: AresOptions
let channel: AresChannel

private let channel: AresChannel
private let queryProcessor: QueryProcessor

init(options: AresOptions) throws {
self.options = options
self.channel = try AresChannel(options: options)

// Need to call `ares_process` or `ares_process_fd` for query callbacks to happen
Expand All @@ -145,7 +143,8 @@ class Ares {
name: String,
replyParser: ReplyParser
) async throws -> ReplyParser.Reply {
try await withTaskCancellationHandler(
let channel = self.channel
return try await withTaskCancellationHandler(
operation: {
try await withCheckedThrowingContinuation { continuation in
let handler = QueryReplyHandler(parser: replyParser, continuation)
Expand Down Expand Up @@ -178,7 +177,7 @@ class Ares {
}
},
onCancel: {
self.channel.withChannel { channel in
channel.withChannel { channel in
ares_cancel(channel)
}
}
Expand All @@ -198,16 +197,18 @@ extension Ares {
// https://github.com/dimbleby/c-ares-resolver/blob/master/src/unix/eventloop.rs // ignore-unacceptable-language
// https://github.com/dimbleby/rust-c-ares/blob/master/src/channel.rs // ignore-unacceptable-language
// https://github.com/dimbleby/rust-c-ares/blob/master/examples/event-loop.rs // ignore-unacceptable-language
class QueryProcessor {
final class QueryProcessor: @unchecked Sendable {
static let defaultPollInterval: UInt64 = 10 * 1_000_000 // 10ms

private let channel: AresChannel
private let pollIntervalNanos: UInt64

private var pollingTask: Task<Void, Error>?
private let lock = NSLock()
private var locked_pollingTask: Task<Void, Error>?

deinit {
self.pollingTask?.cancel()
// No need to lock here as there can exist no more strong references to self.
self.locked_pollingTask?.cancel()
}

init(channel: AresChannel, pollIntervalNanos: UInt64 = QueryProcessor.defaultPollInterval) {
Expand All @@ -218,7 +219,7 @@ extension Ares {
/// Asks c-ares for the set of socket descriptors we are waiting on for the `ares_channel`'s pending queries
/// then call `ares_process_fd` if any is ready for read and/or write.
/// c-ares returns up to `ARES_GETSOCK_MAXNUM` socket descriptors only. If more are in use (unlikely) they are not reported back.
func poll() async {
func poll() {
var socks = [ares_socket_t](repeating: ares_socket_t(), count: Int(ARES_GETSOCK_MAXNUM))

self.channel.withChannel { channel in
Expand Down Expand Up @@ -249,12 +250,14 @@ extension Ares {
}

private func schedule() {
self.pollingTask = Task { [weak self] in
self.lock.lock()
defer { self.lock.unlock() }
self.locked_pollingTask = Task { [weak self] in
guard let s = self else {
return
}
try await Task.sleep(nanoseconds: s.pollIntervalNanos)
await s.poll()
s.poll()
}
}
}
Expand Down Expand Up @@ -291,7 +294,7 @@ extension Ares {
// MARK: - c-ares query reply parsers

protocol AresQueryReplyParser {
associatedtype Reply
associatedtype Reply: Sendable

func parse(buffer: UnsafeMutablePointer<CUnsignedChar>?, length: CInt) throws -> Reply
}
Expand Down
6 changes: 3 additions & 3 deletions Sources/AsyncDNSResolver/dnssd/DNSResolver_dnssd.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import dnssd

/// ``DNSResolver`` implementation backed by dnssd framework.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
public struct DNSSDDNSResolver: DNSResolver {
public struct DNSSDDNSResolver: DNSResolver, Sendable {
let dnssd: DNSSD

init() {
Expand Down Expand Up @@ -100,7 +100,7 @@ extension QueryType {
// MARK: - dnssd query wrapper

@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
struct DNSSD {
struct DNSSD: Sendable {
// Reference: https://gist.github.com/fikeminkel/a9c4bc4d0348527e8df3690e242038d3
func query<ReplyHandler: DNSSDQueryReplyHandler>(
type: QueryType,
Expand Down Expand Up @@ -225,7 +225,7 @@ extension DNSSD {
// MARK: - dnssd query reply handlers

protocol DNSSDQueryReplyHandler {
associatedtype Record
associatedtype Record: Sendable
associatedtype Reply

func parseRecord(data: UnsafeRawPointer?, length: UInt16) throws -> Record?
Expand Down
2 changes: 1 addition & 1 deletion Tests/AsyncDNSResolverTests/c-ares/AresChannelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ final class AresChannelTests: XCTestCase {
guard let channel = try? AresChannel(options: options) else {
return XCTFail("Channel not initialized")
}
guard let _ = channel.pointer.pointee else {
guard let _ = channel.underlying else {
return XCTFail("Underlying ares_channel is nil")
}
}
Expand Down
40 changes: 21 additions & 19 deletions Tests/AsyncDNSResolverTests/c-ares/CAresDNSResolverTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ final class CAresDNSResolverTests: XCTestCase {
func test_concurrency() async throws {
func run(
times: Int = 100,
_ query: @escaping (_ index: Int) async throws -> Void
_ query: @Sendable @escaping (_ index: Int) async throws -> Void
) async throws {
try await withThrowingTaskGroup(of: Void.self) { group in
for i in 1...times {
Expand All @@ -136,65 +136,67 @@ final class CAresDNSResolverTests: XCTestCase {
}
}

let resolver = self.resolver!
let verbose = self.verbose
try await run { i in
let reply = try await self.resolver.queryA(name: "apple.com")
if self.verbose {
let reply = try await resolver.queryA(name: "apple.com")
if verbose {
print("[A] run #\(i) result: \(reply)")
}
}

try await run { i in
let reply = try await self.resolver.queryAAAA(name: "apple.com")
if self.verbose {
let reply = try await resolver.queryAAAA(name: "apple.com")
if verbose {
print("[AAAA] run #\(i) result: \(reply)")
}
}

try await run { i in
let reply = try await self.resolver.queryNS(name: "apple.com")
if self.verbose {
let reply = try await resolver.queryNS(name: "apple.com")
if verbose {
print("[NS] run #\(i) result: \(reply)")
}
}

try await run { i in
let reply = try await self.resolver.queryCNAME(name: "www.apple.com")
if self.verbose {
let reply = try await resolver.queryCNAME(name: "www.apple.com")
if verbose {
print("[CNAME] run #\(i) result: \(String(describing: reply))")
}
}

try await run { i in
let reply = try await self.resolver.querySOA(name: "apple.com")
if self.verbose {
let reply = try await resolver.querySOA(name: "apple.com")
if verbose {
print("[SOA] run #\(i) result: \(String(describing: reply))")
}
}

try await run { i in
let reply = try await self.resolver.queryPTR(name: "47.224.172.17.in-addr.arpa")
if self.verbose {
let reply = try await resolver.queryPTR(name: "47.224.172.17.in-addr.arpa")
if verbose {
print("[PTR] run #\(i) result: \(reply)")
}
}

try await run { i in
let reply = try await self.resolver.queryMX(name: "apple.com")
if self.verbose {
let reply = try await resolver.queryMX(name: "apple.com")
if verbose {
print("[MX] run #\(i) result: \(reply)")
}
}

try await run { i in
let reply = try await self.resolver.queryTXT(name: "apple.com")
if self.verbose {
let reply = try await resolver.queryTXT(name: "apple.com")
if verbose {
print("[TXT] run #\(i) result: \(reply)")
}
}

try await run { i in
let reply = try await self.resolver.querySRV(name: "_caldavs._tcp.google.com")
if self.verbose {
let reply = try await resolver.querySRV(name: "_caldavs._tcp.google.com")
if verbose {
print("[SRV] run #\(i) result: \(reply)")
}
}
Expand Down
Loading
Loading