diff --git a/src/agent/agent_info/include/agent_info.hpp b/src/agent/agent_info/include/agent_info.hpp index 239970711a..d38119b3fd 100644 --- a/src/agent/agent_info/include/agent_info.hpp +++ b/src/agent/agent_info/include/agent_info.hpp @@ -4,9 +4,12 @@ #include #include +#include #include #include +class AgentInfoPersistance; + /// @brief Stores and manages information about an agent. /// /// This class provides methods for getting and setting the agent's name, key, @@ -25,10 +28,12 @@ class AgentInfo /// @param getOSInfo Function to retrieve OS information in JSON format. /// @param getNetworksInfo Function to retrieve network information in JSON format. /// @param agentIsRegistering True if the agent is being registered, false otherwise. + /// @param persistence Optional pointer to an AgentInfoPersistance object. AgentInfo(std::string dbFolderPath = config::DEFAULT_DATA_PATH, std::function getOSInfo = nullptr, std::function getNetworksInfo = nullptr, - bool agentIsRegistering = false); + bool agentIsRegistering = false, + std::shared_ptr persistence = nullptr); /// @brief Gets the agent's name. /// @return The agent's name. @@ -141,4 +146,7 @@ class AgentInfo /// @brief Specify if the agent is about to register. bool m_agentIsRegistering; + + /// @brief Pointer to the agent info persistence instance. + std::shared_ptr m_persistence; }; diff --git a/src/agent/agent_info/src/agent_info.cpp b/src/agent/agent_info/src/agent_info.cpp index dbc4433790..ab55bac147 100644 --- a/src/agent/agent_info/src/agent_info.cpp +++ b/src/agent/agent_info/src/agent_info.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -19,17 +20,18 @@ namespace AgentInfo::AgentInfo(std::string dbFolderPath, std::function getOSInfo, std::function getNetworksInfo, - bool agentIsRegistering) + bool agentIsRegistering, + std::shared_ptr persistence) : m_dataFolderPath(std::move(dbFolderPath)) , m_agentIsRegistering(agentIsRegistering) + , m_persistence(persistence ? std::move(persistence) : std::make_shared(m_dataFolderPath)) { if (!m_agentIsRegistering) { - AgentInfoPersistance agentInfoPersistance(m_dataFolderPath); - m_name = agentInfoPersistance.GetName(); - m_key = agentInfoPersistance.GetKey(); - m_uuid = agentInfoPersistance.GetUUID(); - m_groups = agentInfoPersistance.GetGroups(); + m_name = m_persistence->GetName(); + m_key = m_persistence->GetKey(); + m_uuid = m_persistence->GetUUID(); + m_groups = m_persistence->GetGroups(); } if (m_uuid.empty()) @@ -183,18 +185,16 @@ std::string AgentInfo::GetMetadataInfo() const void AgentInfo::Save() const { - AgentInfoPersistance agentInfoPersistance(m_dataFolderPath); - agentInfoPersistance.ResetToDefault(); - agentInfoPersistance.SetName(m_name); - agentInfoPersistance.SetKey(m_key); - agentInfoPersistance.SetUUID(m_uuid); - agentInfoPersistance.SetGroups(m_groups); + m_persistence->ResetToDefault(); + m_persistence->SetName(m_name); + m_persistence->SetKey(m_key); + m_persistence->SetUUID(m_uuid); + m_persistence->SetGroups(m_groups); } bool AgentInfo::SaveGroups() const { - AgentInfoPersistance agentInfoPersistance(m_dataFolderPath); - return agentInfoPersistance.SetGroups(m_groups); + return m_persistence->SetGroups(m_groups); } std::vector AgentInfo::GetActiveIPAddresses(const nlohmann::json& networksJson) const diff --git a/src/agent/agent_info/src/agent_info_persistance.cpp b/src/agent/agent_info/src/agent_info_persistance.cpp index 478b0438cc..15109b0d90 100644 --- a/src/agent/agent_info/src/agent_info_persistance.cpp +++ b/src/agent/agent_info/src/agent_info_persistance.cpp @@ -25,13 +25,20 @@ namespace const std::string AGENT_GROUP_NAME_COLUMN_NAME = "name"; } // namespace -AgentInfoPersistance::AgentInfoPersistance(const std::string& dbFolderPath) +AgentInfoPersistance::AgentInfoPersistance(const std::string& dbFolderPath, std::unique_ptr persistence) { const auto dbFilePath = dbFolderPath + "/" + AGENT_INFO_DB_NAME; try { - m_db = PersistenceFactory::CreatePersistence(PersistenceFactory::PersistenceType::SQLITE3, dbFilePath); + if (persistence) + { + m_db = std::move(persistence); + } + else + { + m_db = PersistenceFactory::CreatePersistence(PersistenceFactory::PersistenceType::SQLITE3, dbFilePath); + } if (!m_db->TableExists(AGENT_INFO_TABLE_NAME)) { @@ -126,17 +133,20 @@ void AgentInfoPersistance::InsertDefaultAgentInfo() } } -void AgentInfoPersistance::SetAgentInfoValue(const std::string& column, const std::string& value) +bool AgentInfoPersistance::SetAgentInfoValue(const std::string& column, const std::string& value) { try { const Row columns = {ColumnValue(column, ColumnType::TEXT, value)}; m_db->Update(AGENT_INFO_TABLE_NAME, columns); + return true; } catch (const std::exception& e) { LogError("Error updating {}: {}.", column, e.what()); } + + return false; } std::string AgentInfoPersistance::GetAgentInfoValue(const std::string& column) const @@ -201,24 +211,35 @@ std::vector AgentInfoPersistance::GetGroups() const return groupList; } -void AgentInfoPersistance::SetName(const std::string& name) +bool AgentInfoPersistance::SetName(const std::string& name) { - SetAgentInfoValue(AGENT_INFO_NAME_COLUMN_NAME, name); + return SetAgentInfoValue(AGENT_INFO_NAME_COLUMN_NAME, name); } -void AgentInfoPersistance::SetKey(const std::string& key) +bool AgentInfoPersistance::SetKey(const std::string& key) { - SetAgentInfoValue(AGENT_INFO_KEY_COLUMN_NAME, key); + return SetAgentInfoValue(AGENT_INFO_KEY_COLUMN_NAME, key); } -void AgentInfoPersistance::SetUUID(const std::string& uuid) +bool AgentInfoPersistance::SetUUID(const std::string& uuid) { - SetAgentInfoValue(AGENT_INFO_UUID_COLUMN_NAME, uuid); + return SetAgentInfoValue(AGENT_INFO_UUID_COLUMN_NAME, uuid); } bool AgentInfoPersistance::SetGroups(const std::vector& groupList) { - auto transaction = m_db->BeginTransaction(); + TransactionId transaction = 0; + + // Handle the exception separately since it would not be necessary to perform RollBack. + try + { + transaction = m_db->BeginTransaction(); + } + catch (const std::exception& e) + { + LogError("Failed to begin transaction: {}.", e.what()); + return false; + } try { @@ -235,13 +256,22 @@ bool AgentInfoPersistance::SetGroups(const std::vector& groupList) catch (const std::exception& e) { LogError("Error inserting group: {}.", e.what()); - m_db->RollbackTransaction(transaction); + + try + { + m_db->RollbackTransaction(transaction); + } + catch (const std::exception& ee) + { + LogError("Rollback failed: {}.", ee.what()); + } + return false; } return true; } -void AgentInfoPersistance::ResetToDefault() +bool AgentInfoPersistance::ResetToDefault() { try { @@ -250,9 +280,11 @@ void AgentInfoPersistance::ResetToDefault() CreateAgentInfoTable(); CreateAgentGroupTable(); InsertDefaultAgentInfo(); + return true; } catch (const std::exception& e) { LogError("Error resetting to default values: {}.", e.what()); } + return false; } diff --git a/src/agent/agent_info/src/agent_info_persistance.hpp b/src/agent/agent_info/src/agent_info_persistance.hpp index 2896f5deb6..2480210b25 100644 --- a/src/agent/agent_info/src/agent_info_persistance.hpp +++ b/src/agent/agent_info/src/agent_info_persistance.hpp @@ -1,18 +1,18 @@ #pragma once +#include + #include #include #include -class Persistence; - /// @brief Manages persistence of agent information and groups in a database. class AgentInfoPersistance { public: /// @brief Constructs the persistence manager for agent info, initializing the database and tables if necessary. /// @param dbFolderPath Path to the database folder. - explicit AgentInfoPersistance(const std::string& dbFolderPath); + explicit AgentInfoPersistance(const std::string& dbFolderPath, std::unique_ptr persistence = nullptr); /// @brief Destructor for AgentInfoPersistance. ~AgentInfoPersistance(); @@ -47,15 +47,18 @@ class AgentInfoPersistance /// @brief Sets the agent's name in the database. /// @param name The name to set. - void SetName(const std::string& name); + /// @return True if the operation was successful, false otherwise. + bool SetName(const std::string& name); /// @brief Sets the agent's key in the database. /// @param key The key to set. - void SetKey(const std::string& key); + /// @return True if the operation was successful, false otherwise. + bool SetKey(const std::string& key); /// @brief Sets the agent's UUID in the database. /// @param uuid The UUID to set. - void SetUUID(const std::string& uuid); + /// @return True if the operation was successful, false otherwise. + bool SetUUID(const std::string& uuid); /// @brief Sets the agent's group list in the database, replacing any existing groups. /// @param groupList A vector of strings, each representing a group name. @@ -63,7 +66,8 @@ class AgentInfoPersistance bool SetGroups(const std::vector& groupList); /// @brief Resets the database tables to default values, clearing all data. - void ResetToDefault(); + /// @return True if the reset was successful, false otherwise. + bool ResetToDefault(); private: /// @brief Checks if the agent info table is empty. @@ -82,7 +86,8 @@ class AgentInfoPersistance /// @brief Sets a specific agent info value in the database. /// @param column The name of the column to set. /// @param value The value to set in the specified column. - void SetAgentInfoValue(const std::string& column, const std::string& value); + /// @return True if the operation was successful, false otherwise. + bool SetAgentInfoValue(const std::string& column, const std::string& value); /// @brief Retrieves a specific agent info value from the database. /// @param column The name of the column to retrieve. diff --git a/src/agent/agent_info/tests/CMakeLists.txt b/src/agent/agent_info/tests/CMakeLists.txt index 4e4ff6c449..b640b82713 100644 --- a/src/agent/agent_info/tests/CMakeLists.txt +++ b/src/agent/agent_info/tests/CMakeLists.txt @@ -2,12 +2,16 @@ find_package(GTest CONFIG REQUIRED) add_executable(agent_info_test agent_info_test.cpp) configure_target(agent_info_test) -target_include_directories(agent_info_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../src) -target_link_libraries(agent_info_test PRIVATE AgentInfo Persistence GTest::gtest) +target_include_directories(agent_info_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../src + ${CMAKE_CURRENT_SOURCE_DIR}/../../persistence/tests/mocks) +target_link_libraries(agent_info_test PRIVATE AgentInfo Persistence GTest::gtest GTest::gmock GTest::gmock_main) add_test(NAME AgentInfoTest COMMAND agent_info_test) add_executable(agent_info_persistance_test agent_info_persistance_test.cpp) configure_target(agent_info_persistance_test) -target_include_directories(agent_info_persistance_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../src) -target_link_libraries(agent_info_persistance_test PRIVATE AgentInfo Persistence GTest::gtest) +target_include_directories(agent_info_persistance_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../src + ${CMAKE_CURRENT_SOURCE_DIR}/../../persistence/tests/mocks) +target_link_libraries(agent_info_persistance_test PRIVATE AgentInfo Persistence GTest::gtest GTest::gmock) add_test(NAME AgentInfoPersistanceTest COMMAND agent_info_persistance_test) diff --git a/src/agent/agent_info/tests/agent_info_persistance_test.cpp b/src/agent/agent_info/tests/agent_info_persistance_test.cpp index 0abd3a6dd1..0484ace062 100644 --- a/src/agent/agent_info/tests/agent_info_persistance_test.cpp +++ b/src/agent/agent_info/tests/agent_info_persistance_test.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include @@ -9,75 +10,397 @@ class AgentInfoPersistanceTest : public ::testing::Test { protected: + MockPersistence* mockPersistence = nullptr; + std::unique_ptr agentInfoPersistance; + void SetUp() override { - persistance = std::make_unique("."); - persistance->ResetToDefault(); - } + auto mockPersistencePtr = std::make_unique(); + mockPersistence = mockPersistencePtr.get(); - std::unique_ptr persistance; + EXPECT_CALL(*mockPersistence, TableExists("agent_info")).WillOnce(testing::Return(true)); + EXPECT_CALL(*mockPersistence, TableExists("agent_group")).WillOnce(testing::Return(true)); + EXPECT_CALL(*mockPersistence, GetCount("agent_info", testing::_, testing::_)) + .WillOnce(testing::Return(0)) + .WillOnce(testing::Return(0)); + EXPECT_CALL(*mockPersistence, Insert("agent_info", testing::_)).Times(1); + + agentInfoPersistance = std::make_unique("db_path", std::move(mockPersistencePtr)); + } }; TEST_F(AgentInfoPersistanceTest, TestConstruction) { - EXPECT_NE(persistance, nullptr); + EXPECT_NE(agentInfoPersistance, nullptr); +} + +TEST_F(AgentInfoPersistanceTest, TestGetNameValue) +{ + std::vector mockRowName = {{column::ColumnValue("name", column::ColumnType::TEXT, "name_test")}}; + EXPECT_CALL(*mockPersistence, + Select("agent_info", testing::_, testing::_, testing::_, testing::_, testing::_, testing::_)) + .WillOnce(testing::Return(mockRowName)); + EXPECT_EQ(agentInfoPersistance->GetName(), "name_test"); } -TEST_F(AgentInfoPersistanceTest, TestDefaultValues) +TEST_F(AgentInfoPersistanceTest, TestGetNameNotValue) { - EXPECT_EQ(persistance->GetName(), ""); - EXPECT_EQ(persistance->GetKey(), ""); - EXPECT_EQ(persistance->GetUUID(), ""); + std::vector mockRowName = {}; + EXPECT_CALL(*mockPersistence, + Select("agent_info", testing::_, testing::_, testing::_, testing::_, testing::_, testing::_)) + .WillOnce(testing::Return(mockRowName)); + EXPECT_EQ(agentInfoPersistance->GetName(), ""); +} + +TEST_F(AgentInfoPersistanceTest, TestGetNameCatch) +{ + EXPECT_CALL(*mockPersistence, + Select("agent_info", testing::_, testing::_, testing::_, testing::_, testing::_, testing::_)) + .WillOnce(testing::Throw(std::runtime_error("Error Select"))); + EXPECT_EQ(agentInfoPersistance->GetName(), ""); +} + +TEST_F(AgentInfoPersistanceTest, TestGetKeyValue) +{ + std::vector mockRowKey = {{column::ColumnValue("key", column::ColumnType::TEXT, "key_test")}}; + EXPECT_CALL(*mockPersistence, + Select("agent_info", testing::_, testing::_, testing::_, testing::_, testing::_, testing::_)) + .WillOnce(testing::Return(mockRowKey)); + EXPECT_EQ(agentInfoPersistance->GetKey(), "key_test"); +} + +TEST_F(AgentInfoPersistanceTest, TestGetKeyNotValue) +{ + std::vector mockRowKey = {}; + EXPECT_CALL(*mockPersistence, + Select("agent_info", testing::_, testing::_, testing::_, testing::_, testing::_, testing::_)) + .WillOnce(testing::Return(mockRowKey)); + EXPECT_EQ(agentInfoPersistance->GetKey(), ""); +} + +TEST_F(AgentInfoPersistanceTest, TestGetKeyCatch) +{ + EXPECT_CALL(*mockPersistence, + Select("agent_info", testing::_, testing::_, testing::_, testing::_, testing::_, testing::_)) + .WillOnce(testing::Throw(std::runtime_error("Error Select"))); + EXPECT_EQ(agentInfoPersistance->GetKey(), ""); +} + +TEST_F(AgentInfoPersistanceTest, TestGetUUIDValue) +{ + std::vector mockRowUUID = {{column::ColumnValue("uuid", column::ColumnType::TEXT, "uuid_test")}}; + EXPECT_CALL(*mockPersistence, + Select("agent_info", testing::_, testing::_, testing::_, testing::_, testing::_, testing::_)) + .WillOnce(testing::Return(mockRowUUID)); + EXPECT_EQ(agentInfoPersistance->GetUUID(), "uuid_test"); +} + +TEST_F(AgentInfoPersistanceTest, TestGetUUIDNotValue) +{ + std::vector mockRowUUID = {}; + EXPECT_CALL(*mockPersistence, + Select("agent_info", testing::_, testing::_, testing::_, testing::_, testing::_, testing::_)) + .WillOnce(testing::Return(mockRowUUID)); + EXPECT_EQ(agentInfoPersistance->GetUUID(), ""); +} + +TEST_F(AgentInfoPersistanceTest, TestGetUUIDCatch) +{ + EXPECT_CALL(*mockPersistence, + Select("agent_info", testing::_, testing::_, testing::_, testing::_, testing::_, testing::_)) + .WillOnce(testing::Throw(std::runtime_error("Error Select"))); + EXPECT_EQ(agentInfoPersistance->GetUUID(), ""); +} + +TEST_F(AgentInfoPersistanceTest, TestGetGroupsValue) +{ + std::vector mockRowGroups = {{column::ColumnValue("name", column::ColumnType::TEXT, "group_1")}, + {column::ColumnValue("name", column::ColumnType::TEXT, "group_2")}}; + EXPECT_CALL(*mockPersistence, + Select("agent_group", testing::_, testing::_, testing::_, testing::_, testing::_, testing::_)) + .WillOnce(testing::Return(mockRowGroups)); + + std::vector expectedGroups = {"group_1", "group_2"}; + EXPECT_EQ(agentInfoPersistance->GetGroups(), expectedGroups); +} + +TEST_F(AgentInfoPersistanceTest, TestGetGroupsNotValue) +{ + std::vector mockRowGroups = {}; + EXPECT_CALL(*mockPersistence, + Select("agent_group", testing::_, testing::_, testing::_, testing::_, testing::_, testing::_)) + .WillOnce(testing::Return(mockRowGroups)); + + std::vector expectedGroups = {}; + EXPECT_EQ(agentInfoPersistance->GetGroups(), expectedGroups); +} + +TEST_F(AgentInfoPersistanceTest, TestGetGroupsCatch) +{ + EXPECT_CALL(*mockPersistence, + Select("agent_group", testing::_, testing::_, testing::_, testing::_, testing::_, testing::_)) + .WillOnce(testing::Throw(std::runtime_error("Error Select"))); + + std::vector expectedGroups = {}; + EXPECT_EQ(agentInfoPersistance->GetGroups(), expectedGroups); } TEST_F(AgentInfoPersistanceTest, TestSetName) { - const std::string newName = "new_name"; - persistance->SetName(newName); - EXPECT_EQ(persistance->GetName(), newName); + std::string expectedColumn = "name"; + std::string newName = "new_name"; + + EXPECT_CALL(*mockPersistence, + Update(testing::Eq("agent_info"), + testing::AllOf(testing::SizeIs(1), + testing::Contains( + testing::AllOf(testing::Field(&column::ColumnValue::Value, newName), + testing::Field(&column::ColumnName::Name, expectedColumn)))), + testing::_, + testing::_)) + .Times(1); + EXPECT_TRUE(agentInfoPersistance->SetName(newName)); +} + +TEST_F(AgentInfoPersistanceTest, TestSetNameCatch) +{ + std::string expectedColumn = "name"; + std::string newName = "new_name"; + + EXPECT_CALL(*mockPersistence, + Update(testing::Eq("agent_info"), + testing::AllOf(testing::SizeIs(1), + testing::Contains( + testing::AllOf(testing::Field(&column::ColumnValue::Value, newName), + testing::Field(&column::ColumnName::Name, expectedColumn)))), + testing::_, + testing::_)) + .WillOnce(testing::Throw(std::runtime_error("Error Update"))); + EXPECT_FALSE(agentInfoPersistance->SetName(newName)); } TEST_F(AgentInfoPersistanceTest, TestSetKey) { - const std::string newKey = "new_key"; - persistance->SetKey(newKey); - EXPECT_EQ(persistance->GetKey(), newKey); + std::string expectedColumn = "key"; + std::string newKey = "new_key"; + + EXPECT_CALL(*mockPersistence, + Update(testing::Eq("agent_info"), + testing::AllOf(testing::SizeIs(1), + testing::Contains( + testing::AllOf(testing::Field(&column::ColumnValue::Value, newKey), + testing::Field(&column::ColumnName::Name, expectedColumn)))), + testing::_, + testing::_)) + .Times(1); + EXPECT_TRUE(agentInfoPersistance->SetKey(newKey)); +} + +TEST_F(AgentInfoPersistanceTest, TestSetKeyCatch) +{ + std::string expectedColumn = "key"; + std::string newKey = "new_key"; + + EXPECT_CALL(*mockPersistence, + Update(testing::Eq("agent_info"), + testing::AllOf(testing::SizeIs(1), + testing::Contains( + testing::AllOf(testing::Field(&column::ColumnValue::Value, newKey), + testing::Field(&column::ColumnName::Name, expectedColumn)))), + testing::_, + testing::_)) + .WillOnce(testing::Throw(std::runtime_error("Error Update"))); + EXPECT_FALSE(agentInfoPersistance->SetKey(newKey)); } TEST_F(AgentInfoPersistanceTest, TestSetUUID) { - const std::string newUUID = "new_uuid"; - persistance->SetUUID(newUUID); - EXPECT_EQ(persistance->GetUUID(), newUUID); + std::string expectedColumn = "uuid"; + std::string newUUID = "new_uuid"; + + EXPECT_CALL(*mockPersistence, + Update(testing::Eq("agent_info"), + testing::AllOf(testing::SizeIs(1), + testing::Contains( + testing::AllOf(testing::Field(&column::ColumnValue::Value, newUUID), + testing::Field(&column::ColumnName::Name, expectedColumn)))), + testing::_, + testing::_)) + .Times(1); + EXPECT_TRUE(agentInfoPersistance->SetUUID(newUUID)); } -TEST_F(AgentInfoPersistanceTest, TestSetGroups) +TEST_F(AgentInfoPersistanceTest, TestSetUUIDCatch) { - const std::vector newGroups = {"group_1", "group_2"}; - persistance->SetGroups(newGroups); - EXPECT_EQ(persistance->GetGroups(), newGroups); + std::string expectedColumn = "uuid"; + std::string newUUID = "new_uuid"; + + EXPECT_CALL(*mockPersistence, + Update(testing::Eq("agent_info"), + testing::AllOf(testing::SizeIs(1), + testing::Contains( + testing::AllOf(testing::Field(&column::ColumnValue::Value, newUUID), + testing::Field(&column::ColumnName::Name, expectedColumn)))), + testing::_, + testing::_)) + .WillOnce(testing::Throw(std::runtime_error("Error Update"))); + EXPECT_FALSE(agentInfoPersistance->SetUUID(newUUID)); } -TEST_F(AgentInfoPersistanceTest, TestSetGroupsDelete) +TEST_F(AgentInfoPersistanceTest, TestSetGroupsSuccess) { - const std::vector oldGroups = {"group_1", "group_2"}; - const std::vector newGroups = {"group_3", "group_4"}; - persistance->SetGroups(oldGroups); - EXPECT_EQ(persistance->GetGroups(), oldGroups); - persistance->SetGroups(newGroups); - EXPECT_EQ(persistance->GetGroups(), newGroups); + const std::vector newGroups = {"t_group_1", "t_group_2"}; + + EXPECT_CALL(*mockPersistence, BeginTransaction()).Times(1); + EXPECT_CALL(*mockPersistence, Remove("agent_group", testing::_, testing::_)).Times(1); + EXPECT_CALL(*mockPersistence, Insert("agent_group", testing::_)).Times(2); + EXPECT_CALL(*mockPersistence, CommitTransaction(testing::_)).Times(1); + + EXPECT_TRUE(agentInfoPersistance->SetGroups(newGroups)); +} + +TEST_F(AgentInfoPersistanceTest, TestSetGroupsBeginTransactionFails) +{ + const std::vector newGroups = {"t_group_1", "t_group_2"}; + + EXPECT_CALL(*mockPersistence, BeginTransaction()) + .WillOnce(testing::Throw(std::runtime_error("Error BeginTransaction"))); + + EXPECT_FALSE(agentInfoPersistance->SetGroups(newGroups)); +} + +TEST_F(AgentInfoPersistanceTest, TestSetGroupsRemoveFails) +{ + const std::vector newGroups = {"t_group_1", "t_group_2"}; + + EXPECT_CALL(*mockPersistence, BeginTransaction()).Times(1); + EXPECT_CALL(*mockPersistence, Remove("agent_group", testing::_, testing::_)) + .WillOnce(testing::Throw(std::runtime_error("Error Remove"))); + EXPECT_CALL(*mockPersistence, RollbackTransaction(testing::_)).Times(1); + + EXPECT_FALSE(agentInfoPersistance->SetGroups(newGroups)); +} + +TEST_F(AgentInfoPersistanceTest, TestSetGroupsInsertFails1) +{ + const std::vector newGroups = {"t_group_1", "t_group_2"}; + + EXPECT_CALL(*mockPersistence, BeginTransaction()).Times(1); + EXPECT_CALL(*mockPersistence, Remove("agent_group", testing::_, testing::_)).Times(1); + EXPECT_CALL(*mockPersistence, Insert("agent_group", testing::_)) + .WillOnce(testing::Throw(std::runtime_error("Error Insert"))); + EXPECT_CALL(*mockPersistence, RollbackTransaction(testing::_)).Times(1); + + EXPECT_FALSE(agentInfoPersistance->SetGroups(newGroups)); +} + +TEST_F(AgentInfoPersistanceTest, TestSetGroupsInsertFails2) +{ + const std::vector newGroups = {"t_group_1", "t_group_2"}; + + EXPECT_CALL(*mockPersistence, BeginTransaction()).Times(1); + EXPECT_CALL(*mockPersistence, Remove("agent_group", testing::_, testing::_)).Times(1); + + testing::Sequence seq; + EXPECT_CALL(*mockPersistence, Insert("agent_group", testing::_)) + .InSequence(seq) + .WillOnce(testing::Return()) + .WillOnce(testing::Throw(std::runtime_error("Error Insert"))); + EXPECT_CALL(*mockPersistence, RollbackTransaction(testing::_)).Times(1); + + EXPECT_FALSE(agentInfoPersistance->SetGroups(newGroups)); +} + +TEST_F(AgentInfoPersistanceTest, TestSetGroupsCommitFails) +{ + const std::vector newGroups = {"t_group_1", "t_group_2"}; + + EXPECT_CALL(*mockPersistence, BeginTransaction()).Times(1); + EXPECT_CALL(*mockPersistence, Remove("agent_group", testing::_, testing::_)).Times(1); + EXPECT_CALL(*mockPersistence, Insert("agent_group", testing::_)).Times(2); + EXPECT_CALL(*mockPersistence, CommitTransaction(testing::_)) + .WillOnce(testing::Throw(std::runtime_error("Error Commit"))); + EXPECT_CALL(*mockPersistence, RollbackTransaction(testing::_)).Times(1); + + EXPECT_FALSE(agentInfoPersistance->SetGroups(newGroups)); +} + +TEST_F(AgentInfoPersistanceTest, TestSetGroupsRollbackFails) +{ + const std::vector newGroups = {"t_group_1", "t_group_2"}; + + EXPECT_CALL(*mockPersistence, BeginTransaction()).Times(1); + EXPECT_CALL(*mockPersistence, Remove("agent_group", testing::_, testing::_)).Times(1); + EXPECT_CALL(*mockPersistence, Insert("agent_group", testing::_)).Times(2); + EXPECT_CALL(*mockPersistence, CommitTransaction(testing::_)) + .WillOnce(testing::Throw(std::runtime_error("Error Commit"))); + EXPECT_CALL(*mockPersistence, RollbackTransaction(testing::_)) + .WillOnce(testing::Throw(std::runtime_error("Error Rollback"))); + + EXPECT_FALSE(agentInfoPersistance->SetGroups(newGroups)); +} + +TEST_F(AgentInfoPersistanceTest, TestResetToDefaultSuccess) +{ + EXPECT_CALL(*mockPersistence, DropTable("agent_info")).Times(1); + EXPECT_CALL(*mockPersistence, DropTable("agent_group")).Times(1); + EXPECT_CALL(*mockPersistence, CreateTable("agent_info", testing::_)).Times(1); + EXPECT_CALL(*mockPersistence, CreateTable("agent_group", testing::_)).Times(1); + EXPECT_CALL(*mockPersistence, GetCount("agent_info", testing::_, testing::_)).WillOnce(testing::Return(0)); + EXPECT_CALL(*mockPersistence, Insert(testing::_, testing::_)).Times(1); + + EXPECT_TRUE(agentInfoPersistance->ResetToDefault()); +} + +TEST_F(AgentInfoPersistanceTest, TestResetToDefaultDropTableAgentInfoFails) +{ + EXPECT_CALL(*mockPersistence, DropTable("agent_info")) + .WillOnce(testing::Throw(std::runtime_error("Error DropTable"))); + + EXPECT_FALSE(agentInfoPersistance->ResetToDefault()); +} + +TEST_F(AgentInfoPersistanceTest, TestResetToDefaultDropTableAgentGroupFails) +{ + EXPECT_CALL(*mockPersistence, DropTable("agent_info")).Times(1); + EXPECT_CALL(*mockPersistence, DropTable("agent_group")) + .WillOnce(testing::Throw(std::runtime_error("Error DropTable"))); + + EXPECT_FALSE(agentInfoPersistance->ResetToDefault()); +} + +TEST_F(AgentInfoPersistanceTest, TestResetToDefaultCreateAgentInfoTableFails) +{ + EXPECT_CALL(*mockPersistence, DropTable("agent_info")).Times(1); + EXPECT_CALL(*mockPersistence, DropTable("agent_group")).Times(1); + EXPECT_CALL(*mockPersistence, CreateTable("agent_info", testing::_)) + .WillOnce(testing::Throw(std::runtime_error("Error CreateAgentInfoTable"))); + + EXPECT_FALSE(agentInfoPersistance->ResetToDefault()); +} + +TEST_F(AgentInfoPersistanceTest, TestResetToDefaultCreateAgentGroupTableFails) +{ + EXPECT_CALL(*mockPersistence, DropTable("agent_info")).Times(1); + EXPECT_CALL(*mockPersistence, DropTable("agent_group")).Times(1); + EXPECT_CALL(*mockPersistence, CreateTable("agent_info", testing::_)).Times(1); + EXPECT_CALL(*mockPersistence, CreateTable("agent_group", testing::_)) + .WillOnce(testing::Throw(std::runtime_error("Error CreateAgentGroupTable"))); + + EXPECT_FALSE(agentInfoPersistance->ResetToDefault()); } -TEST_F(AgentInfoPersistanceTest, TestResetToDefault) +TEST_F(AgentInfoPersistanceTest, TestResetToDefaultInsertFails) { - const std::string newName = "new_name"; - persistance->SetName(newName); - EXPECT_EQ(persistance->GetName(), newName); + EXPECT_CALL(*mockPersistence, DropTable("agent_info")).Times(1); + EXPECT_CALL(*mockPersistence, DropTable("agent_group")).Times(1); + EXPECT_CALL(*mockPersistence, CreateTable("agent_info", testing::_)).Times(1); + EXPECT_CALL(*mockPersistence, CreateTable("agent_group", testing::_)).Times(1); + EXPECT_CALL(*mockPersistence, GetCount("agent_info", testing::_, testing::_)).WillOnce(testing::Return(0)); + EXPECT_CALL(*mockPersistence, Insert("agent_info", testing::_)) + .WillOnce(testing::Throw(std::runtime_error("Error Insert"))); - persistance->ResetToDefault(); - EXPECT_EQ(persistance->GetName(), ""); - EXPECT_EQ(persistance->GetKey(), ""); - EXPECT_EQ(persistance->GetUUID(), ""); + EXPECT_FALSE(agentInfoPersistance->ResetToDefault()); } int main(int argc, char** argv) diff --git a/src/agent/agent_info/tests/agent_info_test.cpp b/src/agent/agent_info/tests/agent_info_test.cpp index 9d17f7eba5..796330db76 100644 --- a/src/agent/agent_info/tests/agent_info_test.cpp +++ b/src/agent/agent_info/tests/agent_info_test.cpp @@ -2,146 +2,180 @@ #include #include +#include +#include #include #include class AgentInfoTest : public ::testing::Test { protected: + MockPersistence* mockPersistence = nullptr; + std::shared_ptr agentPersistence; + std::unique_ptr agentInfo; + void SetUp() override { - // We need to reset the database to the default state before each test - AgentInfoPersistance agentInfoPersistance("."); - agentInfoPersistance.ResetToDefault(); + InitializeAgentInfo(); + } + + void InitializeAgentInfo(const std::function& osLambda = nullptr, + const std::function& networksLambda = nullptr, + bool agentIsRegistering = false) + { + auto mockPersistencePtr = std::make_unique(); + mockPersistence = mockPersistencePtr.get(); + + SetUpPersistenceMock(); + + if (!agentIsRegistering) + { + SetUpAgentInfoInitialization(); + } + + agentPersistence = std::make_shared("db_path", std::move(mockPersistencePtr)); + + agentInfo = std::make_unique("db_path", + osLambda ? osLambda : nullptr, + networksLambda ? networksLambda : nullptr, + agentIsRegistering, + std::move(agentPersistence)); + } + + void SetUpPersistenceMock() + { + EXPECT_CALL(*mockPersistence, TableExists("agent_info")).WillOnce(testing::Return(true)); + EXPECT_CALL(*mockPersistence, TableExists("agent_group")).WillOnce(testing::Return(true)); + EXPECT_CALL(*mockPersistence, GetCount("agent_info", testing::_, testing::_)).WillOnce(testing::Return(1)); + } + + void SetUpAgentInfoInitialization() + { + std::vector mockRowName = {{column::ColumnValue("name", column::ColumnType::TEXT, "name_test")}}; + std::vector mockRowKey = {{column::ColumnValue("key", column::ColumnType::TEXT, "key_test")}}; + std::vector mockRowUUID = {{column::ColumnValue("uuid", column::ColumnType::TEXT, "uuid_test")}}; + std::vector mockRowGroup = {{column::ColumnValue("name", column::ColumnType::TEXT, "group_test")}}; + + testing::Sequence seq; + EXPECT_CALL(*mockPersistence, + Select("agent_info", testing::_, testing::_, testing::_, testing::_, testing::_, testing::_)) + .InSequence(seq) + .WillOnce(testing::Return(mockRowName)) + .WillOnce(testing::Return(mockRowKey)) + .WillOnce(testing::Return(mockRowUUID)); + EXPECT_CALL(*mockPersistence, + Select("agent_group", testing::_, testing::_, testing::_, testing::_, testing::_, testing::_)) + .WillOnce(testing::Return(mockRowGroup)); } }; TEST_F(AgentInfoTest, TestDefaultConstructorDefaultValues) { EXPECT_NO_THROW({ - const AgentInfo agentInfo("."); - EXPECT_EQ(agentInfo.GetName(), ""); - EXPECT_EQ(agentInfo.GetKey(), ""); - EXPECT_NE(agentInfo.GetUUID(), ""); + EXPECT_EQ(agentInfo->GetName(), "name_test"); + EXPECT_EQ(agentInfo->GetKey(), "key_test"); + EXPECT_NE(agentInfo->GetUUID(), ""); }); } -TEST_F(AgentInfoTest, TestPersistedValues) -{ - AgentInfo agentInfo("."); - agentInfo.SetName("test_name"); - agentInfo.SetKey("4GhT7uFm1zQa9c2Vb7Lk8pYsX0WqZrNj"); - agentInfo.SetUUID("test_uuid"); - agentInfo.Save(); - const AgentInfo agentInfoReloaded("."); - EXPECT_EQ(agentInfoReloaded.GetName(), "test_name"); - EXPECT_EQ(agentInfoReloaded.GetKey(), "4GhT7uFm1zQa9c2Vb7Lk8pYsX0WqZrNj"); - EXPECT_EQ(agentInfoReloaded.GetUUID(), "test_uuid"); -} - TEST_F(AgentInfoTest, TestSetName) { - AgentInfo agentInfo("."); - const std::string oldName = agentInfo.GetName(); const std::string newName = "new_name"; - agentInfo.SetName(newName); - EXPECT_EQ(agentInfo.GetName(), newName); - - const AgentInfo agentInfoReloaded("."); - EXPECT_EQ(agentInfoReloaded.GetName(), oldName); + agentInfo->SetName(newName); + EXPECT_EQ(agentInfo->GetName(), newName); } TEST_F(AgentInfoTest, TestSetKey) { - AgentInfo agentInfo("."); - const std::string oldKey = agentInfo.GetKey(); const std::string newKey = "4GhT7uFm1zQa9c2Vb7Lk8pYsX0WqZrNj"; - agentInfo.SetKey(newKey); - EXPECT_EQ(agentInfo.GetKey(), newKey); - - const AgentInfo agentInfoReloaded("."); - EXPECT_EQ(agentInfoReloaded.GetKey(), oldKey); + agentInfo->SetKey(newKey); + EXPECT_EQ(agentInfo->GetKey(), newKey); } TEST_F(AgentInfoTest, TestSetBadKey) { - AgentInfo agentInfo("."); const std::string newKey1 = "4GhT7uFm"; const std::string newKey2 = "4GhT7uFm1zQa9c2Vb7Lk8pYsX0WqZrN="; - ASSERT_FALSE(agentInfo.SetKey(newKey1)); - ASSERT_FALSE(agentInfo.SetKey(newKey2)); + ASSERT_FALSE(agentInfo->SetKey(newKey1)); + ASSERT_FALSE(agentInfo->SetKey(newKey2)); } TEST_F(AgentInfoTest, TestSetEmptyKey) { - AgentInfo agentInfo("."); const std::string newKey; - const std::string oldKey = agentInfo.GetKey(); - - agentInfo.SetKey(newKey); - EXPECT_NE(agentInfo.GetKey(), newKey); - const AgentInfo agentInfoReloaded("."); - EXPECT_EQ(agentInfoReloaded.GetKey(), oldKey); + agentInfo->SetKey(newKey); + EXPECT_NE(agentInfo->GetKey(), newKey); } TEST_F(AgentInfoTest, TestSetUUID) { - AgentInfo agentInfo("."); const std::string newUUID = "new_uuid"; - agentInfo.SetUUID(newUUID); - EXPECT_EQ(agentInfo.GetUUID(), newUUID); - - const AgentInfo agentInfoReloaded("."); - EXPECT_NE(agentInfoReloaded.GetUUID(), newUUID); + agentInfo->SetUUID(newUUID); + EXPECT_EQ(agentInfo->GetUUID(), newUUID); } TEST_F(AgentInfoTest, TestSetGroups) { - AgentInfo agentInfo("."); - const std::vector oldGroups = agentInfo.GetGroups(); const std::vector newGroups = {"t_group_1", "t_group_2"}; - agentInfo.SetGroups(newGroups); - EXPECT_EQ(agentInfo.GetGroups(), newGroups); - - const AgentInfo agentInfoReloaded("."); - EXPECT_EQ(agentInfoReloaded.GetGroups(), oldGroups); + agentInfo->SetGroups(newGroups); + EXPECT_EQ(agentInfo->GetGroups(), newGroups); } TEST_F(AgentInfoTest, TestSaveGroups) { - AgentInfo agentInfo("."); - const std::vector oldGroups = agentInfo.GetGroups(); const std::vector newGroups = {"t_group_1", "t_group_2"}; - agentInfo.SetGroups(newGroups); - agentInfo.SaveGroups(); - EXPECT_EQ(agentInfo.GetGroups(), newGroups); + EXPECT_CALL(*mockPersistence, BeginTransaction()).Times(1); + EXPECT_CALL(*mockPersistence, Remove("agent_group", testing::_, testing::_)).Times(1); + EXPECT_CALL(*mockPersistence, Insert("agent_group", testing::_)).Times(2); + EXPECT_CALL(*mockPersistence, CommitTransaction(testing::_)).Times(1); + + agentInfo->SetGroups(newGroups); + EXPECT_TRUE(agentInfo->SaveGroups()); + EXPECT_EQ(agentInfo->GetGroups(), newGroups); +} - const AgentInfo agentInfoReloaded("."); - EXPECT_EQ(agentInfoReloaded.GetGroups(), newGroups); +TEST_F(AgentInfoTest, TestSave) +{ + // Mock for: m_persistence->ResetToDefault(); + EXPECT_CALL(*mockPersistence, DropTable("agent_info")).Times(1); + EXPECT_CALL(*mockPersistence, DropTable("agent_group")).Times(1); + EXPECT_CALL(*mockPersistence, CreateTable(testing::_, testing::_)).Times(2); + EXPECT_CALL(*mockPersistence, GetCount("agent_info", testing::_, testing::_)).WillOnce(testing::Return(0)); + EXPECT_CALL(*mockPersistence, Insert(testing::_, testing::_)).Times(1); + + // Mock for: m_persistence->SetName(m_name); m_persistence->SetKey(m_key); m_persistence->SetUUID(m_uuid); + EXPECT_CALL(*mockPersistence, Update("agent_info", testing::_, testing::_, testing::_)).Times(3); + + // Mock for: m_persistence->SetGroups(m_groups); + EXPECT_CALL(*mockPersistence, BeginTransaction()).Times(1); + EXPECT_CALL(*mockPersistence, Remove("agent_group", testing::_, testing::_)).Times(1); + EXPECT_CALL(*mockPersistence, Insert("agent_group", testing::_)).Times(1); + EXPECT_CALL(*mockPersistence, CommitTransaction(testing::_)).Times(1); + agentInfo->Save(); } TEST_F(AgentInfoTest, TestLoadMetadataInfoNoSysInfo) { - const AgentInfo agentInfo(".", nullptr, nullptr, true); + InitializeAgentInfo(nullptr, nullptr, true); - auto metadataInfo = nlohmann::json::parse(agentInfo.GetMetadataInfo()); + auto metadataInfo = nlohmann::json::parse(agentInfo->GetMetadataInfo()); EXPECT_TRUE(metadataInfo != nullptr); // Agent information - EXPECT_EQ(metadataInfo["type"], agentInfo.GetType()); - EXPECT_EQ(metadataInfo["version"], agentInfo.GetVersion()); - EXPECT_EQ(metadataInfo["id"], agentInfo.GetUUID()); - EXPECT_EQ(metadataInfo["name"], agentInfo.GetName()); - EXPECT_EQ(metadataInfo["key"], agentInfo.GetKey()); + EXPECT_EQ(metadataInfo["type"], agentInfo->GetType()); + EXPECT_EQ(metadataInfo["version"], agentInfo->GetVersion()); + EXPECT_EQ(metadataInfo["id"], agentInfo->GetUUID()); + EXPECT_EQ(metadataInfo["name"], agentInfo->GetName()); + EXPECT_EQ(metadataInfo["key"], agentInfo->GetKey()); EXPECT_TRUE(metadataInfo["groups"] == nullptr); // Endpoint information @@ -176,18 +210,18 @@ TEST_F(AgentInfoTest, TestLoadMetadataInfoRegistration) networks["iface"].push_back(ip); - const AgentInfo agentInfo(".", [os]() { return os; }, [networks]() { return networks; }, true); + InitializeAgentInfo([os]() { return os; }, [networks]() { return networks; }, true); - auto metadataInfo = nlohmann::json::parse(agentInfo.GetMetadataInfo()); + auto metadataInfo = nlohmann::json::parse(agentInfo->GetMetadataInfo()); EXPECT_TRUE(metadataInfo != nullptr); // Agent information - EXPECT_EQ(metadataInfo["type"], agentInfo.GetType()); - EXPECT_EQ(metadataInfo["version"], agentInfo.GetVersion()); - EXPECT_EQ(metadataInfo["id"], agentInfo.GetUUID()); - EXPECT_EQ(metadataInfo["name"], agentInfo.GetName()); - EXPECT_EQ(metadataInfo["key"], agentInfo.GetKey()); + EXPECT_EQ(metadataInfo["type"], agentInfo->GetType()); + EXPECT_EQ(metadataInfo["version"], agentInfo->GetVersion()); + EXPECT_EQ(metadataInfo["id"], agentInfo->GetUUID()); + EXPECT_EQ(metadataInfo["name"], agentInfo->GetName()); + EXPECT_EQ(metadataInfo["key"], agentInfo->GetKey()); EXPECT_TRUE(metadataInfo["groups"] == nullptr); // Endpoint information @@ -227,17 +261,17 @@ TEST_F(AgentInfoTest, TestLoadMetadataInfoConnected) networks["iface"].push_back(ip); - const AgentInfo agentInfo(".", [os]() { return os; }, [networks]() { return networks; }); + InitializeAgentInfo([os]() { return os; }, [networks]() { return networks; }); - auto metadataInfo = nlohmann::json::parse(agentInfo.GetMetadataInfo()); + auto metadataInfo = nlohmann::json::parse(agentInfo->GetMetadataInfo()); EXPECT_TRUE(metadataInfo["agent"] != nullptr); // Agent information - EXPECT_EQ(metadataInfo["agent"]["type"], agentInfo.GetType()); - EXPECT_EQ(metadataInfo["agent"]["version"], agentInfo.GetVersion()); - EXPECT_EQ(metadataInfo["agent"]["id"], agentInfo.GetUUID()); - EXPECT_EQ(metadataInfo["agent"]["name"], agentInfo.GetName()); + EXPECT_EQ(metadataInfo["agent"]["type"], agentInfo->GetType()); + EXPECT_EQ(metadataInfo["agent"]["version"], agentInfo->GetVersion()); + EXPECT_EQ(metadataInfo["agent"]["id"], agentInfo->GetUUID()); + EXPECT_EQ(metadataInfo["agent"]["name"], agentInfo->GetName()); EXPECT_TRUE(metadataInfo["agent"]["key"] == nullptr); EXPECT_TRUE(metadataInfo["agent"]["groups"] != nullptr); @@ -255,12 +289,10 @@ TEST_F(AgentInfoTest, TestLoadMetadataInfoConnected) TEST_F(AgentInfoTest, TestLoadHeaderInfo) { - const AgentInfo agentInfo("."); - - auto headerInfo = agentInfo.GetHeaderInfo(); + auto headerInfo = agentInfo->GetHeaderInfo(); EXPECT_NE(headerInfo, ""); - EXPECT_TRUE(headerInfo.starts_with("WazuhXDR/" + agentInfo.GetVersion() + " (" + agentInfo.GetType() + "; ")); + EXPECT_TRUE(headerInfo.starts_with("WazuhXDR/" + agentInfo->GetVersion() + " (" + agentInfo->GetType() + "; ")); } int main(int argc, char** argv) diff --git a/src/agent/include/agent_registration.hpp b/src/agent/include/agent_registration.hpp index 8c7e5d9df3..3f2c1fc31b 100644 --- a/src/agent/include/agent_registration.hpp +++ b/src/agent/include/agent_registration.hpp @@ -35,6 +35,7 @@ namespace agent_registration /// @param name The agent's name. /// @param dbFolderPath The path to the database folder. /// @param verificationMode The connection verification mode. + /// @param agentInfo Shared pointer to the AgentInfo object to manage agent information. AgentRegistration(std::unique_ptr httpClient, std::string url, std::string user, @@ -42,7 +43,8 @@ namespace agent_registration const std::string& key, const std::string& name, const std::string& dbFolderPath, - std::string verificationMode); + std::string verificationMode, + std::shared_ptr agentInfo = nullptr); /// @brief Registers the agent with the manager. /// @@ -61,8 +63,8 @@ namespace agent_registration /// @brief The system's information. SysInfo m_sysInfo; - /// @brief The agent's information. - AgentInfo m_agentInfo; + /// @brief Pointer to the AgentInfo instance. + std::shared_ptr m_agentInfo; /// @brief The URL of the manager. std::string m_serverUrl; diff --git a/src/agent/persistence/tests/mocks/mocks_persistence.hpp b/src/agent/persistence/tests/mocks/mocks_persistence.hpp new file mode 100644 index 0000000000..72acc3a90b --- /dev/null +++ b/src/agent/persistence/tests/mocks/mocks_persistence.hpp @@ -0,0 +1,50 @@ +#include + +#include + +#include +#include + +class MockPersistence : public Persistence +{ +public: + MOCK_METHOD(bool, TableExists, (const std::string& tableName), (override)); + MOCK_METHOD(void, CreateTable, (const std::string& tableName, const column::Keys& cols), (override)); + MOCK_METHOD(void, Insert, (const std::string& tableName, const column::Row& cols), (override)); + MOCK_METHOD(void, + Update, + (const std::string& tableName, + const column::Row& fields, + const column::Criteria& selCriteria, + column::LogicalOperator logOp), + (override)); + MOCK_METHOD(void, + Remove, + (const std::string& tableName, const column::Criteria& selCriteria, column::LogicalOperator logOp), + (override)); + MOCK_METHOD(void, DropTable, (const std::string& tableName), (override)); + MOCK_METHOD(std::vector, + Select, + (const std::string& tableName, + const column::Names& fields, + const column::Criteria& selCriteria, + column::LogicalOperator logOp, + const column::Names& orderBy, + column::OrderType orderType, + int limit), + (override)); + MOCK_METHOD(int, + GetCount, + (const std::string& tableName, const column::Criteria& selCriteria, column::LogicalOperator logOp), + (override)); + MOCK_METHOD(size_t, + GetSize, + (const std::string& tableName, + const column::Names& fields, + const column::Criteria& selCriteria, + column::LogicalOperator logOp), + (override)); + MOCK_METHOD(TransactionId, BeginTransaction, (), (override)); + MOCK_METHOD(void, CommitTransaction, (TransactionId transactionId), (override)); + MOCK_METHOD(void, RollbackTransaction, (TransactionId transactionId), (override)); +}; diff --git a/src/agent/src/agent_registration.cpp b/src/agent/src/agent_registration.cpp index 5fd86762af..8d8eb96dbb 100644 --- a/src/agent/src/agent_registration.cpp +++ b/src/agent/src/agent_registration.cpp @@ -15,10 +15,15 @@ namespace agent_registration const std::string& key, const std::string& name, const std::string& dbFolderPath, - std::string verificationMode) + std::string verificationMode, + std::shared_ptr agentInfo) : m_httpClient(std::move(httpClient)) - , m_agentInfo( - dbFolderPath, [this]() { return m_sysInfo.os(); }, [this]() { return m_sysInfo.networks(); }, true) + , m_agentInfo(agentInfo ? std::move(agentInfo) + : std::make_shared( + dbFolderPath, + [this]() { return m_sysInfo.os(); }, + [this]() { return m_sysInfo.networks(); }, + true)) , m_serverUrl(std::move(url)) , m_user(std::move(user)) , m_password(std::move(password)) @@ -29,12 +34,12 @@ namespace agent_registration throw std::runtime_error("Invalid HTTP Client passed"); } - if (!m_agentInfo.SetKey(key)) + if (!m_agentInfo->SetKey(key)) { throw std::invalid_argument("--key argument must be alphanumeric and 32 characters in length"); } - if (!m_agentInfo.SetName(name)) + if (!m_agentInfo->SetName(name)) { throw std::runtime_error("Couldn't set agent name"); } @@ -53,11 +58,11 @@ namespace agent_registration const auto reqParams = http_client::HttpRequestParams(http_client::MethodType::POST, m_serverUrl, "/agents", - m_agentInfo.GetHeaderInfo(), + m_agentInfo->GetHeaderInfo(), m_verificationMode, token.value(), "", - m_agentInfo.GetMetadataInfo()); + m_agentInfo->GetMetadataInfo()); const auto res = m_httpClient->PerformHttpRequest(reqParams); const auto res_status = std::get<0>(res); @@ -68,7 +73,7 @@ namespace agent_registration return false; } - m_agentInfo.Save(); + m_agentInfo->Save(); return true; } @@ -77,7 +82,7 @@ namespace agent_registration const auto reqParams = http_client::HttpRequestParams(http_client::MethodType::POST, m_serverUrl, "/security/user/authenticate", - m_agentInfo.GetHeaderInfo(), + m_agentInfo->GetHeaderInfo(), m_verificationMode, "", m_user + ":" + m_password); diff --git a/src/agent/tests/CMakeLists.txt b/src/agent/tests/CMakeLists.txt index b124a398f0..3d41bf7046 100644 --- a/src/agent/tests/CMakeLists.txt +++ b/src/agent/tests/CMakeLists.txt @@ -11,9 +11,11 @@ endif() add_executable(agent_registration_test agent_registration_test.cpp) configure_target(agent_registration_test) -target_include_directories(agent_registration_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../src - ${CMAKE_CURRENT_SOURCE_DIR}/../agent_info/src) -target_link_libraries(agent_registration_test PRIVATE Agent GTest::gmock GTest::gtest) +target_include_directories(agent_registration_test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../src + ${CMAKE_CURRENT_SOURCE_DIR}/../agent_info/src + ${CMAKE_CURRENT_SOURCE_DIR}/../persistence/tests/mocks) +target_link_libraries(agent_registration_test PRIVATE Agent Persistence GTest::gmock GTest::gtest) add_test(NAME AgentRegistrationTest COMMAND agent_registration_test) add_executable(signal_handler_test signal_handler_test.cpp) diff --git a/src/agent/tests/agent_registration_test.cpp b/src/agent/tests/agent_registration_test.cpp index 188f7c573d..3f65e9b323 100644 --- a/src/agent/tests/agent_registration_test.cpp +++ b/src/agent/tests/agent_registration_test.cpp @@ -7,6 +7,7 @@ #include #include "../http_client/tests/mocks/mock_http_client.hpp" +#include #include #include @@ -20,32 +21,63 @@ class RegisterTest : public ::testing::Test protected: void SetUp() override { + auto mockPersistencePtr = std::make_unique(); + mockPersistence = mockPersistencePtr.get(); + + SetConstructorPersistenceExpectCalls(); + SysInfo sysInfo; - agent = std::make_unique( + agent = std::make_shared( ".", - [&sysInfo]() mutable { return sysInfo.os(); }, - [&sysInfo]() mutable { return sysInfo.networks(); }, - true); + [sysInfo]() mutable { return sysInfo.os(); }, + [sysInfo]() mutable { return sysInfo.networks(); }, + true, + std::make_shared("db_path", std::move(mockPersistencePtr))); agent->SetKey("4GhT7uFm1zQa9c2Vb7Lk8pYsX0WqZrNj"); agent->SetName("agent_name"); - agent->Save(); } - std::unique_ptr agent; + void SetConstructorPersistenceExpectCalls() + { + EXPECT_CALL(*mockPersistence, TableExists("agent_info")).WillOnce(testing::Return(true)); + EXPECT_CALL(*mockPersistence, TableExists("agent_group")).WillOnce(testing::Return(true)); + EXPECT_CALL(*mockPersistence, GetCount("agent_info", testing::_, testing::_)) + .WillOnce(testing::Return(0)) + .WillOnce(testing::Return(0)); + EXPECT_CALL(*mockPersistence, Insert("agent_info", testing::_)).Times(1); + } + + void SetAgentInfoSaveExpectCalls() + { + // Mock for: m_persistence->ResetToDefault(); + EXPECT_CALL(*mockPersistence, DropTable("agent_info")).Times(1); + EXPECT_CALL(*mockPersistence, DropTable("agent_group")).Times(1); + EXPECT_CALL(*mockPersistence, CreateTable(testing::_, testing::_)).Times(2); + EXPECT_CALL(*mockPersistence, GetCount("agent_info", testing::_, testing::_)).WillOnce(testing::Return(0)); + EXPECT_CALL(*mockPersistence, Insert(testing::_, testing::_)).Times(1); + + // Mock for: m_persistence->SetName(m_name); m_persistence->SetKey(m_key); m_persistence->SetUUID(m_uuid); + EXPECT_CALL(*mockPersistence, Update("agent_info", testing::_, testing::_, testing::_)).Times(3); + + // Mock for: m_persistence->SetGroups(m_groups); + EXPECT_CALL(*mockPersistence, BeginTransaction()).Times(1); + EXPECT_CALL(*mockPersistence, Remove("agent_group", testing::_, testing::_)).Times(1); + EXPECT_CALL(*mockPersistence, CommitTransaction(testing::_)).Times(1); + } + + std::shared_ptr agent; std::unique_ptr registration; + MockPersistence* mockPersistence = nullptr; }; TEST_F(RegisterTest, RegistrationTestSuccess) { - AgentInfoPersistance agentInfoPersistance("."); - agentInfoPersistance.ResetToDefault(); - auto mockHttpClient = std::make_unique(); auto mockHttpClientPtr = mockHttpClient.get(); registration = std::make_unique( - std::move(mockHttpClient), "https://localhost:55000", "user", "password", "", "", ".", "full"); + std::move(mockHttpClient), "https://localhost:55000", "user", "password", "", "", ".", "full", agent); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) std::tuple expectedResponse1 {200, R"({"data":{"token":"token"}})"}; @@ -59,15 +91,15 @@ TEST_F(RegisterTest, RegistrationTestSuccess) .WillOnce(testing::Return(expectedResponse2)); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + + SetAgentInfoSaveExpectCalls(); + const bool res = registration->Register(); ASSERT_TRUE(res); } TEST_F(RegisterTest, RegistrationFailsIfAuthenticationFails) { - AgentInfoPersistance agentInfoPersistance("."); - agentInfoPersistance.ResetToDefault(); - auto mockHttpClient = std::make_unique(); auto mockHttpClientPtr = mockHttpClient.get(); @@ -78,7 +110,8 @@ TEST_F(RegisterTest, RegistrationFailsIfAuthenticationFails) "4GhT7uFm1zQa9c2Vb7Lk8pYsX0WqZrNj", "agent_name", ".", - "certificate"); + "certificate", + agent); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) std::tuple expectedResponse {401, ""}; @@ -92,9 +125,6 @@ TEST_F(RegisterTest, RegistrationFailsIfAuthenticationFails) TEST_F(RegisterTest, RegistrationFailsIfServerResponseIsNotOk) { - AgentInfoPersistance agentInfoPersistance("."); - agentInfoPersistance.ResetToDefault(); - auto mockHttpClient = std::make_unique(); auto mockHttpClientPtr = mockHttpClient.get(); @@ -105,7 +135,8 @@ TEST_F(RegisterTest, RegistrationFailsIfServerResponseIsNotOk) "4GhT7uFm1zQa9c2Vb7Lk8pYsX0WqZrNj", "agent_name", ".", - "none"); + "none", + agent); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) std::tuple expectedResponse1 {200, R"({"data":{"token":"token"}})"}; @@ -125,16 +156,11 @@ TEST_F(RegisterTest, RegistrationFailsIfServerResponseIsNotOk) TEST_F(RegisterTest, RegisteringWithoutAKeyGeneratesOneAutomatically) { - AgentInfoPersistance agentInfoPersistance("."); - agentInfoPersistance.ResetToDefault(); - - EXPECT_TRUE(agentInfoPersistance.GetKey().empty()); - auto mockHttpClient = std::make_unique(); auto mockHttpClientPtr = mockHttpClient.get(); registration = std::make_unique( - std::move(mockHttpClient), "https://localhost:55000", "user", "password", "", "agent_name", ".", "full"); + std::move(mockHttpClient), "https://localhost:55000", "user", "password", "", "agent_name", ".", "full", agent); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) std::tuple expectedResponse1 {200, R"({"data":{"token":"token"}})"}; @@ -148,10 +174,53 @@ TEST_F(RegisterTest, RegisteringWithoutAKeyGeneratesOneAutomatically) .WillOnce(testing::Return(expectedResponse2)); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + + // Mock for: m_persistence->ResetToDefault(); + EXPECT_CALL(*mockPersistence, DropTable("agent_info")).Times(1); + EXPECT_CALL(*mockPersistence, DropTable("agent_group")).Times(1); + EXPECT_CALL(*mockPersistence, CreateTable(testing::_, testing::_)).Times(2); + EXPECT_CALL(*mockPersistence, GetCount("agent_info", testing::_, testing::_)).WillOnce(testing::Return(0)); + EXPECT_CALL(*mockPersistence, Insert(testing::_, testing::_)).Times(1); + + // Mock for: m_persistence->SetName(m_name); m_persistence->SetKey(m_key); m_persistence->SetUUID(m_uuid); + testing::InSequence seq; + EXPECT_CALL(*mockPersistence, + Update(testing::Eq("agent_info"), + testing::AllOf( + testing::SizeIs(1), + testing::Contains(testing::AllOf(testing::Field(&column::ColumnValue::Value, "agent_name"), + testing::Field(&column::ColumnName::Name, "name")))), + testing::_, + testing::_)) + .Times(1); + + EXPECT_CALL(*mockPersistence, + Update(testing::Eq("agent_info"), + testing::AllOf(testing::SizeIs(1), + testing::Contains(testing::AllOf( + testing::Field(&column::ColumnValue::Value, testing::Not(testing::Eq(""))), + testing::Field(&column::ColumnName::Name, "key")))), + testing::_, + testing::_)) + .Times(1); + + EXPECT_CALL(*mockPersistence, + Update(testing::Eq("agent_info"), + testing::AllOf(testing::SizeIs(1), + testing::Contains(testing::AllOf( + testing::Field(&column::ColumnValue::Value, testing::Not(testing::Eq(""))), + testing::Field(&column::ColumnName::Name, "uuid")))), + testing::_, + testing::_)) + .Times(1); + + // Mock for: m_persistence->SetGroups(m_groups); + EXPECT_CALL(*mockPersistence, BeginTransaction()).Times(1); + EXPECT_CALL(*mockPersistence, Remove("agent_group", testing::_, testing::_)).Times(1); + EXPECT_CALL(*mockPersistence, CommitTransaction(testing::_)).Times(1); + const bool res = registration->Register(); ASSERT_TRUE(res); - - EXPECT_FALSE(agentInfoPersistance.GetKey().empty()); } TEST_F(RegisterTest, RegistrationTestFailWithBadKey) @@ -165,14 +234,15 @@ TEST_F(RegisterTest, RegistrationTestFailWithBadKey) "badKey", "agent_name", ".", - "full"), + "full", + agent), std::invalid_argument); } TEST_F(RegisterTest, RegistrationTestFailWithHttpClientError) { ASSERT_THROW(agent_registration::AgentRegistration( - nullptr, "https://localhost:55000", "user", "password", "", "agent_name", ".", "full"), + nullptr, "https://localhost:55000", "user", "password", "", "agent_name", ".", "full", agent), std::runtime_error); } @@ -182,7 +252,7 @@ TEST_F(RegisterTest, AuthenticateWithUserPassword_Success) auto mockHttpClientPtr = mockHttpClient.get(); registration = std::make_unique( - std::move(mockHttpClient), "https://localhost:55000", "user", "password", "", "", ".", "full"); + std::move(mockHttpClient), "https://localhost:55000", "user", "password", "", "", ".", "full", agent); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) std::tuple expectedResponse {200, R"({"data":{"token":"valid_token"}})"}; @@ -203,7 +273,7 @@ TEST_F(RegisterTest, AuthenticateWithUserPassword_Failure) auto mockHttpClientPtr = mockHttpClient.get(); registration = std::make_unique( - std::move(mockHttpClient), "https://localhost:55000", "user", "password", "", "", ".", "full"); + std::move(mockHttpClient), "https://localhost:55000", "user", "password", "", "", ".", "full", agent); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) std::tuple expectedResponse {401, ""};