From 828d001d817c68846192a2c7284b11e8c250676e Mon Sep 17 00:00:00 2001 From: caffix Date: Fri, 13 Dec 2024 10:20:04 -0500 Subject: [PATCH] starting to implement the Repository interface for neo4j --- cache/entity.go | 8 +- repository/neo4j/db_test.go | 14 +- repository/neo4j/entity.go | 218 ++++++++++++++++++++++++++++++ repository/neo4j/entity_test.go | 7 + repository/repository.go | 2 +- repository/sqlrepo/entity.go | 6 +- repository/sqlrepo/entity_test.go | 2 +- 7 files changed, 247 insertions(+), 10 deletions(-) create mode 100644 repository/neo4j/entity.go create mode 100644 repository/neo4j/entity_test.go diff --git a/cache/entity.go b/cache/entity.go index 1a38543..543225c 100644 --- a/cache/entity.go +++ b/cache/entity.go @@ -59,9 +59,9 @@ func (c *Cache) FindEntityById(id string) (*types.Entity, error) { return c.cache.FindEntityById(id) } -// FindEntityByContent implements the Repository interface. -func (c *Cache) FindEntityByContent(asset oam.Asset, since time.Time) ([]*types.Entity, error) { - entities, err := c.cache.FindEntityByContent(asset, since) +// FindEntitiesByContent implements the Repository interface. +func (c *Cache) FindEntitiesByContent(asset oam.Asset, since time.Time) ([]*types.Entity, error) { + entities, err := c.cache.FindEntitiesByContent(asset, since) if err == nil && len(entities) > 0 { return entities, nil } @@ -70,7 +70,7 @@ func (c *Cache) FindEntityByContent(asset oam.Asset, since time.Time) ([]*types. return nil, err } - dbentities, dberr := c.db.FindEntityByContent(asset, since) + dbentities, dberr := c.db.FindEntitiesByContent(asset, since) if dberr != nil { return entities, err } diff --git a/repository/neo4j/db_test.go b/repository/neo4j/db_test.go index e7642a6..001b727 100644 --- a/repository/neo4j/db_test.go +++ b/repository/neo4j/db_test.go @@ -12,12 +12,13 @@ import ( "testing" neomigrations "github.com/owasp-amass/asset-db/migrations/neo4j" + "github.com/stretchr/testify/assert" ) var store *neoRepository func TestMain(m *testing.M) { - dsn := "bolt://neo4j:hackme4fun@localhost:7687/assetdb" + dsn := "bolt://neo4j:hackme4fun@localhost:7687/amass" store, err := New("neo4j", dsn) if err != nil { @@ -32,3 +33,14 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } + +func TestClose(t *testing.T) { + err := store.Close() + assert.NoError(t, err) +} + +func TestGetDBType(t *testing.T) { + if db := store.GetDBType(); db != Neo4j { + t.Errorf("Failed to return the correct database type") + } +} diff --git a/repository/neo4j/entity.go b/repository/neo4j/entity.go new file mode 100644 index 0000000..2113e87 --- /dev/null +++ b/repository/neo4j/entity.go @@ -0,0 +1,218 @@ +// Copyright © by Jeff Foley 2017-2024. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. +// SPDX-License-Identifier: Apache-2.0 + +package neo4j + +import ( + "context" + "errors" + "strconv" + "time" + + neo4jdb "github.com/neo4j/neo4j-go-driver/v5/neo4j" + "github.com/owasp-amass/asset-db/types" + "gorm.io/gorm" +) + +// CreateEntity creates a new entity in the database. +// It takes an Entity as input and persists it in the database. +// The asset is serialized to JSON and stored in the Content field of the Entity struct. +// Returns the created entity as a types.Entity or an error if the creation fails. +func (neo *neoRepository) CreateEntity(input *types.Entity) (*types.Entity, error) { + jsonContent, err := input.Asset.JSON() + if err != nil { + return nil, err + } + + entity := Entity{ + Type: string(input.Asset.AssetType()), + Content: jsonContent, + } + + // ensure that duplicate entities are not entered into the database + if entities, err := neo.FindEntitiesByContent(input.Asset, time.Time{}); err == nil && len(entities) == 1 { + e := entities[0] + + if input.Asset.AssetType() == e.Asset.AssetType() { + if id, err := strconv.ParseUint(e.ID, 10, 64); err == nil { + entity.ID = id + entity.CreatedAt = e.CreatedAt + entity.UpdatedAt = time.Now().UTC() + } + } + } else { + if input.CreatedAt.IsZero() { + entity.CreatedAt = time.Now().UTC() + } else { + entity.CreatedAt = input.CreatedAt.UTC() + } + + if input.LastSeen.IsZero() { + entity.UpdatedAt = time.Now().UTC() + } else { + entity.UpdatedAt = input.LastSeen.UTC() + } + } + + result := sql.db.Save(&entity) + if err := result.Error; err != nil { + return nil, err + } + + return &types.Entity{ + ID: strconv.FormatUint(entity.ID, 10), + CreatedAt: entity.CreatedAt.In(time.UTC).Local(), + LastSeen: entity.UpdatedAt.In(time.UTC).Local(), + Asset: input.Asset, + }, nil +} + +// CreateAsset creates a new entity in the database. +// It takes an oam.Asset as input and persists it in the database. +// The asset is serialized to JSON and stored in the Content field of the Entity struct. +// Returns the created entity as a types.Entity or an error if the creation fails. +func (neo *neoRepository) CreateAsset(asset oam.Asset) (*types.Entity, error) { + return neo.CreateEntity(&types.Entity{Asset: asset}) +} + +// UpdateEntityLastSeen performs an update on the entity. +func (neo *neoRepository) UpdateEntityLastSeen(id string) error { + result := sql.db.Exec("UPDATE entities SET updated_at = current_timestamp WHERE entity_id = ?", id) + if err := result.Error; err != nil { + return err + } + return nil +} + +// FindEntityById finds an entity in the database by the ID. +// It takes a string representing the entity ID and retrieves the corresponding entity from the database. +// Returns the found entity as a types.Entity or an error if the asset is not found. +func (neo *neoRepository) FindEntityById(id string) (*types.Entity, error) { + entityId, err := strconv.ParseUint(id, 10, 64) + if err != nil { + return nil, err + } + + entity := Entity{ID: entityId} + result := sql.db.First(&entity) + if err := result.Error; err != nil { + return nil, err + } + + assetData, err := entity.Parse() + if err != nil { + return nil, err + } + + return &types.Entity{ + ID: strconv.FormatUint(entity.ID, 10), + CreatedAt: entity.CreatedAt.In(time.UTC).Local(), + LastSeen: entity.UpdatedAt.In(time.UTC).Local(), + Asset: assetData, + }, nil +} + +// FindEntitiesByContent finds entities in the database that match the provided asset data and last seen after +// the since parameter. It takes an oam.Asset as input and searches for entities with matching content in the database. +// If since.IsZero(), the parameter will be ignored. +// The asset data is serialized to JSON and compared against the Content field of the Entity struct. +// Returns a slice of matching entities as []*types.Entity or an error if the search fails. +func (neo *neoRepository) FindEntitiesByContent(assetData oam.Asset, since time.Time) ([]*types.Entity, error) { + jsonContent, err := assetData.JSON() + if err != nil { + return nil, err + } + + entity := Entity{ + Type: string(assetData.AssetType()), + Content: jsonContent, + } + + jsonQuery, err := entity.JSONQuery() + if err != nil { + return nil, err + } + + tx := sql.db.Where("etype = ?", entity.Type) + if !since.IsZero() { + tx = tx.Where("updated_at >= ?", since.UTC()) + } + + var entities []Entity + tx = tx.Where(jsonQuery).Find(&entities) + if err := tx.Error; err != nil { + return nil, err + } + + var results []*types.Entity + for _, e := range entities { + if assetData, err := e.Parse(); err == nil { + results = append(results, &types.Entity{ + ID: strconv.FormatUint(e.ID, 10), + CreatedAt: e.CreatedAt.In(time.UTC).Local(), + LastSeen: e.UpdatedAt.In(time.UTC).Local(), + Asset: assetData, + }) + } + } + + if len(results) == 0 { + return nil, errors.New("zero entities found") + } + return results, nil +} + +// FindEntitiesByType finds all entities in the database of the provided asset type and last seen after the since parameter. +// It takes an asset type and retrieves the corresponding entities from the database. +// If since.IsZero(), the parameter will be ignored. +// Returns a slice of matching entities as []*types.Entity or an error if the search fails. +func (neo *neoRepository) FindEntitiesByType(atype oam.AssetType, since time.Time) ([]*types.Entity, error) { + var entities []Entity + var result *gorm.DB + + if since.IsZero() { + result = sql.db.Where("etype = ?", atype).Find(&entities) + } else { + result = sql.db.Where("etype = ? AND updated_at >= ?", atype, since.UTC()).Find(&entities) + } + if err := result.Error; err != nil { + return nil, err + } + + var results []*types.Entity + for _, e := range entities { + if f, err := e.Parse(); err == nil { + results = append(results, &types.Entity{ + ID: strconv.FormatUint(e.ID, 10), + CreatedAt: e.CreatedAt.In(time.UTC).Local(), + LastSeen: e.UpdatedAt.In(time.UTC).Local(), + Asset: f, + }) + } + } + + if len(results) == 0 { + return nil, errors.New("no entities of the specified type") + } + return results, nil +} + +// DeleteEntity removes an entity in the database by its ID. +// It takes a string representing the entity ID and removes the corresponding entity from the database. +// Returns an error if the entity is not found. +func (neo *neoRepository) DeleteEntity(id string) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _, err := neo4jdb.ExecuteQuery(ctx, neo.db, + "MATCH (n:Entity {entity_id: $entity_id}) DETACH DELETE n", + map[string]interface{}{ + "entity_id": id, + }, + neo4jdb.EagerResultTransformer, + neo4jdb.ExecuteQueryWithDatabase(neo.dbname), + ) + + return err +} diff --git a/repository/neo4j/entity_test.go b/repository/neo4j/entity_test.go new file mode 100644 index 0000000..985d310 --- /dev/null +++ b/repository/neo4j/entity_test.go @@ -0,0 +1,7 @@ +//go:build integration + +// Copyright © by Jeff Foley 2017-2024. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. +// SPDX-License-Identifier: Apache-2.0 + +package neo4j diff --git a/repository/repository.go b/repository/repository.go index 9c5b4b3..ebc915a 100644 --- a/repository/repository.go +++ b/repository/repository.go @@ -22,7 +22,7 @@ type Repository interface { CreateEntity(entity *types.Entity) (*types.Entity, error) CreateAsset(asset oam.Asset) (*types.Entity, error) FindEntityById(id string) (*types.Entity, error) - FindEntityByContent(asset oam.Asset, since time.Time) ([]*types.Entity, error) + FindEntitiesByContent(asset oam.Asset, since time.Time) ([]*types.Entity, error) FindEntitiesByType(atype oam.AssetType, since time.Time) ([]*types.Entity, error) DeleteEntity(id string) error CreateEdge(edge *types.Edge) (*types.Edge, error) diff --git a/repository/sqlrepo/entity.go b/repository/sqlrepo/entity.go index eefa4ab..2c7dd5e 100644 --- a/repository/sqlrepo/entity.go +++ b/repository/sqlrepo/entity.go @@ -112,12 +112,12 @@ func (sql *sqlRepository) FindEntityById(id string) (*types.Entity, error) { }, nil } -// FindEntityByContent finds entities in the database that match the provided asset data and last seen after the since parameter. -// It takes an oam.Asset as input and searches for entities with matching content in the database. +// FindEntitiesByContent finds entities in the database that match the provided asset data and last seen after +// the since parameter. It takes an oam.Asset as input and searches for entities with matching content in the database. // If since.IsZero(), the parameter will be ignored. // The asset data is serialized to JSON and compared against the Content field of the Entity struct. // Returns a slice of matching entities as []*types.Entity or an error if the search fails. -func (sql *sqlRepository) FindEntityByContent(assetData oam.Asset, since time.Time) ([]*types.Entity, error) { +func (sql *sqlRepository) FindEntitiesByContent(assetData oam.Asset, since time.Time) ([]*types.Entity, error) { jsonContent, err := assetData.JSON() if err != nil { return nil, err diff --git a/repository/sqlrepo/entity_test.go b/repository/sqlrepo/entity_test.go index 0d9ae5c..4da68d9 100644 --- a/repository/sqlrepo/entity_test.go +++ b/repository/sqlrepo/entity_test.go @@ -258,7 +258,7 @@ func TestRepository(t *testing.T) { t.Fatalf("failed to find entity by id: expected entity %s, got %s", sourceEntity.Asset, foundAsset.Asset) } - foundAssetByContent, err := store.FindEntityByContent(sourceEntity.Asset, start) + foundAssetByContent, err := store.FindEntitiesByContent(sourceEntity.Asset, start) assert.NoError(t, err) assert.NotEqual(t, foundAssetByContent, nil)