From bcbdbf4ce3185f3b3f882c9b995340b3a08c3686 Mon Sep 17 00:00:00 2001 From: Chris Conlon Date: Fri, 12 Apr 2024 15:01:02 -0600 Subject: [PATCH] JCE: refactor KeyAgreement threaded test to use AtomicIntegerArray --- .../jce/test/WolfCryptKeyAgreementTest.java | 75 +++++++++++-------- 1 file changed, 44 insertions(+), 31 deletions(-) diff --git a/src/test/java/com/wolfssl/provider/jce/test/WolfCryptKeyAgreementTest.java b/src/test/java/com/wolfssl/provider/jce/test/WolfCryptKeyAgreementTest.java index 4d316f8e..8aba56f5 100644 --- a/src/test/java/com/wolfssl/provider/jce/test/WolfCryptKeyAgreementTest.java +++ b/src/test/java/com/wolfssl/provider/jce/test/WolfCryptKeyAgreementTest.java @@ -29,10 +29,12 @@ import java.util.ArrayList; import java.util.Random; import java.util.Iterator; +import java.util.concurrent.TimeUnit; import java.util.concurrent.Executors; import java.util.concurrent.ExecutorService; import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicIntegerArray; import javax.crypto.KeyAgreement; import javax.crypto.ShortBufferException; @@ -537,9 +539,18 @@ private void threadRunnerKeyAgreeTest(String algo) int numThreads = 10; ExecutorService service = Executors.newFixedThreadPool(numThreads); final CountDownLatch latch = new CountDownLatch(numThreads); - final LinkedBlockingQueue results = new LinkedBlockingQueue<>(); final String currentAlgo = algo; + /* Used to detect timeout of CountDownLatch, don't run indefinitely + * if threads are stalled out or deadlocked */ + boolean returnWithoutTimeout = true; + + /* Keep track of failure and success count */ + final AtomicIntegerArray failures = new AtomicIntegerArray(1); + final AtomicIntegerArray success = new AtomicIntegerArray(1); + failures.set(0, 0); + success.set(0, 0); + /* DH Tests */ AlgorithmParameterGenerator paramGen = AlgorithmParameterGenerator.getInstance("DH"); @@ -552,7 +563,6 @@ private void threadRunnerKeyAgreeTest(String algo) service.submit(new Runnable() { @Override public void run() { - int failed = 0; KeyPairGenerator keyGen = null; KeyAgreement aKeyAgree = null; KeyAgreement bKeyAgree = null; @@ -603,53 +613,56 @@ private void threadRunnerKeyAgreeTest(String algo) byte secretB[] = bKeyAgree.generateSecret(); if (!Arrays.equals(secretA, secretB)) { - failed = 1; + throw new Exception( + "Secrets A and B to not match"); } - if (failed == 0) { - cKeyAgree = KeyAgreement.getInstance( - currentAlgo, "wolfJCE"); - cPair = keyGen.generateKeyPair(); - cKeyAgree.init(cPair.getPrivate()); + cKeyAgree = KeyAgreement.getInstance( + currentAlgo, "wolfJCE"); + cPair = keyGen.generateKeyPair(); + cKeyAgree.init(cPair.getPrivate()); - aKeyAgree.doPhase(cPair.getPublic(), true); - cKeyAgree.doPhase(aPair.getPublic(), true); + aKeyAgree.doPhase(cPair.getPublic(), true); + cKeyAgree.doPhase(aPair.getPublic(), true); - byte secretA2[] = aKeyAgree.generateSecret(); - byte secretC[] = cKeyAgree.generateSecret(); + byte secretA2[] = aKeyAgree.generateSecret(); + byte secretC[] = cKeyAgree.generateSecret(); - if (!Arrays.equals(secretA2, secretC)) { - failed = 1; - } + if (!Arrays.equals(secretA2, secretC)) { + throw new Exception( + "Secrets A2 and C do not match"); } + /* Log success */ + success.incrementAndGet(0); + } catch (Exception e) { e.printStackTrace(); - failed = 1; + + /* Log failure */ + failures.incrementAndGet(0); } finally { latch.countDown(); } - - if (failed == 1) { - results.add(1); - } - else { - results.add(0); - } } }); } /* wait for all threads to complete */ - latch.await(); - - /* Look for any failures that happened */ - Iterator listIterator = results.iterator(); - while (listIterator.hasNext()) { - Integer cur = listIterator.next(); - if (cur == 1) { - fail("Threading error in KeyAgreement thread test"); + returnWithoutTimeout = latch.await(10, TimeUnit.SECONDS); + service.shutdown(); + + /* Check failure count and success count against thread count */ + if ((failures.get(0) != 0) || + (success.get(0) != numThreads)) { + if (returnWithoutTimeout == true) { + fail("KeyAgreement test threading error: " + + failures.get(0) + " failures, " + + success.get(0) + " success, " + + numThreads + " num threads total"); + } else { + fail("KeyAgreement test threading error, threads timed out"); } } }