Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use new VulnerabilityMatcher in guided remediation #1503

Merged
merged 6 commits into from
Jan 19, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
WIP client migration
michaelkedar committed Jan 15, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 4bef858e23a2348b0a41683da370520a9a850864
25 changes: 20 additions & 5 deletions cmd/osv-scanner/fix/main.go
Original file line number Diff line number Diff line change
@@ -7,9 +7,13 @@ import (
"os"
"path/filepath"
"strings"
"time"

"deps.dev/util/resolve"
"github.com/google/osv-scanner/internal/clients/clientimpl/localmatcher"
"github.com/google/osv-scanner/internal/clients/clientimpl/osvmatcher"
"github.com/google/osv-scanner/internal/depsdev"
"github.com/google/osv-scanner/internal/osvdev"
"github.com/google/osv-scanner/internal/remediation"
"github.com/google/osv-scanner/internal/remediation/upgrade"
"github.com/google/osv-scanner/internal/resolution"
@@ -365,17 +369,28 @@ func action(ctx *cli.Context, stdout, stderr io.Writer) (reporter.Reporter, erro
}

if ctx.Bool("experimental-offline-vulnerabilities") {
var err error
opts.Client.VulnerabilityClient, err = client.NewOSVOfflineClient(
matcher, err := localmatcher.NewLocalMatcher(
r,
system,
ctx.String("experimental-local-db-path"),
"osv-scanner_fix/"+version.OSVVersion,
ctx.Bool("experimental-download-offline-databases"),
ctx.String("experimental-local-db-path"))
)
if err != nil {
return nil, err
}

// TODO: check system is downloaded
// if err := matcher.LoadEcosystem(ctx.Context, system?); err != nil {
// return nil, err
// }

opts.Client.VulnerabilityMatcher = matcher
} else {
opts.Client.VulnerabilityClient = client.NewOSVClient()
// TODO: replace with cached client
opts.Client.VulnerabilityMatcher = &osvmatcher.OSVMatcher{
Client: *osvdev.DefaultClient(), // TODO: UserAgent
InitialQueryTimeout: 5 * time.Minute,
}
}

if !ctx.Bool("non-interactive") {
18 changes: 11 additions & 7 deletions internal/remediation/in_place.go
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@ import (
"deps.dev/util/resolve"
"deps.dev/util/resolve/dep"
"deps.dev/util/semver"
"github.com/google/osv-scanner/internal/clients/clientinterfaces"
"github.com/google/osv-scanner/internal/remediation/upgrade"
"github.com/google/osv-scanner/internal/resolution"
"github.com/google/osv-scanner/internal/resolution/client"
@@ -105,7 +106,7 @@ func (r InPlaceResult) VulnCount() VulnCount {
// ComputeInPlacePatches finds all possible targeting version changes that would fix vulnerabilities in a resolved graph.
// TODO: Check for introduced vulnerabilities
func ComputeInPlacePatches(ctx context.Context, cl client.ResolutionClient, graph *resolve.Graph, opts Options) (InPlaceResult, error) {
res, err := inPlaceVulnsNodes(cl, graph)
res, err := inPlaceVulnsNodes(ctx, cl, graph)
if err != nil {
return InPlaceResult{}, err
}
@@ -235,8 +236,8 @@ type inPlaceVulnsNodesResult struct {
vkNodes map[resolve.VersionKey][]resolve.NodeID
}

func inPlaceVulnsNodes(cl client.VulnerabilityClient, graph *resolve.Graph) (inPlaceVulnsNodesResult, error) {
nodeVulns, err := cl.FindVulns(graph)
func inPlaceVulnsNodes(ctx context.Context, m clientinterfaces.VulnerabilityMatcher, graph *resolve.Graph) (inPlaceVulnsNodesResult, error) {
nodeVulns, err := m.MatchVulnerabilities(ctx, client.GraphAsInventory(graph))
if err != nil {
return inPlaceVulnsNodesResult{}, err
}
@@ -249,7 +250,10 @@ func inPlaceVulnsNodes(cl client.VulnerabilityClient, graph *resolve.Graph) (inP

// Find all direct dependencies of vulnerable nodes.
for _, e := range graph.Edges {
if len(nodeVulns[e.From]) > 0 {
if e.From == 0 {
continue
}
if len(nodeVulns[e.From-1]) > 0 {
result.nodeDependencies[e.From] = append(result.nodeDependencies[e.From], graph.Nodes[e.To].Version)
}
}
@@ -259,7 +263,7 @@ func inPlaceVulnsNodes(cl client.VulnerabilityClient, graph *resolve.Graph) (inP
var nodeIDs []resolve.NodeID
for nID, vulns := range nodeVulns {
if len(vulns) > 0 {
nodeIDs = append(nodeIDs, resolve.NodeID(nID))
nodeIDs = append(nodeIDs, resolve.NodeID(nID+1))
}
}
nodeChains := resolution.ComputeChains(graph, nodeIDs)
@@ -270,9 +274,9 @@ func inPlaceVulnsNodes(cl client.VulnerabilityClient, graph *resolve.Graph) (inP
chains := nodeChains[i]
vk := graph.Nodes[nID].Version
result.vkNodes[vk] = append(result.vkNodes[vk], nID)
for _, vuln := range nodeVulns[nID] {
for _, vuln := range nodeVulns[nID-1] {
resVuln := resolution.Vulnerability{
OSV: vuln,
OSV: *vuln,
ProblemChains: slices.Clone(chains),
DevOnly: !slices.ContainsFunc(chains, func(dc resolution.DependencyChain) bool { return !resolution.ChainIsDev(dc, nil) }),
}
10 changes: 2 additions & 8 deletions internal/resolution/client/client.go
Original file line number Diff line number Diff line change
@@ -8,22 +8,16 @@ import (
"deps.dev/util/resolve"
"deps.dev/util/resolve/dep"
"deps.dev/util/semver"
"github.com/google/osv-scanner/internal/clients/clientinterfaces"
"github.com/google/osv-scanner/internal/depsdev"
"github.com/google/osv-scanner/pkg/models"
"github.com/google/osv-scanner/pkg/osv"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

type ResolutionClient struct {
DependencyClient
VulnerabilityClient
}

type VulnerabilityClient interface {
// FindVulns finds the vulnerabilities affecting each of Nodes in the graph.
// The returned Vulnerabilities[i] corresponds to the vulnerabilities in g.Nodes[i].
FindVulns(g *resolve.Graph) ([]models.Vulnerabilities, error)
clientinterfaces.VulnerabilityMatcher
}

type DependencyClient interface {
48 changes: 48 additions & 0 deletions internal/resolution/client/helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package client

import (
"deps.dev/util/resolve"
"github.com/google/osv-scalibr/extractor"
"github.com/google/osv-scalibr/plugin"
"github.com/google/osv-scalibr/purl"
"github.com/ossf/osv-schema/bindings/go/osvschema"
)

// GraphAsInventory is a helper function to convert a Graph into an Inventory for use with VulnerabilityMatcher.
func GraphAsInventory(g *resolve.Graph) []*extractor.Inventory {
// g.Nodes[0] is the root node of the graph that should be excluded.
inv := make([]*extractor.Inventory, len(g.Nodes)-1)
for i, n := range g.Nodes[1:] {
inv[i] = &extractor.Inventory{
Name: n.Version.Name,
Version: n.Version.Version,
Metadata: n,
Extractor: graphExtractor{},
Locations: []string{g.Nodes[0].Version.Name},
}
}
return inv
}

// graphExtractor is for GraphAsInventory to get the ecosystem.
type graphExtractor struct{}

func (e graphExtractor) Ecosystem(i *extractor.Inventory) string {
n, ok := i.Metadata.(resolve.Node)
if !ok {
return ""
}
switch n.Version.System {
case resolve.NPM:
return string(osvschema.EcosystemNPM)
case resolve.Maven:
return string(osvschema.EcosystemMaven)
default:
return ""
}
}

func (e graphExtractor) Name() string { return "" }
func (e graphExtractor) Requirements() *plugin.Capabilities { return nil }
func (e graphExtractor) ToPURL(_ *extractor.Inventory) *purl.PackageURL { return nil }
func (e graphExtractor) Version() int { return 0 }
30 changes: 20 additions & 10 deletions internal/resolution/clienttest/mock_resolution_client.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
package clienttest

import (
"context"
"os"
"strings"
"testing"

"deps.dev/util/resolve"
"deps.dev/util/resolve/schema"
"github.com/google/osv-scalibr/extractor"
"github.com/google/osv-scanner/internal/imodels"
"github.com/google/osv-scanner/internal/resolution/client"
"github.com/google/osv-scanner/internal/resolution/util"
"github.com/google/osv-scanner/internal/utility/vulns"
"github.com/google/osv-scanner/pkg/lockfile"
"github.com/google/osv-scanner/pkg/models"
"gopkg.in/yaml.v3"
)
@@ -20,17 +23,24 @@ type ResolutionUniverse struct {
Vulns []models.Vulnerability `yaml:"vulns"`
}

type mockVulnerabilityClient []models.Vulnerability
type mockVulnerabilityMatcher []models.Vulnerability

func (mvc mockVulnerabilityClient) FindVulns(g *resolve.Graph) ([]models.Vulnerabilities, error) {
result := make([]models.Vulnerabilities, len(g.Nodes))
for i, n := range g.Nodes {
if i == 0 {
continue // skip root node
func (mvc mockVulnerabilityMatcher) MatchVulnerabilities(_ context.Context, invs []*extractor.Inventory) ([][]*models.Vulnerability, error) {
result := make([][]*models.Vulnerability, len(invs))
for i, inv := range invs {
pkg := imodels.FromInventory(inv)
// TODO (V2 Models): remove this once PackageDetails has been migrated
mappedPackageDetails := lockfile.PackageDetails{
Name: pkg.Name(),
Version: pkg.Version(),
Commit: pkg.Commit(),
Ecosystem: lockfile.Ecosystem(pkg.Ecosystem().String()),
CompareAs: lockfile.Ecosystem(pkg.Ecosystem().String()),
DepGroups: pkg.DepGroups(),
}
for _, v := range mvc {
if vulns.IsAffected(v, util.VKToPackageDetails(n.Version)) {
result[i] = append(result[i], v)
if vulns.IsAffected(v, mappedPackageDetails) {
result[i] = append(result[i], &v)
}
}
}
@@ -61,7 +71,7 @@ func NewMockResolutionClient(t *testing.T, universeYAML string) client.Resolutio
}

cl := client.ResolutionClient{
VulnerabilityClient: mockVulnerabilityClient(universe.Vulns),
VulnerabilityMatcher: mockVulnerabilityMatcher(universe.Vulns),
}

var sys resolve.System
10 changes: 5 additions & 5 deletions internal/resolution/resolve.go
Original file line number Diff line number Diff line change
@@ -204,7 +204,7 @@ func resolvePostProcess(ctx context.Context, cl client.ResolutionClient, m manif

// computeVulns scans for vulnerabilities in a resolved graph and populates res.Vulns
func (res *Result) computeVulns(ctx context.Context, cl client.ResolutionClient) error {
nodeVulns, err := cl.FindVulns(res.Graph)
nodeVulns, err := cl.MatchVulnerabilities(ctx, client.GraphAsInventory(res.Graph))
if err != nil {
return err
}
@@ -213,17 +213,17 @@ func (res *Result) computeVulns(ctx context.Context, cl client.ResolutionClient)
vulnInfo := make(map[string]models.Vulnerability)
for i, vulns := range nodeVulns {
if len(vulns) > 0 {
vulnerableNodes = append(vulnerableNodes, resolve.NodeID(i))
vulnerableNodes = append(vulnerableNodes, resolve.NodeID(i+1))
}
for _, vuln := range vulns {
vulnInfo[vuln.ID] = vuln
vulnInfo[vuln.ID] = *vuln
}
}

nodeChains := ComputeChains(res.Graph, vulnerableNodes)
vulnChains := make(map[string][]DependencyChain)
for i, idx := range vulnerableNodes {
for _, vuln := range nodeVulns[idx] {
for i, nID := range vulnerableNodes {
for _, vuln := range nodeVulns[nID-1] {
vulnChains[vuln.ID] = append(vulnChains[vuln.ID], nodeChains[i]...)
}
}
23 changes: 17 additions & 6 deletions scripts/generate_mock_resolution_universe/main.go
Original file line number Diff line number Diff line change
@@ -22,7 +22,9 @@ import (
pb "deps.dev/api/v3"
"deps.dev/util/resolve"
"deps.dev/util/resolve/dep"
"github.com/google/osv-scanner/internal/clients/clientimpl/osvmatcher"
"github.com/google/osv-scanner/internal/depsdev"
"github.com/google/osv-scanner/internal/osvdev"
"github.com/google/osv-scanner/internal/remediation"
"github.com/google/osv-scanner/internal/remediation/upgrade"
"github.com/google/osv-scanner/internal/resolution"
@@ -51,8 +53,11 @@ var remediationOpts = remediation.Options{

func doRelockRelax(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename string) error {
cl := client.ResolutionClient{
VulnerabilityClient: client.NewOSVClient(),
DependencyClient: ddCl,
VulnerabilityMatcher: &osvmatcher.OSVMatcher{
Client: *osvdev.DefaultClient(),
InitialQueryTimeout: 5 * time.Minute,
},
DependencyClient: ddCl,
}

f, err := lf.OpenLocalDepFile(filename)
@@ -78,8 +83,11 @@ func doRelockRelax(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename

func doOverride(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename string) error {
cl := client.ResolutionClient{
VulnerabilityClient: client.NewOSVClient(),
DependencyClient: ddCl,
VulnerabilityMatcher: &osvmatcher.OSVMatcher{
Client: *osvdev.DefaultClient(),
InitialQueryTimeout: 5 * time.Minute,
},
DependencyClient: ddCl,
}

f, err := lf.OpenLocalDepFile(filename)
@@ -105,8 +113,11 @@ func doOverride(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename str

func doInPlace(ddCl *client.DepsDevClient, rw lockfile.ReadWriter, filename string) error {
cl := client.ResolutionClient{
VulnerabilityClient: client.NewOSVClient(),
DependencyClient: ddCl,
VulnerabilityMatcher: &osvmatcher.OSVMatcher{
Client: *osvdev.DefaultClient(),
InitialQueryTimeout: 5 * time.Minute,
},
DependencyClient: ddCl,
}

f, err := lf.OpenLocalDepFile(filename)