From cb13bcd6c5be677cdbd780c976c2e24f2cc7817e Mon Sep 17 00:00:00 2001 From: caffix Date: Fri, 20 Dec 2024 20:58:51 -0500 Subject: [PATCH] edge methods in the Respository interface are implemented --- repository/neo4j/edge.go | 247 ++++++++++++++++------------ repository/neo4j/edge_tag.go | 4 +- repository/neo4j/entity_tag.go | 4 +- repository/neo4j/entity_test.go | 11 +- repository/neo4j/extract_edge.go | 202 +++++++++++++++++++++++ repository/neo4j/query_relations.go | 55 +++++++ 6 files changed, 405 insertions(+), 118 deletions(-) create mode 100644 repository/neo4j/extract_edge.go create mode 100644 repository/neo4j/query_relations.go diff --git a/repository/neo4j/edge.go b/repository/neo4j/edge.go index 9294e8a..62d6294 100644 --- a/repository/neo4j/edge.go +++ b/repository/neo4j/edge.go @@ -9,12 +9,12 @@ import ( "errors" "fmt" "reflect" - "strconv" + "strings" "time" + neo4jdb "github.com/neo4j/neo4j-go-driver/v5/neo4j" "github.com/owasp-amass/asset-db/types" oam "github.com/owasp-amass/open-asset-model" - "gorm.io/gorm" ) // CreateEdge creates an edge between two entities in the database. @@ -32,50 +32,57 @@ func (neo *neoRepository) CreateEdge(edge *types.Edge) (*types.Edge, error) { edge.FromEntity.Asset.AssetType(), edge.Relation.Label(), edge.ToEntity.Asset.AssetType()) } - var updated time.Time if edge.LastSeen.IsZero() { - updated = time.Now().UTC() - } else { - updated = edge.LastSeen.UTC() + edge.LastSeen = time.Now() } // ensure that duplicate relationships are not entered into the database - if e, found := neo.isDuplicateEdge(edge, updated); found { + if e, found := neo.isDuplicateEdge(edge, edge.LastSeen); found { return e, nil } - fromEntityId, err := strconv.ParseUint(edge.FromEntity.ID, 10, 64) - if err != nil { - return nil, err + if edge.CreatedAt.IsZero() { + edge.CreatedAt = time.Now() } - toEntityId, err := strconv.ParseUint(edge.ToEntity.ID, 10, 64) + props, err := edgePropsMap(edge) if err != nil { return nil, err } - jsonContent, err := edge.Relation.JSON() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + from := fmt.Sprintf("MATCH (from:Entity {entity_id: '%s'})", edge.FromEntity.ID) + to := fmt.Sprintf("MATCH (to:Entity {entity_id: '%s'})", edge.ToEntity.ID) + query := fmt.Sprintf("%s %s CREATE (from)-[r:%s $props]->(to) RETURN r", from, to, strings.ToUpper(edge.Relation.Label())) + result, err := neo4jdb.ExecuteQuery(ctx, neo.db, query, + map[string]interface{}{"props": props}, + neo4jdb.EagerResultTransformer, + neo4jdb.ExecuteQueryWithDatabase(neo.dbname), + ) if err != nil { return nil, err } + if len(result.Records) == 0 { + return nil, errors.New("no records returned from the query") + } - r := Edge{ - Type: string(edge.Relation.RelationType()), - Content: jsonContent, - FromEntityID: fromEntityId, - ToEntityID: toEntityId, - UpdatedAt: updated, + rel, isnil, err := neo4jdb.GetRecordValue[neo4jdb.Relationship](result.Records[0], "r") + if err != nil { + return nil, err } - if edge.CreatedAt.IsZero() { - r.CreatedAt = time.Now().UTC() - } else { - r.CreatedAt = edge.CreatedAt.UTC() + if isnil { + return nil, errors.New("the record value for the relationship is nil") } - result := sql.db.Create(&r) - if err := result.Error; err != nil { + r, err := relationshipToEdge(rel) + if err != nil { return nil, err } - return toEdge(r), nil + r.FromEntity = edge.FromEntity + r.ToEntity = edge.ToEntity + + return r, nil } // isDuplicateEdge checks if the relationship between source and dest already exists. @@ -104,41 +111,15 @@ func (neo *neoRepository) isDuplicateEdge(edge *types.Edge, updated time.Time) ( // edgeSeen updates the updated_at timestamp for the specified edge. func (neo *neoRepository) edgeSeen(rel *types.Edge, updated time.Time) error { - id, err := strconv.ParseUint(rel.ID, 10, 64) - if err != nil { - return err - } - - jsonContent, err := rel.Relation.JSON() - if err != nil { - return err - } - - fromEntityId, err := strconv.ParseUint(rel.FromEntity.ID, 10, 64) - if err != nil { - return err - } - - toEntityId, err := strconv.ParseUint(rel.ToEntity.ID, 10, 64) - if err != nil { - return err - } - - r := Edge{ - ID: id, - Type: string(rel.Relation.RelationType()), - Content: jsonContent, - FromEntityID: fromEntityId, - ToEntityID: toEntityId, - CreatedAt: rel.CreatedAt, - UpdatedAt: updated, - } + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() - result := sql.db.Save(&r) - if err := result.Error; err != nil { - return err - } - return nil + query := fmt.Sprintf("MATCH ()-[r]->() WHERE r.elementId = '%s' SET r.updated_at = localDateTime('%s')", rel.ID, timeToNeo4jTime(updated)) + _, err := neo4jdb.ExecuteQuery(ctx, neo.db, query, nil, + neo4jdb.EagerResultTransformer, + neo4jdb.ExecuteQueryWithDatabase(neo.dbname), + ) + return err } func (neo *neoRepository) FindEdgeById(id string) (*types.Edge, error) { @@ -198,88 +179,136 @@ func (neo *neoRepository) FindEdgeById(id string) (*types.Edge, error) { // If since.IsZero(), the parameter will be ignored. // If no labels are specified, all incoming eges are returned. func (neo *neoRepository) IncomingEdges(entity *types.Entity, since time.Time, labels ...string) ([]*types.Edge, error) { - entityId, err := strconv.ParseInt(entity.ID, 10, 64) - if err != nil { - return nil, err - } + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() - var edges []Edge - var result *gorm.DB - if since.IsZero() { - result = sql.db.Where("to_entity_id = ?", entityId).Find(&edges) - } else { - result = sql.db.Where("to_entity_id = ? AND updated_at >= ?", entityId, since.UTC()).Find(&edges) + query := fmt.Sprintf("MATCH (:Entity {entity_id: '%s'})<-[r]-(from:Entity) RETURN r, from.entity_id AS fid", entity.ID) + if !since.IsZero() { + query = fmt.Sprintf("MATCH (:Entity {entity_id: '%s'})<-[r]-(from:Entity) WHERE r.updated_at >= localDateTime('%s') RETURN r, from.entity_id AS fid", entity.ID, timeToNeo4jTime(since)) } - if err := result.Error; err != nil { + + result, err := neo4jdb.ExecuteQuery(ctx, neo.db, query, nil, + neo4jdb.EagerResultTransformer, + neo4jdb.ExecuteQueryWithDatabase(neo.dbname), + ) + if err != nil { return nil, err } - var results []Edge - if len(labels) > 0 { - for _, edge := range edges { - e := &edge - - if rel, err := e.Parse(); err == nil { - for _, label := range labels { - if label == rel.Label() { - results = append(results, edge) - break - } + var results []*types.Edge + for _, record := range result.Records { + r, isnil, err := neo4jdb.GetRecordValue[neo4jdb.Relationship](record, "r") + if err != nil { + continue + } + if isnil { + continue + } + + if len(labels) > 0 { + var found bool + + for _, label := range labels { + if strings.EqualFold(label, r.Type) { + found = true + break } } + + if !found { + continue + } + } + + fid, isnil, err := neo4jdb.GetRecordValue[string](record, "fid") + if err != nil { + continue + } + if isnil { + continue } - } else { - results = edges + + edge, err := relationshipToEdge(r) + if err != nil { + continue + } + edge.FromEntity = &types.Entity{ID: fid} + edge.ToEntity = entity + results = append(results, edge) } if len(results) == 0 { return nil, errors.New("zero edges found") } - return toEdges(results), nil + return results, nil } // OutgoingEdges finds all edges from the entity of the specified labels and last seen after the since parameter. // If since.IsZero(), the parameter will be ignored. // If no labels are specified, all outgoing edges are returned. func (neo *neoRepository) OutgoingEdges(entity *types.Entity, since time.Time, labels ...string) ([]*types.Edge, error) { - entityId, err := strconv.ParseInt(entity.ID, 10, 64) - if err != nil { - return nil, err - } + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() - var edges []Edge - var result *gorm.DB - if since.IsZero() { - result = sql.db.Where("from_entity_id = ?", entityId).Find(&edges) - } else { - result = sql.db.Where("from_entity_id = ? AND updated_at >= ?", entityId, since.UTC()).Find(&edges) + query := fmt.Sprintf("MATCH (:Entity {entity_id: '%s'})-[r]->(to:Entity) RETURN r, to.entity_id AS tid", entity.ID) + if !since.IsZero() { + query = fmt.Sprintf("MATCH (:Entity {entity_id: '%s'})-[r]->(to:Entity) WHERE r.updated_at >= localDateTime('%s') RETURN r, to.entity_id AS tid", entity.ID, timeToNeo4jTime(since)) } - if err := result.Error; err != nil { + + result, err := neo4jdb.ExecuteQuery(ctx, neo.db, query, nil, + neo4jdb.EagerResultTransformer, + neo4jdb.ExecuteQueryWithDatabase(neo.dbname), + ) + if err != nil { return nil, err } - var results []Edge - if len(labels) > 0 { - for _, edge := range edges { - e := &edge - - if rel, err := e.Parse(); err == nil { - for _, label := range labels { - if label == rel.Label() { - results = append(results, edge) - break - } + var results []*types.Edge + for _, record := range result.Records { + r, isnil, err := neo4jdb.GetRecordValue[neo4jdb.Relationship](record, "r") + if err != nil { + continue + } + if isnil { + continue + } + + if len(labels) > 0 { + var found bool + + for _, label := range labels { + if strings.EqualFold(label, r.Type) { + found = true + break } } + + if !found { + continue + } + } + + tid, isnil, err := neo4jdb.GetRecordValue[string](record, "tid") + if err != nil { + continue + } + if isnil { + continue + } + + edge, err := relationshipToEdge(r) + if err != nil { + continue } - } else { - results = edges + edge.FromEntity = entity + edge.ToEntity = &types.Entity{ID: tid} + results = append(results, edge) } if len(results) == 0 { return nil, errors.New("zero edges found") } - return toEdges(results), nil + return results, nil } // DeleteEdge removes an edge in the database by its ID. diff --git a/repository/neo4j/edge_tag.go b/repository/neo4j/edge_tag.go index aeb177d..a98adea 100644 --- a/repository/neo4j/edge_tag.go +++ b/repository/neo4j/edge_tag.go @@ -27,7 +27,7 @@ func (neo *neoRepository) CreateEdgeTag(edge *types.Edge, input *types.EdgeTag) return nil, errors.New("the input edge tag is nil") } // ensure that duplicate entities are not entered into the database - if tags, err := neo.FindEdgeTagsByContent(input.Property, time.Time{}); err == nil && len(tags) == 1 { + if tags, err := neo.FindEdgeTagsByContent(input.Property, time.Time{}); err == nil && len(tags) > 0 { t := tags[0] if input.Property.PropertyType() != t.Property.PropertyType() { @@ -39,6 +39,7 @@ func (neo *neoRepository) CreateEdgeTag(edge *types.Edge, input *types.EdgeTag) return nil, err } + t.Edge = edge t.LastSeen = time.Now() props, err := edgeTagPropsMap(t) if err != nil { @@ -83,6 +84,7 @@ func (neo *neoRepository) CreateEdgeTag(edge *types.Edge, input *types.EdgeTag) input.LastSeen = time.Now() } + input.Edge = edge props, err := edgeTagPropsMap(input) if err != nil { return nil, err diff --git a/repository/neo4j/entity_tag.go b/repository/neo4j/entity_tag.go index 55e033d..9b8671d 100644 --- a/repository/neo4j/entity_tag.go +++ b/repository/neo4j/entity_tag.go @@ -27,7 +27,7 @@ func (neo *neoRepository) CreateEntityTag(entity *types.Entity, input *types.Ent return nil, errors.New("the input entity tag is nil") } // ensure that duplicate entities are not entered into the database - if tags, err := neo.FindEntityTagsByContent(input.Property, time.Time{}); err == nil && len(tags) == 1 { + if tags, err := neo.FindEntityTagsByContent(input.Property, time.Time{}); err == nil && len(tags) > 0 { t := tags[0] if input.Property.PropertyType() != t.Property.PropertyType() { @@ -39,6 +39,7 @@ func (neo *neoRepository) CreateEntityTag(entity *types.Entity, input *types.Ent return nil, err } + t.Entity = entity t.LastSeen = time.Now() props, err := entityTagPropsMap(t) if err != nil { @@ -83,6 +84,7 @@ func (neo *neoRepository) CreateEntityTag(entity *types.Entity, input *types.Ent input.LastSeen = time.Now() } + input.Entity = entity props, err := entityTagPropsMap(input) if err != nil { return nil, err diff --git a/repository/neo4j/entity_test.go b/repository/neo4j/entity_test.go index 8e5f727..3c90808 100644 --- a/repository/neo4j/entity_test.go +++ b/repository/neo4j/entity_test.go @@ -60,7 +60,7 @@ func TestCreateEntity(t *testing.T) { }) assert.NoError(t, err) - asset.Equal(t, entity.ID, second.ID) + assert.NotEqual(t, entity.ID, second.ID) if !second.CreatedAt.After(newer.LastSeen) { t.Errorf("Failed to assign the second entity an accurate creation time") } @@ -76,7 +76,7 @@ func TestFindEntityById(t *testing.T) { same, err := store.FindEntityById(entity.ID) assert.NoError(t, err) - asset.Equal(t, entity.ID, same.ID) + assert.Equal(t, entity.ID, same.ID) if fqdn1, ok := entity.Asset.(*domain.FQDN); !ok { t.Errorf("Failed to type assert the first asset") @@ -99,7 +99,7 @@ func TestFindEntitiesByContent(t *testing.T) { e, err := store.FindEntitiesByContent(fqdn, entity.CreatedAt.Add(-1*time.Second)) assert.NoError(t, err) same := e[0] - asset.Equal(t, entity.ID, same.ID) + assert.Equal(t, entity.ID, same.ID) if fqdn1, ok := entity.Asset.(*domain.FQDN); !ok { t.Errorf("Failed to type assert the first asset") @@ -158,9 +158,6 @@ func TestDeleteEntity(t *testing.T) { err = store.DeleteEntity(entity.ID) assert.NoError(t, err) - _, err := store.FindEntityById(entity.ID) - assert.Error(t, err) - - err = store.DeleteEntity(entity.ID) + _, err = store.FindEntityById(entity.ID) assert.Error(t, err) } diff --git a/repository/neo4j/extract_edge.go b/repository/neo4j/extract_edge.go new file mode 100644 index 0000000..1a7a99e --- /dev/null +++ b/repository/neo4j/extract_edge.go @@ -0,0 +1,202 @@ +// Copyright © by Jeff Foley 2017-2024. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. +// SPDX-License-Identifier: Apache-2.0 + +package neo4j + +import ( + "errors" + "strings" + + neo4jdb "github.com/neo4j/neo4j-go-driver/v5/neo4j" + "github.com/owasp-amass/asset-db/types" + oam "github.com/owasp-amass/open-asset-model" + "github.com/owasp-amass/open-asset-model/relation" +) + +func relationshipToEdge(rel neo4jdb.Relationship) (*types.Edge, error) { + t, err := neo4jdb.GetProperty[neo4jdb.LocalDateTime](rel, "created_at") + if err != nil { + return nil, err + } + created := neo4jTimeToTime(t) + + t, err = neo4jdb.GetProperty[neo4jdb.LocalDateTime](rel, "updated_at") + if err != nil { + return nil, err + } + updated := neo4jTimeToTime(t) + + etype, err := neo4jdb.GetProperty[string](rel, "etype") + if err != nil { + return nil, err + } + rtype := oam.RelationType(etype) + + var r oam.Relation + switch rtype { + case oam.BasicDNSRelation: + r, err = relationshipToBasicDNSRelation(rel) + case oam.PortRelation: + r, err = relationshipToPortRelation(rel) + case oam.PrefDNSRelation: + r, err = relationshipToPrefDNSRelation(rel) + case oam.SimpleRelation: + r, err = relationshipToSimpleRelation(rel) + case oam.SRVDNSRelation: + r, err = relationshipToSRVDNSRelation(rel) + } + if err != nil { + return nil, err + } + if r == nil { + return nil, errors.New("relation type not supported") + } + + return &types.Edge{ + ID: rel.GetElementId(), + CreatedAt: created, + LastSeen: updated, + Relation: r, + }, nil +} + +func relationshipToBasicDNSRelation(rel neo4jdb.Relationship) (*relation.BasicDNSRelation, error) { + num, err := neo4jdb.GetProperty[int64](rel, "header_rrtype") + if err != nil { + return nil, err + } + rrtype := int(num) + + num, err = neo4jdb.GetProperty[int64](rel, "header_class") + if err != nil { + return nil, err + } + class := int(num) + + num, err = neo4jdb.GetProperty[int64](rel, "header_ttl") + if err != nil { + return nil, err + } + ttl := int(num) + + return &relation.BasicDNSRelation{ + Name: strings.ToLower(rel.Type), + Header: relation.RRHeader{ + RRType: rrtype, + Class: class, + TTL: ttl, + }, + }, nil +} + +func relationshipToPortRelation(rel neo4jdb.Relationship) (*relation.PortRelation, error) { + num, err := neo4jdb.GetProperty[int64](rel, "port_number") + if err != nil { + return nil, err + } + port := int(num) + + protocol, err := neo4jdb.GetProperty[string](rel, "protocol") + if err != nil { + return nil, err + } + + return &relation.PortRelation{ + Name: strings.ToLower(rel.Type), + PortNumber: port, + Protocol: protocol, + }, nil +} + +func relationshipToPrefDNSRelation(rel neo4jdb.Relationship) (*relation.PrefDNSRelation, error) { + num, err := neo4jdb.GetProperty[int64](rel, "header_rrtype") + if err != nil { + return nil, err + } + rrtype := int(num) + + num, err = neo4jdb.GetProperty[int64](rel, "header_class") + if err != nil { + return nil, err + } + class := int(num) + + num, err = neo4jdb.GetProperty[int64](rel, "header_ttl") + if err != nil { + return nil, err + } + ttl := int(num) + + num, err = neo4jdb.GetProperty[int64](rel, "preference") + if err != nil { + return nil, err + } + pref := int(num) + + return &relation.PrefDNSRelation{ + Name: strings.ToLower(rel.Type), + Header: relation.RRHeader{ + RRType: rrtype, + Class: class, + TTL: ttl, + }, + Preference: pref, + }, nil +} + +func relationshipToSimpleRelation(rel neo4jdb.Relationship) (*relation.SimpleRelation, error) { + return &relation.SimpleRelation{ + Name: strings.ToLower(rel.Type), + }, nil +} + +func relationshipToSRVDNSRelation(rel neo4jdb.Relationship) (*relation.SRVDNSRelation, error) { + num, err := neo4jdb.GetProperty[int64](rel, "header_rrtype") + if err != nil { + return nil, err + } + rrtype := int(num) + + num, err = neo4jdb.GetProperty[int64](rel, "header_class") + if err != nil { + return nil, err + } + class := int(num) + + num, err = neo4jdb.GetProperty[int64](rel, "header_ttl") + if err != nil { + return nil, err + } + ttl := int(num) + + num, err = neo4jdb.GetProperty[int64](rel, "priority") + if err != nil { + return nil, err + } + priority := int(num) + + num, err = neo4jdb.GetProperty[int64](rel, "weight") + if err != nil { + return nil, err + } + weight := int(num) + + num, err = neo4jdb.GetProperty[int64](rel, "port") + if err != nil { + return nil, err + } + port := int(num) + + return &relation.SRVDNSRelation{ + Name: strings.ToLower(rel.Type), + Header: relation.RRHeader{ + RRType: rrtype, + Class: class, + TTL: ttl, + }, + Priority: priority, + Weight: weight, + Port: port, + }, nil +} diff --git a/repository/neo4j/query_relations.go b/repository/neo4j/query_relations.go new file mode 100644 index 0000000..ef71229 --- /dev/null +++ b/repository/neo4j/query_relations.go @@ -0,0 +1,55 @@ +// Copyright © by Jeff Foley 2017-2024. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. +// SPDX-License-Identifier: Apache-2.0 + +package neo4j + +import ( + "errors" + + "github.com/owasp-amass/asset-db/types" + "github.com/owasp-amass/open-asset-model/relation" +) + +func edgePropsMap(edge *types.Edge) (map[string]interface{}, error) { + if edge == nil { + return nil, errors.New("the edge is nil") + } + if edge.Relation == nil { + return nil, errors.New("the relation is nil") + } + + m := make(map[string]interface{}) + // begin populating the map of parameters + m["etype"] = edge.Relation.RelationType() + m["created_at"] = timeToNeo4jTime(edge.CreatedAt) + m["updated_at"] = timeToNeo4jTime(edge.LastSeen) + + // Add the properties of the relation + switch v := edge.Relation.(type) { + case *relation.BasicDNSRelation: + m["header_rrtype"] = v.Header.RRType + m["header_class"] = v.Header.Class + m["header_ttl"] = v.Header.TTL + case *relation.PortRelation: + m["port_number"] = v.PortNumber + m["protocol"] = v.Protocol + case *relation.PrefDNSRelation: + m["header_rrtype"] = v.Header.RRType + m["header_class"] = v.Header.Class + m["header_ttl"] = v.Header.TTL + m["preference"] = v.Preference + case *relation.SimpleRelation: + case *relation.SRVDNSRelation: + m["header_rrtype"] = v.Header.RRType + m["header_class"] = v.Header.Class + m["header_ttl"] = v.Header.TTL + m["priority"] = v.Priority + m["weight"] = v.Weight + m["port"] = v.Port + default: + return nil, errors.New("property type not supported") + } + + return m, nil +}