Skip to content

Commit

Permalink
Fix some thread safe issues of regex matcher (#6)
Browse files Browse the repository at this point in the history
Fix some thread safe and memory issues of regex matcher
  • Loading branch information
silas-sager authored Sep 26, 2022
1 parent d9f3ebe commit fd38999
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 103 deletions.
4 changes: 2 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

<groupId>ir.sahab</groupId>
<artifactId>regex-matcher</artifactId>
<version>1.0.0</version>
<version>1.0.1</version>
<description>
This library provides facilities to match an input string against a collection of regex patterns.
</description>
Expand Down Expand Up @@ -59,7 +59,7 @@
<plugin>
<groupId>org.jacoco</groupId>
<artifactId>jacoco-maven-plugin</artifactId>
<version>0.8.6</version>
<version>0.8.7</version>
<executions>
<execution>
<goals>
Expand Down
43 changes: 19 additions & 24 deletions src/main/cpp/jni/hyperscan_wrapper.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "hyperscan_wrapper.h"

#include <cassert>
#include <cstdlib>
#include <vector>

Expand All @@ -23,17 +22,12 @@ HyperscanWrapper::~HyperscanWrapper() {
}

void HyperscanWrapper::AddPattern(unsigned int id, const char* pattern, bool is_case_sensitive) {
assert(pattern != nullptr);
assert(id > 0);

patterns_.insert(pair<unsigned int, pair<string, bool> >(
id, pair<string, bool>(pattern, is_case_sensitive)));
is_compile_required_ = true;
}

bool HyperscanWrapper::RemovePattern(unsigned int id) {
assert(id > 0);

if (0 == patterns_.erase(id)) {
return false;
}
Expand All @@ -42,7 +36,7 @@ bool HyperscanWrapper::RemovePattern(unsigned int id) {
return true;
}

int HyperscanWrapper::CompilePatterns() {
int64_t HyperscanWrapper::CompilePatterns() {
if (!is_compile_required_) {
return 0;
}
Expand Down Expand Up @@ -76,26 +70,27 @@ int HyperscanWrapper::CompilePatterns() {
auto error = ch_compile_multi(cstr_patterns.data(), flags.data(), pattern_ids.data(),
cstr_patterns.size(), CH_MODE_NOGROUPS, nullptr, &pattern_database_, &compile_err);
if (error != CH_SUCCESS) {
if (error == CH_COMPILER_ERROR) {
auto erroneous_id = compile_err->expression;
last_error_.assign("Unable to compile patterns: error = ").append(compile_err->message);
if (erroneous_id >= 0) {
// Convert index to pattern id
erroneous_id = pattern_ids[erroneous_id];
last_error_.append(", erroneous pattern id = ");
last_error_ += to_string(erroneous_id);
}
ch_free_compile_error(compile_err);
return erroneous_id;
auto erroneous_pattern_index = compile_err->expression;

last_error_ = "Unable to compile patterns: error = ";
last_error_.append(compile_err->message);
ch_free_compile_error(compile_err);
if (erroneous_pattern_index >= 0) {
// Convert index to pattern id
auto erroneous_pattern_id = pattern_ids[erroneous_pattern_index];
last_error_.append(", erroneous pattern id = ");
last_error_ += to_string(erroneous_pattern_id);
return erroneous_pattern_id;
} else {
last_error_ = "An unexpected error occurred: " + to_string(error);
last_error_.append(", unknown expression index = ");
last_error_ += to_string(erroneous_pattern_index);
return -1;
}
}

error = ch_alloc_scratch(pattern_database_, &scratch_);
if (error != CH_SUCCESS) {
last_error_ = "ERROR: Unable to allocate scratch: error = " + to_string(error);
last_error_ = "Unable to allocate scratch: error = " + to_string(error);
return -1;
}

Expand Down Expand Up @@ -129,23 +124,23 @@ bool HyperscanWrapper::Match(const string& input, set<unsigned int>* results) {
return true;
}

const char* HyperscanWrapper::GetLastError() const {
return last_error_.c_str();
std::string HyperscanWrapper::GetLastError() const {
return last_error_;
}

void HyperscanWrapper::CleanUp() {
ch_error_t error;
if (scratch_ != nullptr) {
error = ch_free_scratch(scratch_);
if (error != CH_SUCCESS) {
fprintf(stderr, "ERROR: Unable to free scratch: error = %d\n", error);
fprintf(stderr, "Unable to free scratch: error = %d\n", error);
}
scratch_ = nullptr;
}
if (pattern_database_ != nullptr) {
error = ch_free_database(pattern_database_);
if (error != CH_SUCCESS) {
fprintf(stderr, "ERROR: Unable to free pattern database: error = %d\n", error);
fprintf(stderr, "Unable to free pattern database: error = %d\n", error);
}
pattern_database_ = nullptr;
}
Expand Down
4 changes: 2 additions & 2 deletions src/main/cpp/jni/hyperscan_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ class HyperscanWrapper {
// The caller MUST call this method before calling Match() if the patterns set has changed.
// Returns 0 on success, < 0 if error is not specific to a pattern, > 0 id of the first
// erroneous pattern. Call GetLastError() for a string explanation of the error.
int CompilePatterns();
int64_t CompilePatterns();

// Matches the given string against all patterns in the current pattern set and stores the
// pattern ids that match in results.
// Returns false if an error occurs. Call GetLastError() for more information.
bool Match(const std::string& input, std::set<unsigned int>* results);

// Returns a string explanation of the last error that has occurred.
const char* GetLastError() const;
std::string GetLastError() const;

private:
void CleanUp();
Expand Down
131 changes: 84 additions & 47 deletions src/main/cpp/jni/ir_sahab_regexmatcher_RegexMatcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
#include <set>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <climits>

using sahab::HyperscanWrapper;

const char* java_assertion_error_path = "java/lang/AssertionError";
const char* java_illegal_argument_exception_path = "java/lang/IllegalArgumentException";
const char* java_pattern_preparation_exception_path = "ir/sahab/regexmatcher/exception/PatternPreparationException";

// The whole point of these class is to provide a mapping from Java world to C++ world.
Expand All @@ -20,99 +24,130 @@ const char* java_pattern_preparation_exception_path = "ir/sahab/regexmatcher/exc
// 3) javah -cp . <full_package_name_of_desired_class> (example ir.sahab.regexmatcher.RegexMatcher)
// Note: We also keep track of live HyperscanWrapper instances. This is done by keeping a map of
// created instances. An instance is created by calling newInstance() and it is destroyed by calling close().
unsigned int num_created_instances = 0;
std::map<unsigned int, std::unique_ptr<HyperscanWrapper> > instances;
int64_t num_created_instances = 0;
std::map<int64_t, std::unique_ptr<HyperscanWrapper> > instances;
std::mutex instances_map_mutex;

static void ThrowJavaException(JNIEnv* jenv, const char* java_error_class_path, std::string message) {
jclass clazz = jenv->FindClass(java_error_class_path);
if (jenv->ThrowNew(clazz, message.c_str()) < 0) {
fprintf(stderr, "Failed to throw exception: %s.", message.c_str());
std::exit(EXIT_FAILURE);
}
}

static HyperscanWrapper* GetHyperscanInstance(JNIEnv* jenv, jlong handle) {
static HyperscanWrapper* GetHyperscanInstance(JNIEnv* jenv, jlong jinstance_id) {
HyperscanWrapper* instance = nullptr;
auto instance_id = static_cast<int64_t>(jinstance_id);
try {
instance = instances.at(static_cast<unsigned int>(handle)).get();
instances_map_mutex.lock();
instance = instances.at(instance_id).get();
} catch (std::out_of_range& e) {
jclass clazz = jenv->FindClass(java_assertion_error_path);
std::string msg = "Invalid instance handle: handle = ";
msg += handle;
if (jenv->ThrowNew(clazz, msg.c_str()) < 0) {
fprintf(stderr, "Failed to throw exception: %s.", msg.c_str());
std::exit(EXIT_FAILURE);
}
ThrowJavaException(jenv, java_assertion_error_path, "Either instance closed or not valid: Instance ID = "
+ std::to_string(instance_id));
}
instances_map_mutex.unlock();
return instance;
}

JNIEXPORT jlong JNICALL Java_ir_sahab_regexmatcher_RegexMatcher_newInstance(
JNIEnv* jenv, jobject jobj) {
++num_created_instances;
instances_map_mutex.lock();
num_created_instances++;
instances[num_created_instances] = std::unique_ptr<HyperscanWrapper>(new HyperscanWrapper());
return num_created_instances;
instances_map_mutex.unlock();
return static_cast<jlong>(num_created_instances);
}

JNIEXPORT void JNICALL Java_ir_sahab_regexmatcher_RegexMatcher_close(
JNIEnv* jenv, jobject jobj, jlong jinstance_handle) {
instances.erase(static_cast<unsigned int>(jinstance_handle));
JNIEnv* jenv, jobject jobj, jlong jinstance_id) {
instances_map_mutex.lock();
instances.erase(static_cast<int64_t>(jinstance_id));
instances_map_mutex.unlock();
}

JNIEXPORT void JNICALL Java_ir_sahab_regexmatcher_RegexMatcher_addPattern(
JNIEnv* jenv, jobject jobj, jlong jinstance_handle, jlong jpattern_id, jstring jpattern,
JNIEnv* jenv, jobject jobj, jlong jinstance_id, jlong jpattern_id, jstring jpattern,
jboolean is_case_sensitive) {
HyperscanWrapper* instance = GetHyperscanInstance(jenv, jinstance_handle);
if (instance == nullptr)
return;
HyperscanWrapper* instance = GetHyperscanInstance(jenv, jinstance_id);
if (instance == nullptr) {
return; // C++ continues to work after Java exception is thrown
}

auto pattern_id = static_cast<int64_t>(jpattern_id);
if (pattern_id <= 0 || pattern_id > UINT_MAX) {
ThrowJavaException(jenv, java_illegal_argument_exception_path, "Pattern ID must between 0 and "
+ std::to_string(UINT_MAX) + ": pattern ID = " + std::to_string(pattern_id));
return; // C++ continues to work after Java exception is thrown
}

auto pattern = jenv->GetStringUTFChars(jpattern, nullptr);
instance->AddPattern(static_cast<unsigned int>(jpattern_id), pattern, is_case_sensitive);
if (pattern == nullptr) {
ThrowJavaException(jenv, java_assertion_error_path, "Unable to convert java 'pattern' string to cpp string!");
return; // C++ continues to work after Java exception is thrown
}
instance->AddPattern(static_cast<unsigned int>(pattern_id), pattern, (bool)(is_case_sensitive == JNI_TRUE));
jenv->ReleaseStringUTFChars(jpattern, pattern);
}

JNIEXPORT jboolean JNICALL Java_ir_sahab_regexmatcher_RegexMatcher_removePattern(
JNIEnv* jenv, jobject jobj, jlong jinstance_handle, jlong jpattern_id) {
HyperscanWrapper* instance = GetHyperscanInstance(jenv, jinstance_handle);
if (instance == nullptr)
return JNI_FALSE;
JNIEnv* jenv, jobject jobj, jlong jinstance_id, jlong jpattern_id) {
HyperscanWrapper* instance = GetHyperscanInstance(jenv, jinstance_id);
if (instance == nullptr) {
return JNI_FALSE; // C++ continues to work after Java exception is thrown
}

auto pattern_id = static_cast<int64_t>(jpattern_id);
if (pattern_id <= 0 || pattern_id > UINT_MAX) {
ThrowJavaException(jenv, java_illegal_argument_exception_path, "Pattern ID must between 0 and "
+ std::to_string(UINT_MAX) + ": pattern ID = " + std::to_string(pattern_id));
return JNI_FALSE; // C++ continues to work after Java exception is thrown
}

return instance->RemovePattern(static_cast<unsigned int>(jpattern_id)) ?
JNI_TRUE : JNI_FALSE;
return instance->RemovePattern(static_cast<unsigned int>(pattern_id)) ? JNI_TRUE : JNI_FALSE;
}

JNIEXPORT void JNICALL Java_ir_sahab_regexmatcher_RegexMatcher_preparePatterns(
JNIEnv* jenv, jobject jobj, jlong jinstance_handle) {
HyperscanWrapper* instance = GetHyperscanInstance(jenv, jinstance_handle);
if (instance == nullptr)
return;
JNIEnv* jenv, jobject jobj, jlong jinstance_id) {
HyperscanWrapper* instance = GetHyperscanInstance(jenv, jinstance_id);
if (instance == nullptr) {
return; // C++ continues to work after Java exception is thrown
}

auto result = static_cast<jlong>(instance->CompilePatterns());
auto result = instance->CompilePatterns();
if (result != 0) {
jclass clazz = jenv->FindClass(java_pattern_preparation_exception_path);
jmethodID clazz_constructor = jenv->GetMethodID(clazz, "<init>", "(Ljava/lang/String;J)V");
std::string msg = "Failed to prepare patterns: ";
msg.append(instance->GetLastError());
std::string msg = "Failed to prepare patterns: " + instance->GetLastError();
auto jexception = jenv->NewObject(clazz, clazz_constructor,
jenv->NewStringUTF(msg.c_str()), result);
jenv->NewStringUTF(msg.c_str()), static_cast<jlong>(result));
if (jenv->Throw(static_cast<jthrowable>(jexception)) < 0) {
fprintf(stderr, "Failed to throw exception: %s.", msg.c_str());
std::exit(EXIT_FAILURE);
}
return;
}
return;
}

JNIEXPORT jobject JNICALL Java_ir_sahab_regexmatcher_RegexMatcher_match(
JNIEnv* jenv, jobject, jlong jinstance_handle, jstring jinput) {
HyperscanWrapper* instance = GetHyperscanInstance(jenv, jinstance_handle);
if (instance == nullptr)
return nullptr;
JNIEnv* jenv, jobject, jlong jinstance_id, jstring jinput) {
HyperscanWrapper* instance = GetHyperscanInstance(jenv, jinstance_id);
if (instance == nullptr) {
return nullptr; // C++ continues to work after Java exception is thrown
}

auto input = jenv->GetStringUTFChars(jinput, nullptr);
if (input == nullptr) {
ThrowJavaException(jenv, java_assertion_error_path, "Unable to convert java 'input' string to cpp string!");
return nullptr; // C++ continues to work after Java exception is thrown
}
std::set<unsigned int> results;
auto error_occurred = !instance->Match(input, &results);
jenv->ReleaseStringUTFChars(jinput, input);

if (error_occurred) {
jclass clazz = jenv->FindClass(java_assertion_error_path);
if (jenv->ThrowNew(clazz, instance->GetLastError()) < 0) {
fprintf(stderr, "Failed to throw exception: %s.", instance->GetLastError());
std::exit(EXIT_FAILURE);
}
return nullptr;
ThrowJavaException(jenv, java_assertion_error_path, instance->GetLastError());
return nullptr; // C++ continues to work after Java exception is thrown
}

auto jarraylist_clazz = jenv->FindClass("java/util/ArrayList");
Expand All @@ -126,10 +161,12 @@ JNIEXPORT jobject JNICALL Java_ir_sahab_regexmatcher_RegexMatcher_match(
for (auto& result : results) {
auto element = jenv->NewObject(jlong_clazz, jlong_constructor, static_cast<jlong>(result));
if (JNI_TRUE != jenv->CallBooleanMethod(jresult, jarraylist_add, element)) {
fprintf(stderr, "Element was not added to array: %d", result);
ThrowJavaException(jenv, java_assertion_error_path, "Element was not added to array: "
+ std::to_string(result));
return nullptr; // C++ continues to work after Java exception is thrown
}
}
}

return jresult;
}
}
Loading

0 comments on commit fd38999

Please sign in to comment.