From 4bef858e23a2348b0a41683da370520a9a850864 Mon Sep 17 00:00:00 2001 From: Michael Kedar Date: Wed, 8 Jan 2025 15:09:23 +1100 Subject: [PATCH 1/6] WIP client migration --- cmd/osv-scanner/fix/main.go | 25 ++++++++-- internal/remediation/in_place.go | 18 ++++--- internal/resolution/client/client.go | 10 +--- internal/resolution/client/helper.go | 48 +++++++++++++++++++ .../clienttest/mock_resolution_client.go | 30 ++++++++---- internal/resolution/resolve.go | 10 ++-- .../generate_mock_resolution_universe/main.go | 23 ++++++--- 7 files changed, 123 insertions(+), 41 deletions(-) create mode 100644 internal/resolution/client/helper.go diff --git a/cmd/osv-scanner/fix/main.go b/cmd/osv-scanner/fix/main.go index 0e5348042a6..7982ff9ab63 100644 --- a/cmd/osv-scanner/fix/main.go +++ b/cmd/osv-scanner/fix/main.go @@ -7,9 +7,13 @@ import ( "os" "path/filepath" "strings" + "time" "deps.dev/util/resolve" + "github.com/google/osv-scanner/internal/clients/clientimpl/localmatcher" + "github.com/google/osv-scanner/internal/clients/clientimpl/osvmatcher" "github.com/google/osv-scanner/internal/depsdev" + "github.com/google/osv-scanner/internal/osvdev" "github.com/google/osv-scanner/internal/remediation" "github.com/google/osv-scanner/internal/remediation/upgrade" "github.com/google/osv-scanner/internal/resolution" @@ -365,17 +369,28 @@ func action(ctx *cli.Context, stdout, stderr io.Writer) (reporter.Reporter, erro } if ctx.Bool("experimental-offline-vulnerabilities") { - var err error - opts.Client.VulnerabilityClient, err = client.NewOSVOfflineClient( + matcher, err := localmatcher.NewLocalMatcher( r, - system, + ctx.String("experimental-local-db-path"), + "osv-scanner_fix/"+version.OSVVersion, ctx.Bool("experimental-download-offline-databases"), - ctx.String("experimental-local-db-path")) + ) if err != nil { return nil, err } + + // TODO: check system is downloaded + // if err := matcher.LoadEcosystem(ctx.Context, system?); err != nil { + // return nil, err + // } + + opts.Client.VulnerabilityMatcher = matcher } else { - opts.Client.VulnerabilityClient = client.NewOSVClient() + // TODO: replace with cached client + opts.Client.VulnerabilityMatcher = &osvmatcher.OSVMatcher{ + Client: *osvdev.DefaultClient(), // TODO: UserAgent + InitialQueryTimeout: 5 * time.Minute, + } } if !ctx.Bool("non-interactive") { diff --git a/internal/remediation/in_place.go b/internal/remediation/in_place.go index 086e0598875..bf2155bb39e 100644 --- a/internal/remediation/in_place.go +++ b/internal/remediation/in_place.go @@ -9,6 +9,7 @@ import ( "deps.dev/util/resolve" "deps.dev/util/resolve/dep" "deps.dev/util/semver" + "github.com/google/osv-scanner/internal/clients/clientinterfaces" "github.com/google/osv-scanner/internal/remediation/upgrade" "github.com/google/osv-scanner/internal/resolution" "github.com/google/osv-scanner/internal/resolution/client" @@ -105,7 +106,7 @@ func (r InPlaceResult) VulnCount() VulnCount { // ComputeInPlacePatches finds all possible targeting version changes that would fix vulnerabilities in a resolved graph. // TODO: Check for introduced vulnerabilities func ComputeInPlacePatches(ctx context.Context, cl client.ResolutionClient, graph *resolve.Graph, opts Options) (InPlaceResult, error) { - res, err := inPlaceVulnsNodes(cl, graph) + res, err := inPlaceVulnsNodes(ctx, cl, graph) if err != nil { return InPlaceResult{}, err } @@ -235,8 +236,8 @@ type inPlaceVulnsNodesResult struct { vkNodes map[resolve.VersionKey][]resolve.NodeID } -func inPlaceVulnsNodes(cl client.VulnerabilityClient, graph *resolve.Graph) (inPlaceVulnsNodesResult, error) { - nodeVulns, err := cl.FindVulns(graph) +func inPlaceVulnsNodes(ctx context.Context, m clientinterfaces.VulnerabilityMatcher, graph *resolve.Graph) (inPlaceVulnsNodesResult, error) { + nodeVulns, err := m.MatchVulnerabilities(ctx, client.GraphAsInventory(graph)) if err != nil { return inPlaceVulnsNodesResult{}, err } @@ -249,7 +250,10 @@ func inPlaceVulnsNodes(cl client.VulnerabilityClient, graph *resolve.Graph) (inP // Find all direct dependencies of vulnerable nodes. for _, e := range graph.Edges { - if len(nodeVulns[e.From]) > 0 { + if e.From == 0 { + continue + } + if len(nodeVulns[e.From-1]) > 0 { result.nodeDependencies[e.From] = append(result.nodeDependencies[e.From], graph.Nodes[e.To].Version) } } @@ -259,7 +263,7 @@ func inPlaceVulnsNodes(cl client.VulnerabilityClient, graph *resolve.Graph) (inP var nodeIDs []resolve.NodeID for nID, vulns := range nodeVulns { if len(vulns) > 0 { - nodeIDs = append(nodeIDs, resolve.NodeID(nID)) + nodeIDs = append(nodeIDs, resolve.NodeID(nID+1)) } } nodeChains := resolution.ComputeChains(graph, nodeIDs) @@ -270,9 +274,9 @@ func inPlaceVulnsNodes(cl client.VulnerabilityClient, graph *resolve.Graph) (inP chains := nodeChains[i] vk := graph.Nodes[nID].Version result.vkNodes[vk] = append(result.vkNodes[vk], nID) - for _, vuln := range nodeVulns[nID] { + for _, vuln := range nodeVulns[nID-1] { resVuln := resolution.Vulnerability{ - OSV: vuln, + OSV: *vuln, ProblemChains: slices.Clone(chains), DevOnly: !slices.ContainsFunc(chains, func(dc resolution.DependencyChain) bool { return !resolution.ChainIsDev(dc, nil) }), } diff --git a/internal/resolution/client/client.go b/internal/resolution/client/client.go index 5341d1b2dd6..63620990aea 100644 --- a/internal/resolution/client/client.go +++ b/internal/resolution/client/client.go @@ -8,8 +8,8 @@ import ( "deps.dev/util/resolve" "deps.dev/util/resolve/dep" "deps.dev/util/semver" + "github.com/google/osv-scanner/internal/clients/clientinterfaces" "github.com/google/osv-scanner/internal/depsdev" - "github.com/google/osv-scanner/pkg/models" "github.com/google/osv-scanner/pkg/osv" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -17,13 +17,7 @@ import ( type ResolutionClient struct { DependencyClient - VulnerabilityClient -} - -type VulnerabilityClient interface { - // FindVulns finds the vulnerabilities affecting each of Nodes in the graph. - // The returned Vulnerabilities[i] corresponds to the vulnerabilities in g.Nodes[i]. - FindVulns(g *resolve.Graph) ([]models.Vulnerabilities, error) + clientinterfaces.VulnerabilityMatcher } type DependencyClient interface { diff --git a/internal/resolution/client/helper.go b/internal/resolution/client/helper.go new file mode 100644 index 00000000000..6fabaa80c19 --- /dev/null +++ b/internal/resolution/client/helper.go @@ -0,0 +1,48 @@ +package client + +import ( + "deps.dev/util/resolve" + "github.com/google/osv-scalibr/extractor" + "github.com/google/osv-scalibr/plugin" + "github.com/google/osv-scalibr/purl" + "github.com/ossf/osv-schema/bindings/go/osvschema" +) + +// GraphAsInventory is a helper function to convert a Graph into an Inventory for use with VulnerabilityMatcher. +func GraphAsInventory(g *resolve.Graph) []*extractor.Inventory { + // g.Nodes[0] is the root node of the graph that should be excluded. + inv := make([]*extractor.Inventory, len(g.Nodes)-1) + for i, n := range g.Nodes[1:] { + inv[i] = &extractor.Inventory{ + Name: n.Version.Name, + Version: n.Version.Version, + Metadata: n, + Extractor: graphExtractor{}, + Locations: []string{g.Nodes[0].Version.Name}, + } + } + return inv +} + +// graphExtractor is for GraphAsInventory to get the ecosystem. +type graphExtractor struct{} + +func (e graphExtractor) Ecosystem(i *extractor.Inventory) string { + n, ok := i.Metadata.(resolve.Node) + if !ok { + return "" + } + switch n.Version.System { + case resolve.NPM: + return string(osvschema.EcosystemNPM) + case resolve.Maven: + return string(osvschema.EcosystemMaven) + default: + return "" + } +} + +func (e graphExtractor) Name() string { return "" } +func (e graphExtractor) Requirements() *plugin.Capabilities { return nil } +func (e graphExtractor) ToPURL(_ *extractor.Inventory) *purl.PackageURL { return nil } +func (e graphExtractor) Version() int { return 0 } diff --git a/internal/resolution/clienttest/mock_resolution_client.go b/internal/resolution/clienttest/mock_resolution_client.go index 27e4c40d228..6704f2a1ec3 100644 --- a/internal/resolution/clienttest/mock_resolution_client.go +++ b/internal/resolution/clienttest/mock_resolution_client.go @@ -1,15 +1,18 @@ package clienttest import ( + "context" "os" "strings" "testing" "deps.dev/util/resolve" "deps.dev/util/resolve/schema" + "github.com/google/osv-scalibr/extractor" + "github.com/google/osv-scanner/internal/imodels" "github.com/google/osv-scanner/internal/resolution/client" - "github.com/google/osv-scanner/internal/resolution/util" "github.com/google/osv-scanner/internal/utility/vulns" + "github.com/google/osv-scanner/pkg/lockfile" "github.com/google/osv-scanner/pkg/models" "gopkg.in/yaml.v3" ) @@ -20,17 +23,24 @@ type ResolutionUniverse struct { Vulns []models.Vulnerability `yaml:"vulns"` } -type mockVulnerabilityClient []models.Vulnerability +type mockVulnerabilityMatcher []models.Vulnerability -func (mvc mockVulnerabilityClient) FindVulns(g *resolve.Graph) ([]models.Vulnerabilities, error) { - result := make([]models.Vulnerabilities, len(g.Nodes)) - for i, n := range g.Nodes { - if i == 0 { - continue // skip root node +func (mvc mockVulnerabilityMatcher) MatchVulnerabilities(_ context.Context, invs []*extractor.Inventory) ([][]*models.Vulnerability, error) { + result := make([][]*models.Vulnerability, len(invs)) + for i, inv := range invs { + pkg := imodels.FromInventory(inv) + // TODO (V2 Models): remove this once PackageDetails has been migrated + mappedPackageDetails := lockfile.PackageDetails{ + Name: pkg.Name(), + Version: pkg.Version(), + Commit: pkg.Commit(), + Ecosystem: lockfile.Ecosystem(pkg.Ecosystem().String()), + CompareAs: lockfile.Ecosystem(pkg.Ecosystem().String()), + DepGroups: pkg.DepGroups(), } for _, v := range mvc { - if vulns.IsAffected(v, util.VKToPackageDetails(n.Version)) { - result[i] = append(result[i], v) + if vulns.IsAffected(v, mappedPackageDetails) { + result[i] = append(result[i], &v) } } } @@ -61,7 +71,7 @@ func NewMockResolutionClient(t *testing.T, universeYAML string) client.Resolutio } cl := client.ResolutionClient{ - VulnerabilityClient: mockVulnerabilityClient(universe.Vulns), + VulnerabilityMatcher: mockVulnerabilityMatcher(universe.Vulns), } var sys resolve.System diff --git a/internal/resolution/resolve.go b/internal/resolution/resolve.go index bb702d14746..818fb98a1c8 100644 --- a/internal/resolution/resolve.go +++ b/internal/resolution/resolve.go @@ -204,7 +204,7 @@ func resolvePostProcess(ctx context.Context, cl client.ResolutionClient, m manif // computeVulns scans for vulnerabilities in a resolved graph and populates res.Vulns func (res *Result) computeVulns(ctx context.Context, cl client.ResolutionClient) error { - nodeVulns, err := cl.FindVulns(res.Graph) + nodeVulns, err := cl.MatchVulnerabilities(ctx, client.GraphAsInventory(res.Graph)) if err != nil { return err } @@ -213,17 +213,17 @@ func (res *Result) computeVulns(ctx context.Context, cl client.ResolutionClient) vulnInfo := make(map[string]models.Vulnerability) for i, vulns := range nodeVulns { if len(vulns) > 0 { - vulnerableNodes = append(vulnerableNodes, resolve.NodeID(i)) + vulnerableNodes = append(vulnerableNodes, resolve.NodeID(i+1)) } for _, vuln := range vulns { - vulnInfo[vuln.ID] = vuln + vulnInfo[vuln.ID] = *vuln } } nodeChains := ComputeChains(res.Graph, vulnerableNodes) vulnChains := make(map[string][]DependencyChain) - for i, idx := range vulnerableNodes { - for _, vuln := range nodeVulns[idx] { + for i, nID := range vulnerableNodes { + for _, vuln := range nodeVulns[nID-1] { vulnChains[vuln.ID] = append(vulnChains[vuln.ID], nodeChains[i]...) } } diff --git a/scripts/generate_mock_resolution_universe/main.go b/scripts/generate_mock_resolution_universe/main.go index c926293fb0a..d482713bbb1 100644 --- a/scripts/generate_mock_resolution_universe/main.go +++ b/scripts/generate_mock_resolution_universe/main.go @@ -22,7 +22,9 @@ import ( pb "deps.dev/api/v3" "deps.dev/util/resolve" "deps.dev/util/resolve/dep" + "github.com/google/osv-scanner/internal/clients/clientimpl/osvmatcher" "github.com/google/osv-scanner/internal/depsdev" + "github.com/google/osv-scanner/internal/osvdev" "github.com/google/osv-scanner/internal/remediation" "github.com/google/osv-scanner/internal/remediation/upgrade" "github.com/google/osv-scanner/internal/resolution" @@ -51,8 +53,11 @@ var remediationOpts = remediation.Options{ func doRelockRelax(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename string) error { cl := client.ResolutionClient{ - VulnerabilityClient: client.NewOSVClient(), - DependencyClient: ddCl, + VulnerabilityMatcher: &osvmatcher.OSVMatcher{ + Client: *osvdev.DefaultClient(), + InitialQueryTimeout: 5 * time.Minute, + }, + DependencyClient: ddCl, } f, err := lf.OpenLocalDepFile(filename) @@ -78,8 +83,11 @@ func doRelockRelax(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename func doOverride(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename string) error { cl := client.ResolutionClient{ - VulnerabilityClient: client.NewOSVClient(), - DependencyClient: ddCl, + VulnerabilityMatcher: &osvmatcher.OSVMatcher{ + Client: *osvdev.DefaultClient(), + InitialQueryTimeout: 5 * time.Minute, + }, + DependencyClient: ddCl, } f, err := lf.OpenLocalDepFile(filename) @@ -105,8 +113,11 @@ func doOverride(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename str func doInPlace(ddCl *client.DepsDevClient, rw lockfile.ReadWriter, filename string) error { cl := client.ResolutionClient{ - VulnerabilityClient: client.NewOSVClient(), - DependencyClient: ddCl, + VulnerabilityMatcher: &osvmatcher.OSVMatcher{ + Client: *osvdev.DefaultClient(), + InitialQueryTimeout: 5 * time.Minute, + }, + DependencyClient: ddCl, } f, err := lf.OpenLocalDepFile(filename) From 298e1f74dac02eaffa678ff4df32032502a92afd Mon Sep 17 00:00:00 2001 From: Michael Kedar Date: Wed, 15 Jan 2025 17:08:45 +1100 Subject: [PATCH 2/6] fix up local client a bit --- cmd/osv-scanner/fix/main.go | 15 ++++++++---- internal/resolution/client/helper.go | 34 +++++++++++++--------------- internal/resolution/util/depsdev.go | 1 + 3 files changed, 28 insertions(+), 22 deletions(-) diff --git a/cmd/osv-scanner/fix/main.go b/cmd/osv-scanner/fix/main.go index 7982ff9ab63..bd7019f7f43 100644 --- a/cmd/osv-scanner/fix/main.go +++ b/cmd/osv-scanner/fix/main.go @@ -13,6 +13,7 @@ import ( "github.com/google/osv-scanner/internal/clients/clientimpl/localmatcher" "github.com/google/osv-scanner/internal/clients/clientimpl/osvmatcher" "github.com/google/osv-scanner/internal/depsdev" + "github.com/google/osv-scanner/internal/imodels/ecosystem" "github.com/google/osv-scanner/internal/osvdev" "github.com/google/osv-scanner/internal/remediation" "github.com/google/osv-scanner/internal/remediation/upgrade" @@ -20,8 +21,10 @@ import ( "github.com/google/osv-scanner/internal/resolution/client" "github.com/google/osv-scanner/internal/resolution/lockfile" "github.com/google/osv-scanner/internal/resolution/manifest" + "github.com/google/osv-scanner/internal/resolution/util" "github.com/google/osv-scanner/internal/version" "github.com/google/osv-scanner/pkg/reporter" + "github.com/ossf/osv-schema/bindings/go/osvschema" "github.com/urfave/cli/v2" "golang.org/x/term" ) @@ -379,10 +382,14 @@ func action(ctx *cli.Context, stdout, stderr io.Writer) (reporter.Reporter, erro return nil, err } - // TODO: check system is downloaded - // if err := matcher.LoadEcosystem(ctx.Context, system?); err != nil { - // return nil, err - // } + eco, ok := util.OSVEcosystem[system] + if !ok { + // Something's very wrong if we hit this + panic("unhandled resolve.Ecosystem: " + system.String()) + } + if err := matcher.LoadEcosystem(ctx.Context, ecosystem.Parsed{Ecosystem: osvschema.Ecosystem(eco)}); err != nil { + return nil, err + } opts.Client.VulnerabilityMatcher = matcher } else { diff --git a/internal/resolution/client/helper.go b/internal/resolution/client/helper.go index 6fabaa80c19..aae4c4988c0 100644 --- a/internal/resolution/client/helper.go +++ b/internal/resolution/client/helper.go @@ -5,7 +5,6 @@ import ( "github.com/google/osv-scalibr/extractor" "github.com/google/osv-scalibr/plugin" "github.com/google/osv-scalibr/purl" - "github.com/ossf/osv-schema/bindings/go/osvschema" ) // GraphAsInventory is a helper function to convert a Graph into an Inventory for use with VulnerabilityMatcher. @@ -16,33 +15,32 @@ func GraphAsInventory(g *resolve.Graph) []*extractor.Inventory { inv[i] = &extractor.Inventory{ Name: n.Version.Name, Version: n.Version.Version, - Metadata: n, - Extractor: graphExtractor{}, - Locations: []string{g.Nodes[0].Version.Name}, + Extractor: mockExtractor{n.Version.System}, } } + return inv } -// graphExtractor is for GraphAsInventory to get the ecosystem. -type graphExtractor struct{} +// mockExtractor is for GraphAsInventory to get the ecosystem. +type mockExtractor struct { + ecosystem resolve.System +} -func (e graphExtractor) Ecosystem(i *extractor.Inventory) string { - n, ok := i.Metadata.(resolve.Node) - if !ok { - return "" - } - switch n.Version.System { +func (e mockExtractor) Ecosystem(*extractor.Inventory) string { + switch e.ecosystem { case resolve.NPM: - return string(osvschema.EcosystemNPM) + return "npm" case resolve.Maven: - return string(osvschema.EcosystemMaven) + return "Maven" + case resolve.UnknownSystem: + return "" default: return "" } } -func (e graphExtractor) Name() string { return "" } -func (e graphExtractor) Requirements() *plugin.Capabilities { return nil } -func (e graphExtractor) ToPURL(_ *extractor.Inventory) *purl.PackageURL { return nil } -func (e graphExtractor) Version() int { return 0 } +func (e mockExtractor) Name() string { return "" } +func (e mockExtractor) Requirements() *plugin.Capabilities { return nil } +func (e mockExtractor) ToPURL(*extractor.Inventory) *purl.PackageURL { return nil } +func (e mockExtractor) Version() int { return 0 } diff --git a/internal/resolution/util/depsdev.go b/internal/resolution/util/depsdev.go index f573163c548..6068492e651 100644 --- a/internal/resolution/util/depsdev.go +++ b/internal/resolution/util/depsdev.go @@ -6,6 +6,7 @@ import ( "github.com/google/osv-scanner/pkg/models" ) +// TODO: use osvschema.Ecosystem or imodel's ecosystem.Parsed var OSVEcosystem = map[resolve.System]models.Ecosystem{ resolve.NPM: models.EcosystemNPM, resolve.Maven: models.EcosystemMaven, From 27718d3eece6d98367fd403cbfd161f53e96f4d1 Mon Sep 17 00:00:00 2001 From: Michael Kedar Date: Thu, 16 Jan 2025 11:56:48 +1100 Subject: [PATCH 3/6] CachedOSVMatcher --- cmd/osv-scanner/fix/main.go | 15 +- .../clientimpl/localmatcher/localmatcher.go | 2 +- .../clients/clientimpl/localmatcher/zip.go | 8 +- .../clientimpl/osvmatcher/cachedosvmatcher.go | 151 ++++++++++++++++++ .../clienttest/mock_resolution_client.go | 19 +-- .../generate_mock_resolution_universe/main.go | 6 +- 6 files changed, 173 insertions(+), 28 deletions(-) create mode 100644 internal/clients/clientimpl/osvmatcher/cachedosvmatcher.go diff --git a/cmd/osv-scanner/fix/main.go b/cmd/osv-scanner/fix/main.go index bd7019f7f43..8828895682d 100644 --- a/cmd/osv-scanner/fix/main.go +++ b/cmd/osv-scanner/fix/main.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "io" + "net/http" "os" "path/filepath" "strings" @@ -371,11 +372,12 @@ func action(ctx *cli.Context, stdout, stderr io.Writer) (reporter.Reporter, erro } } + userAgent := "osv-scanner_fix/" + version.OSVVersion if ctx.Bool("experimental-offline-vulnerabilities") { matcher, err := localmatcher.NewLocalMatcher( r, ctx.String("experimental-local-db-path"), - "osv-scanner_fix/"+version.OSVVersion, + userAgent, ctx.Bool("experimental-download-offline-databases"), ) if err != nil { @@ -393,9 +395,14 @@ func action(ctx *cli.Context, stdout, stderr io.Writer) (reporter.Reporter, erro opts.Client.VulnerabilityMatcher = matcher } else { - // TODO: replace with cached client - opts.Client.VulnerabilityMatcher = &osvmatcher.OSVMatcher{ - Client: *osvdev.DefaultClient(), // TODO: UserAgent + config := osvdev.DefaultConfig() + config.UserAgent = userAgent + opts.Client.VulnerabilityMatcher = &osvmatcher.CachedOSVMatcher{ + Client: osvdev.OSVClient{ + HTTPClient: http.DefaultClient, + Config: config, + BaseHostURL: osvdev.DefaultBaseURL, + }, InitialQueryTimeout: 5 * time.Minute, } } diff --git a/internal/clients/clientimpl/localmatcher/localmatcher.go b/internal/clients/clientimpl/localmatcher/localmatcher.go index c4606582740..e90452e5a67 100644 --- a/internal/clients/clientimpl/localmatcher/localmatcher.go +++ b/internal/clients/clientimpl/localmatcher/localmatcher.go @@ -77,7 +77,7 @@ func (matcher *LocalMatcher) MatchVulnerabilities(ctx context.Context, invs []*e continue } - results = append(results, db.VulnerabilitiesAffectingPackage(pkg)) + results = append(results, VulnerabilitiesAffectingPackage(db.Vulnerabilities(false), pkg)) } return results, nil diff --git a/internal/clients/clientimpl/localmatcher/zip.go b/internal/clients/clientimpl/localmatcher/zip.go index 12433c395a9..46424f86ec8 100644 --- a/internal/clients/clientimpl/localmatcher/zip.go +++ b/internal/clients/clientimpl/localmatcher/zip.go @@ -235,7 +235,8 @@ func (db *ZipDB) Vulnerabilities(includeWithdrawn bool) []models.Vulnerability { return vulnerabilities } -func (db *ZipDB) VulnerabilitiesAffectingPackage(pkg imodels.PackageInfo) []*models.Vulnerability { +// TODO: Move this to another file. +func VulnerabilitiesAffectingPackage(allVulns []models.Vulnerability, pkg imodels.PackageInfo) []*models.Vulnerability { var vulnerabilities []*models.Vulnerability // TODO (V2 Models): remove this once PackageDetails has been migrated @@ -248,7 +249,7 @@ func (db *ZipDB) VulnerabilitiesAffectingPackage(pkg imodels.PackageInfo) []*mod DepGroups: pkg.DepGroups(), } - for _, vulnerability := range db.Vulnerabilities(false) { + for _, vulnerability := range allVulns { if vulns.IsAffected(vulnerability, mappedPackageDetails) && !vulns.Include(vulnerabilities, vulnerability) { vulnerabilities = append(vulnerabilities, &vulnerability) } @@ -258,10 +259,11 @@ func (db *ZipDB) VulnerabilitiesAffectingPackage(pkg imodels.PackageInfo) []*mod } func (db *ZipDB) Check(pkgs []imodels.PackageInfo) ([]*models.Vulnerability, error) { + allVulns := db.Vulnerabilities(false) vulnerabilities := make([]*models.Vulnerability, 0, len(pkgs)) for _, pkg := range pkgs { - vulnerabilities = append(vulnerabilities, db.VulnerabilitiesAffectingPackage(pkg)...) + vulnerabilities = append(vulnerabilities, VulnerabilitiesAffectingPackage(allVulns, pkg)...) } return vulnerabilities, nil diff --git a/internal/clients/clientimpl/osvmatcher/cachedosvmatcher.go b/internal/clients/clientimpl/osvmatcher/cachedosvmatcher.go new file mode 100644 index 00000000000..7061b6e0193 --- /dev/null +++ b/internal/clients/clientimpl/osvmatcher/cachedosvmatcher.go @@ -0,0 +1,151 @@ +package osvmatcher + +import ( + "context" + "errors" + "maps" + "slices" + "sync" + "time" + + "github.com/google/osv-scalibr/extractor" + "github.com/google/osv-scanner/internal/clients/clientimpl/localmatcher" + "github.com/google/osv-scanner/internal/imodels" + "github.com/google/osv-scanner/internal/osvdev" + "github.com/google/osv-scanner/pkg/models" + "golang.org/x/sync/errgroup" +) + +// CachedOSVMatcher implements the VulnerabilityMatcher interface with a osv.dev client. +// It sends out requests for every vulnerability of each package, which get cached. +// Checking if a specific version matches an OSV record is done locally. +// This should be used when we know the same packages are going to be repeatedly +// queried multiple times, as in guided remediation. +// TODO: This does not support commit-based queries. +type CachedOSVMatcher struct { + Client osvdev.OSVClient + // InitialQueryTimeout allows you to set a timeout specifically for the initial paging query + // If timeout runs out, whatever pages that has been successfully queried within the timeout will + // still return fully hydrated. + InitialQueryTimeout time.Duration + + vulnCache sync.Map // map[osvdev.Package][]models.Vulnerability +} + +func (matcher *CachedOSVMatcher) MatchVulnerabilities(ctx context.Context, invs []*extractor.Inventory) ([][]*models.Vulnerability, error) { + // populate vulnCache with missing packages + if err := matcher.doQueries(ctx, invs); err != nil { + return nil, err + } + + results := make([][]*models.Vulnerability, len(invs)) + + for i, inv := range invs { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + pkgInfo := imodels.FromInventory(inv) + pkg := osvdev.Package{ + Name: pkgInfo.Name(), + Ecosystem: pkgInfo.Ecosystem().String(), + } + vulns, ok := matcher.vulnCache.Load(pkg) + if !ok { + continue + } + results[i] = localmatcher.VulnerabilitiesAffectingPackage(vulns.([]models.Vulnerability), pkgInfo) + } + + return results, nil +} + +func (matcher *CachedOSVMatcher) doQueries(ctx context.Context, invs []*extractor.Inventory) error { + var batchResp *osvdev.BatchedResponse + deadlineExceeded := false + + var queries []*osvdev.Query + { + // determine which packages aren't already cached + // convert Inventory to Query for each pkgs element + toQuery := make(map[*osvdev.Query]struct{}) + for _, inv := range invs { + pkgInfo := imodels.FromInventory(inv) + if pkgInfo.Name() == "" || pkgInfo.Ecosystem().IsEmpty() { + continue + } + pkg := osvdev.Package{ + Name: pkgInfo.Name(), + Ecosystem: pkgInfo.Ecosystem().String(), + } + if _, ok := matcher.vulnCache.Load(pkg); !ok { + toQuery[&osvdev.Query{Package: pkg}] = struct{}{} + } + } + queries = slices.Collect(maps.Keys(toQuery)) + } + + if len(queries) == 0 { + return nil + } + + var err error + + // If there is a timeout for the initial query, set an additional context deadline here. + if matcher.InitialQueryTimeout > 0 { + batchQueryCtx, cancelFunc := context.WithDeadline(ctx, time.Now().Add(matcher.InitialQueryTimeout)) + batchResp, err = queryForBatchWithPaging(batchQueryCtx, &matcher.Client, queries) + cancelFunc() + } else { + batchResp, err = queryForBatchWithPaging(ctx, &matcher.Client, queries) + } + + if err != nil { + // Deadline being exceeded is likely caused by a long paging time + // if that's the case, we can should return what we already got, and + // then let the caller know it is not all the results. + if errors.Is(err, context.DeadlineExceeded) { + deadlineExceeded = true + } else { + return err + } + } + + vulnerabilities := make([][]models.Vulnerability, len(batchResp.Results)) + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(maxConcurrentRequests) + + for batchIdx, resp := range batchResp.Results { + vulnerabilities[batchIdx] = make([]models.Vulnerability, len(resp.Vulns)) + for resultIdx, vuln := range resp.Vulns { + g.Go(func() error { + // exit early if another hydration request has already failed + // results are thrown away later, so avoid needless work + if ctx.Err() != nil { + return nil //nolint:nilerr // this value doesn't matter to errgroup.Wait() + } + vuln, err := matcher.Client.GetVulnByID(ctx, vuln.ID) + if err != nil { + return err + } + vulnerabilities[batchIdx][resultIdx] = *vuln + + return nil + }) + } + } + + if err := g.Wait(); err != nil { + return err + } + + if deadlineExceeded { + return context.DeadlineExceeded + } + + for i, vulns := range vulnerabilities { + matcher.vulnCache.Store(queries[i].Package, vulns) + } + + return nil +} diff --git a/internal/resolution/clienttest/mock_resolution_client.go b/internal/resolution/clienttest/mock_resolution_client.go index 6704f2a1ec3..87243b19226 100644 --- a/internal/resolution/clienttest/mock_resolution_client.go +++ b/internal/resolution/clienttest/mock_resolution_client.go @@ -9,10 +9,9 @@ import ( "deps.dev/util/resolve" "deps.dev/util/resolve/schema" "github.com/google/osv-scalibr/extractor" + "github.com/google/osv-scanner/internal/clients/clientimpl/localmatcher" "github.com/google/osv-scanner/internal/imodels" "github.com/google/osv-scanner/internal/resolution/client" - "github.com/google/osv-scanner/internal/utility/vulns" - "github.com/google/osv-scanner/pkg/lockfile" "github.com/google/osv-scanner/pkg/models" "gopkg.in/yaml.v3" ) @@ -28,21 +27,7 @@ type mockVulnerabilityMatcher []models.Vulnerability func (mvc mockVulnerabilityMatcher) MatchVulnerabilities(_ context.Context, invs []*extractor.Inventory) ([][]*models.Vulnerability, error) { result := make([][]*models.Vulnerability, len(invs)) for i, inv := range invs { - pkg := imodels.FromInventory(inv) - // TODO (V2 Models): remove this once PackageDetails has been migrated - mappedPackageDetails := lockfile.PackageDetails{ - Name: pkg.Name(), - Version: pkg.Version(), - Commit: pkg.Commit(), - Ecosystem: lockfile.Ecosystem(pkg.Ecosystem().String()), - CompareAs: lockfile.Ecosystem(pkg.Ecosystem().String()), - DepGroups: pkg.DepGroups(), - } - for _, v := range mvc { - if vulns.IsAffected(v, mappedPackageDetails) { - result[i] = append(result[i], &v) - } - } + result[i] = localmatcher.VulnerabilitiesAffectingPackage(mvc, imodels.FromInventory(inv)) } return result, nil diff --git a/scripts/generate_mock_resolution_universe/main.go b/scripts/generate_mock_resolution_universe/main.go index d482713bbb1..348cfc50eb6 100644 --- a/scripts/generate_mock_resolution_universe/main.go +++ b/scripts/generate_mock_resolution_universe/main.go @@ -53,7 +53,7 @@ var remediationOpts = remediation.Options{ func doRelockRelax(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename string) error { cl := client.ResolutionClient{ - VulnerabilityMatcher: &osvmatcher.OSVMatcher{ + VulnerabilityMatcher: &osvmatcher.CachedOSVMatcher{ Client: *osvdev.DefaultClient(), InitialQueryTimeout: 5 * time.Minute, }, @@ -83,7 +83,7 @@ func doRelockRelax(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename func doOverride(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename string) error { cl := client.ResolutionClient{ - VulnerabilityMatcher: &osvmatcher.OSVMatcher{ + VulnerabilityMatcher: &osvmatcher.CachedOSVMatcher{ Client: *osvdev.DefaultClient(), InitialQueryTimeout: 5 * time.Minute, }, @@ -113,7 +113,7 @@ func doOverride(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename str func doInPlace(ddCl *client.DepsDevClient, rw lockfile.ReadWriter, filename string) error { cl := client.ResolutionClient{ - VulnerabilityMatcher: &osvmatcher.OSVMatcher{ + VulnerabilityMatcher: &osvmatcher.CachedOSVMatcher{ Client: *osvdev.DefaultClient(), InitialQueryTimeout: 5 * time.Minute, }, From 76e3dd2d0eaa85684c6b7fc7a4ab2bc71d5c6497 Mon Sep 17 00:00:00 2001 From: Michael Kedar Date: Thu, 16 Jan 2025 12:58:31 +1100 Subject: [PATCH 4/6] delete old clients --- internal/resolution/client/osv_client.go | 94 --------------- .../resolution/client/osv_offline_client.go | 107 ------------------ .../generate_mock_resolution_universe/main.go | 41 ++++--- 3 files changed, 25 insertions(+), 217 deletions(-) delete mode 100644 internal/resolution/client/osv_client.go delete mode 100644 internal/resolution/client/osv_offline_client.go diff --git a/internal/resolution/client/osv_client.go b/internal/resolution/client/osv_client.go deleted file mode 100644 index ac2f9d453d3..00000000000 --- a/internal/resolution/client/osv_client.go +++ /dev/null @@ -1,94 +0,0 @@ -package client - -import ( - "sync" - - "deps.dev/util/resolve" - "github.com/google/osv-scanner/internal/resolution/util" - "github.com/google/osv-scanner/internal/utility/vulns" - "github.com/google/osv-scanner/pkg/models" - "github.com/google/osv-scanner/pkg/osv" - "golang.org/x/exp/maps" -) - -type OSVClient struct { - // vulnCache caches all vulnerabilities affecting any versions of particular packages. - // We cache call vulns & manually check affected, rather than querying the affected versions directly - // since remediation needs to query for OSV vulnerabilities multiple times for the same packages. - vulnCache sync.Map // map[resolve.PackageKey][]models.Vulnerability - // TODO: This tends to get the full info of a lot of vulns that never show up in the dependency graphs. - // Worst case is something like PyPI:tensorflow, which has >600 vulns across all versions, but a specific version may be affected by 0. -} - -func NewOSVClient() *OSVClient { - return &OSVClient{} -} - -func (c *OSVClient) FindVulns(g *resolve.Graph) ([]models.Vulnerabilities, error) { - // Determine which packages we don't already have cached - toQuery := make(map[resolve.PackageKey]struct{}) - for _, node := range g.Nodes[1:] { // skipping the root node - pk := node.Version.PackageKey - if _, ok := c.vulnCache.Load(pk); !ok { - toQuery[pk] = struct{}{} - } - } - - // Query OSV for the missing records - if len(toQuery) > 0 { - pks := maps.Keys(toQuery) - var batchRequest osv.BatchedQuery - batchRequest.Queries = make([]*osv.Query, len(pks)) - for i, pk := range pks { - batchRequest.Queries[i] = &osv.Query{ - Package: osv.Package{ - Name: pk.Name, - Ecosystem: string(util.OSVEcosystem[pk.System]), - }, - // Omitting the Version from the query gets all vulns affecting any version of the package - // (I'm not actually sure if this behaviour is explicitly documented anywhere) - } - } - batchResponse, err := osv.MakeRequest(batchRequest) - if err != nil { - return nil, err - } - hydrated, err := osv.Hydrate(batchResponse) - if err != nil { - return nil, err - } - // fill in the cache with the responses - for i, pk := range pks { - c.vulnCache.Store(pk, hydrated.Results[i].Vulns) - } - } - - // Compute the actual affected vulnerabilities for each node - nodeVulns := make([]models.Vulnerabilities, len(g.Nodes)) - // For convenience, include the root node as an empty slice in the results - for i, n := range g.Nodes { - if i == 0 { - continue - } - pkgVulnsAny, ok := c.vulnCache.Load(n.Version.PackageKey) - if !ok { - // This should be impossible - panic("vulnerability caching failed") - } - pkgVulns, ok := pkgVulnsAny.([]models.Vulnerability) - if !ok { - panic("vulnerability caching failed") - } - - var affectedVulns []models.Vulnerability - pkgDetails := util.VKToPackageDetails(n.Version) - for _, vuln := range pkgVulns { - if vulns.IsAffected(vuln, pkgDetails) { - affectedVulns = append(affectedVulns, vuln) - } - } - nodeVulns[i] = affectedVulns - } - - return nodeVulns, nil -} diff --git a/internal/resolution/client/osv_offline_client.go b/internal/resolution/client/osv_offline_client.go deleted file mode 100644 index 96505f143f0..00000000000 --- a/internal/resolution/client/osv_offline_client.go +++ /dev/null @@ -1,107 +0,0 @@ -//nolint:unused,revive -package client - -import ( - "errors" - "fmt" - "strings" - - "deps.dev/util/resolve" - "github.com/google/osv-scanner/pkg/models" - "github.com/google/osv-scanner/pkg/reporter" -) - -type OSVOfflineClient struct { - // TODO: OSV-Scanner v2 plans to make vulnerability clients that can be used here. - localDBPath string -} - -func NewOSVOfflineClient(r reporter.Reporter, system resolve.System, downloadDBs bool, localDBPath string) (*OSVOfflineClient, error) { - if system == resolve.UnknownSystem { - return nil, errors.New("osv offline client created with unknown ecosystem") - } - - panic("TODO: Temporarily disabled") - - // // Make a dummy request to the local client to log and make sure the database is downloaded without error. - // q := osv.BatchedQuery{Queries: []*osv.Query{{ - // Package: osv.Package{ - // Name: "foo", - // Ecosystem: string(util.OSVEcosystem[system]), - // }, - // Version: "1.0.0", - // }}} - // _, err := local.MakeRequest(r, q, !downloadDBs, localDBPath) - // if err != nil { - // return nil, err - // } - - // if r.HasErrored() { - // return nil, errors.New("error creating osv offline client") - // } - - // return &OSVOfflineClient{localDBPath: localDBPath}, nil -} - -func (c *OSVOfflineClient) FindVulns(g *resolve.Graph) ([]models.Vulnerabilities, error) { - return []models.Vulnerabilities{}, nil - // var query osv.BatchedQuery - // query.Queries = make([]*osv.Query, len(g.Nodes)-1) - // for i, node := range g.Nodes[1:] { - // query.Queries[i] = &osv.Query{ - // Package: osv.Package{ - // Name: node.Version.Name, - // Ecosystem: string(util.OSVEcosystem[node.Version.System]), - // }, - // Version: node.Version.Version, - // } - // } - - // // If local.MakeRequest logs an error, it's probably fatal for guided remediation. - // // Set up a reporter to capture error logs and return the logs as an error. - // r := &errorReporter{} - // // DB should already be downloaded, set offline to true. - // hydrated, err := local.MakeRequest(r, query, true, c.localDBPath) - - // if err != nil { - // return nil, err - // } - - // if r.HasErrored() { - // return nil, r.GetError() - // } - - // nodeVulns := make([]models.Vulnerabilities, len(g.Nodes)) - // for i, res := range hydrated.Results { - // nodeVulns[i+1] = res.Vulns - // } - - // return nodeVulns, nil -} - -// errorReporter is a reporter.Reporter to capture error logs and pack them into an error. -type errorReporter struct { - s strings.Builder -} - -func (r *errorReporter) Errorf(format string, a ...any) { - fmt.Fprintf(&r.s, format, a...) -} - -func (r *errorReporter) HasErrored() bool { - return r.s.Len() > 0 -} - -func (r *errorReporter) GetError() error { - str := strings.TrimSpace(r.s.String()) - if str == "" { - return nil - } - - return errors.New(str) -} - -func (r *errorReporter) Warnf(string, ...any) {} -func (r *errorReporter) Infof(string, ...any) {} -func (r *errorReporter) Verbosef(string, ...any) {} -func (r *errorReporter) PrintResult(*models.VulnerabilityResults) error { return nil } diff --git a/scripts/generate_mock_resolution_universe/main.go b/scripts/generate_mock_resolution_universe/main.go index 348cfc50eb6..b45b2fe979e 100644 --- a/scripts/generate_mock_resolution_universe/main.go +++ b/scripts/generate_mock_resolution_universe/main.go @@ -13,6 +13,7 @@ import ( "encoding/gob" "errors" "fmt" + "net/http" "os" "path/filepath" "slices" @@ -23,6 +24,7 @@ import ( "deps.dev/util/resolve" "deps.dev/util/resolve/dep" "github.com/google/osv-scanner/internal/clients/clientimpl/osvmatcher" + "github.com/google/osv-scanner/internal/clients/clientinterfaces" "github.com/google/osv-scanner/internal/depsdev" "github.com/google/osv-scanner/internal/osvdev" "github.com/google/osv-scanner/internal/remediation" @@ -51,13 +53,26 @@ var remediationOpts = remediation.Options{ UpgradeConfig: upgrade.NewConfig(), } +const userAgent = "osv-scanner_generate_mock/" + version.OSVVersion + +func vulnMatcher() clientinterfaces.VulnerabilityMatcher { + config := osvdev.DefaultConfig() + config.UserAgent = userAgent + + return &osvmatcher.CachedOSVMatcher{ + Client: osvdev.OSVClient{ + HTTPClient: http.DefaultClient, + Config: config, + BaseHostURL: osvdev.DefaultBaseURL, + }, + InitialQueryTimeout: 5 * time.Minute, + } +} + func doRelockRelax(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename string) error { cl := client.ResolutionClient{ - VulnerabilityMatcher: &osvmatcher.CachedOSVMatcher{ - Client: *osvdev.DefaultClient(), - InitialQueryTimeout: 5 * time.Minute, - }, - DependencyClient: ddCl, + VulnerabilityMatcher: vulnMatcher(), + DependencyClient: ddCl, } f, err := lf.OpenLocalDepFile(filename) @@ -83,11 +98,8 @@ func doRelockRelax(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename func doOverride(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename string) error { cl := client.ResolutionClient{ - VulnerabilityMatcher: &osvmatcher.CachedOSVMatcher{ - Client: *osvdev.DefaultClient(), - InitialQueryTimeout: 5 * time.Minute, - }, - DependencyClient: ddCl, + VulnerabilityMatcher: vulnMatcher(), + DependencyClient: ddCl, } f, err := lf.OpenLocalDepFile(filename) @@ -113,11 +125,8 @@ func doOverride(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename str func doInPlace(ddCl *client.DepsDevClient, rw lockfile.ReadWriter, filename string) error { cl := client.ResolutionClient{ - VulnerabilityMatcher: &osvmatcher.CachedOSVMatcher{ - Client: *osvdev.DefaultClient(), - InitialQueryTimeout: 5 * time.Minute, - }, - DependencyClient: ddCl, + VulnerabilityMatcher: vulnMatcher(), + DependencyClient: ddCl, } f, err := lf.OpenLocalDepFile(filename) @@ -294,7 +303,7 @@ func typeString(t dep.Type) string { } func main() { - cl, err := client.NewDepsDevClient(depsdev.DepsdevAPI, "osv-scanner_generate_mock/"+version.OSVVersion) + cl, err := client.NewDepsDevClient(depsdev.DepsdevAPI, userAgent) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) From ba3561f82280e2200a2274e4575b4451848681d2 Mon Sep 17 00:00:00 2001 From: Michael Kedar Date: Thu, 16 Jan 2025 14:57:22 +1100 Subject: [PATCH 5/6] As -> To --- internal/remediation/in_place.go | 2 +- internal/resolution/client/helper.go | 6 +++--- internal/resolution/resolve.go | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/internal/remediation/in_place.go b/internal/remediation/in_place.go index bf2155bb39e..6b5524bc05b 100644 --- a/internal/remediation/in_place.go +++ b/internal/remediation/in_place.go @@ -237,7 +237,7 @@ type inPlaceVulnsNodesResult struct { } func inPlaceVulnsNodes(ctx context.Context, m clientinterfaces.VulnerabilityMatcher, graph *resolve.Graph) (inPlaceVulnsNodesResult, error) { - nodeVulns, err := m.MatchVulnerabilities(ctx, client.GraphAsInventory(graph)) + nodeVulns, err := m.MatchVulnerabilities(ctx, client.GraphToInventory(graph)) if err != nil { return inPlaceVulnsNodesResult{}, err } diff --git a/internal/resolution/client/helper.go b/internal/resolution/client/helper.go index aae4c4988c0..6438f1306fd 100644 --- a/internal/resolution/client/helper.go +++ b/internal/resolution/client/helper.go @@ -7,8 +7,8 @@ import ( "github.com/google/osv-scalibr/purl" ) -// GraphAsInventory is a helper function to convert a Graph into an Inventory for use with VulnerabilityMatcher. -func GraphAsInventory(g *resolve.Graph) []*extractor.Inventory { +// GraphToInventory is a helper function to convert a Graph into an Inventory for use with VulnerabilityMatcher. +func GraphToInventory(g *resolve.Graph) []*extractor.Inventory { // g.Nodes[0] is the root node of the graph that should be excluded. inv := make([]*extractor.Inventory, len(g.Nodes)-1) for i, n := range g.Nodes[1:] { @@ -22,7 +22,7 @@ func GraphAsInventory(g *resolve.Graph) []*extractor.Inventory { return inv } -// mockExtractor is for GraphAsInventory to get the ecosystem. +// mockExtractor is for GraphToInventory to get the ecosystem. type mockExtractor struct { ecosystem resolve.System } diff --git a/internal/resolution/resolve.go b/internal/resolution/resolve.go index 818fb98a1c8..978ca45dd55 100644 --- a/internal/resolution/resolve.go +++ b/internal/resolution/resolve.go @@ -204,7 +204,7 @@ func resolvePostProcess(ctx context.Context, cl client.ResolutionClient, m manif // computeVulns scans for vulnerabilities in a resolved graph and populates res.Vulns func (res *Result) computeVulns(ctx context.Context, cl client.ResolutionClient) error { - nodeVulns, err := cl.MatchVulnerabilities(ctx, client.GraphAsInventory(res.Graph)) + nodeVulns, err := cl.MatchVulnerabilities(ctx, client.GraphToInventory(res.Graph)) if err != nil { return err } From 5e87da6fcd3a0e9c999a81dbcb749fa8ff41d47e Mon Sep 17 00:00:00 2001 From: Michael Kedar Date: Thu, 16 Jan 2025 16:15:02 +1100 Subject: [PATCH 6/6] add element back into nodeVulns --- internal/remediation/in_place.go | 14 ++++++++------ internal/resolution/resolve.go | 9 +++++++-- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/internal/remediation/in_place.go b/internal/remediation/in_place.go index 6b5524bc05b..f7547723eb2 100644 --- a/internal/remediation/in_place.go +++ b/internal/remediation/in_place.go @@ -16,6 +16,7 @@ import ( lf "github.com/google/osv-scanner/internal/resolution/lockfile" "github.com/google/osv-scanner/internal/resolution/util" "github.com/google/osv-scanner/internal/utility/vulns" + "github.com/google/osv-scanner/pkg/models" "golang.org/x/exp/maps" ) @@ -242,6 +243,10 @@ func inPlaceVulnsNodes(ctx context.Context, m clientinterfaces.VulnerabilityMatc return inPlaceVulnsNodesResult{}, err } + // GraphToInventory/MatchVulnerabilities excludes the root node of the graph. + // Prepend an element to nodeVulns so that the indices line up with graph.Nodes[i] <=> nodeVulns[i] + nodeVulns = append([][]*models.Vulnerability{nil}, nodeVulns...) + result := inPlaceVulnsNodesResult{ nodeDependencies: make(map[resolve.NodeID][]resolve.VersionKey), vkVulns: make(map[resolve.VersionKey][]resolution.Vulnerability), @@ -250,10 +255,7 @@ func inPlaceVulnsNodes(ctx context.Context, m clientinterfaces.VulnerabilityMatc // Find all direct dependencies of vulnerable nodes. for _, e := range graph.Edges { - if e.From == 0 { - continue - } - if len(nodeVulns[e.From-1]) > 0 { + if len(nodeVulns[e.From]) > 0 { result.nodeDependencies[e.From] = append(result.nodeDependencies[e.From], graph.Nodes[e.To].Version) } } @@ -263,7 +265,7 @@ func inPlaceVulnsNodes(ctx context.Context, m clientinterfaces.VulnerabilityMatc var nodeIDs []resolve.NodeID for nID, vulns := range nodeVulns { if len(vulns) > 0 { - nodeIDs = append(nodeIDs, resolve.NodeID(nID+1)) + nodeIDs = append(nodeIDs, resolve.NodeID(nID)) } } nodeChains := resolution.ComputeChains(graph, nodeIDs) @@ -274,7 +276,7 @@ func inPlaceVulnsNodes(ctx context.Context, m clientinterfaces.VulnerabilityMatc chains := nodeChains[i] vk := graph.Nodes[nID].Version result.vkNodes[vk] = append(result.vkNodes[vk], nID) - for _, vuln := range nodeVulns[nID-1] { + for _, vuln := range nodeVulns[nID] { resVuln := resolution.Vulnerability{ OSV: *vuln, ProblemChains: slices.Clone(chains), diff --git a/internal/resolution/resolve.go b/internal/resolution/resolve.go index 978ca45dd55..5ba2b167dc8 100644 --- a/internal/resolution/resolve.go +++ b/internal/resolution/resolve.go @@ -208,12 +208,17 @@ func (res *Result) computeVulns(ctx context.Context, cl client.ResolutionClient) if err != nil { return err } + + // GraphToInventory/MatchVulnerabilities excludes the root node of the graph. + // Prepend an element to nodeVulns so that the indices line up with graph.Nodes[i] <=> nodeVulns[i] + nodeVulns = append([][]*models.Vulnerability{nil}, nodeVulns...) + // Find all dependency paths to the vulnerable dependencies var vulnerableNodes []resolve.NodeID vulnInfo := make(map[string]models.Vulnerability) for i, vulns := range nodeVulns { if len(vulns) > 0 { - vulnerableNodes = append(vulnerableNodes, resolve.NodeID(i+1)) + vulnerableNodes = append(vulnerableNodes, resolve.NodeID(i)) } for _, vuln := range vulns { vulnInfo[vuln.ID] = *vuln @@ -223,7 +228,7 @@ func (res *Result) computeVulns(ctx context.Context, cl client.ResolutionClient) nodeChains := ComputeChains(res.Graph, vulnerableNodes) vulnChains := make(map[string][]DependencyChain) for i, nID := range vulnerableNodes { - for _, vuln := range nodeVulns[nID-1] { + for _, vuln := range nodeVulns[nID] { vulnChains[vuln.ID] = append(vulnChains[vuln.ID], nodeChains[i]...) } }