Skip to content

Commit

Permalink
refactor: use new VulnerabilityMatcher in guided remediation (#1503)
Browse files Browse the repository at this point in the history
Following up on #1470 
- Made `ResolutionClient` use the `VulnerabilityMatcher` interface (and
added helper function to convert deps.dev graphs into inventories)
- Deleted old `VulnerabilityClient`
- Created `CachedOSVMatcher` to re-implement performance improvements
from the original `VulnerabilityClient` w.r.t. repeated queries.
- Re-enabled local database capability in `osv-scanner fix`
  • Loading branch information
michaelkedar authored Jan 19, 2025
1 parent 6874aa6 commit fe4eaea
Show file tree
Hide file tree
Showing 13 changed files with 296 additions and 248 deletions.
39 changes: 34 additions & 5 deletions cmd/osv-scanner/fix/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,28 @@ import (
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"

"deps.dev/util/resolve"
"github.com/google/osv-scanner/internal/clients/clientimpl/localmatcher"
"github.com/google/osv-scanner/internal/clients/clientimpl/osvmatcher"
"github.com/google/osv-scanner/internal/depsdev"
"github.com/google/osv-scanner/internal/imodels/ecosystem"
"github.com/google/osv-scanner/internal/osvdev"
"github.com/google/osv-scanner/internal/remediation"
"github.com/google/osv-scanner/internal/remediation/upgrade"
"github.com/google/osv-scanner/internal/resolution"
"github.com/google/osv-scanner/internal/resolution/client"
"github.com/google/osv-scanner/internal/resolution/lockfile"
"github.com/google/osv-scanner/internal/resolution/manifest"
"github.com/google/osv-scanner/internal/resolution/util"
"github.com/google/osv-scanner/internal/version"
"github.com/google/osv-scanner/pkg/reporter"
"github.com/ossf/osv-schema/bindings/go/osvschema"
"github.com/urfave/cli/v2"
"golang.org/x/term"
)
Expand Down Expand Up @@ -364,18 +372,39 @@ func action(ctx *cli.Context, stdout, stderr io.Writer) (reporter.Reporter, erro
}
}

userAgent := "osv-scanner_fix/" + version.OSVVersion
if ctx.Bool("experimental-offline-vulnerabilities") {
var err error
opts.Client.VulnerabilityClient, err = client.NewOSVOfflineClient(
matcher, err := localmatcher.NewLocalMatcher(
r,
system,
ctx.String("experimental-local-db-path"),
userAgent,
ctx.Bool("experimental-download-offline-databases"),
ctx.String("experimental-local-db-path"))
)
if err != nil {
return nil, err
}

eco, ok := util.OSVEcosystem[system]
if !ok {
// Something's very wrong if we hit this
panic("unhandled resolve.Ecosystem: " + system.String())
}
if err := matcher.LoadEcosystem(ctx.Context, ecosystem.Parsed{Ecosystem: osvschema.Ecosystem(eco)}); err != nil {
return nil, err
}

opts.Client.VulnerabilityMatcher = matcher
} else {
opts.Client.VulnerabilityClient = client.NewOSVClient()
config := osvdev.DefaultConfig()
config.UserAgent = userAgent
opts.Client.VulnerabilityMatcher = &osvmatcher.CachedOSVMatcher{
Client: osvdev.OSVClient{
HTTPClient: http.DefaultClient,
Config: config,
BaseHostURL: osvdev.DefaultBaseURL,
},
InitialQueryTimeout: 5 * time.Minute,
}
}

if !ctx.Bool("non-interactive") {
Expand Down
2 changes: 1 addition & 1 deletion internal/clients/clientimpl/localmatcher/localmatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (matcher *LocalMatcher) MatchVulnerabilities(ctx context.Context, invs []*e
continue
}

results = append(results, db.VulnerabilitiesAffectingPackage(pkg))
results = append(results, VulnerabilitiesAffectingPackage(db.Vulnerabilities(false), pkg))
}

return results, nil
Expand Down
8 changes: 5 additions & 3 deletions internal/clients/clientimpl/localmatcher/zip.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ func (db *ZipDB) Vulnerabilities(includeWithdrawn bool) []models.Vulnerability {
return vulnerabilities
}

func (db *ZipDB) VulnerabilitiesAffectingPackage(pkg imodels.PackageInfo) []*models.Vulnerability {
// TODO: Move this to another file.
func VulnerabilitiesAffectingPackage(allVulns []models.Vulnerability, pkg imodels.PackageInfo) []*models.Vulnerability {
var vulnerabilities []*models.Vulnerability

// TODO (V2 Models): remove this once PackageDetails has been migrated
Expand All @@ -248,7 +249,7 @@ func (db *ZipDB) VulnerabilitiesAffectingPackage(pkg imodels.PackageInfo) []*mod
DepGroups: pkg.DepGroups(),
}

for _, vulnerability := range db.Vulnerabilities(false) {
for _, vulnerability := range allVulns {
if vulns.IsAffected(vulnerability, mappedPackageDetails) && !vulns.Include(vulnerabilities, vulnerability) {
vulnerabilities = append(vulnerabilities, &vulnerability)
}
Expand All @@ -258,10 +259,11 @@ func (db *ZipDB) VulnerabilitiesAffectingPackage(pkg imodels.PackageInfo) []*mod
}

func (db *ZipDB) Check(pkgs []imodels.PackageInfo) ([]*models.Vulnerability, error) {
allVulns := db.Vulnerabilities(false)
vulnerabilities := make([]*models.Vulnerability, 0, len(pkgs))

for _, pkg := range pkgs {
vulnerabilities = append(vulnerabilities, db.VulnerabilitiesAffectingPackage(pkg)...)
vulnerabilities = append(vulnerabilities, VulnerabilitiesAffectingPackage(allVulns, pkg)...)
}

return vulnerabilities, nil
Expand Down
151 changes: 151 additions & 0 deletions internal/clients/clientimpl/osvmatcher/cachedosvmatcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package osvmatcher

import (
"context"
"errors"
"maps"
"slices"
"sync"
"time"

"github.com/google/osv-scalibr/extractor"
"github.com/google/osv-scanner/internal/clients/clientimpl/localmatcher"
"github.com/google/osv-scanner/internal/imodels"
"github.com/google/osv-scanner/internal/osvdev"
"github.com/google/osv-scanner/pkg/models"
"golang.org/x/sync/errgroup"
)

// CachedOSVMatcher implements the VulnerabilityMatcher interface with a osv.dev client.
// It sends out requests for every vulnerability of each package, which get cached.
// Checking if a specific version matches an OSV record is done locally.
// This should be used when we know the same packages are going to be repeatedly
// queried multiple times, as in guided remediation.
// TODO: This does not support commit-based queries.
type CachedOSVMatcher struct {
Client osvdev.OSVClient
// InitialQueryTimeout allows you to set a timeout specifically for the initial paging query
// If timeout runs out, whatever pages that has been successfully queried within the timeout will
// still return fully hydrated.
InitialQueryTimeout time.Duration

vulnCache sync.Map // map[osvdev.Package][]models.Vulnerability
}

func (matcher *CachedOSVMatcher) MatchVulnerabilities(ctx context.Context, invs []*extractor.Inventory) ([][]*models.Vulnerability, error) {
// populate vulnCache with missing packages
if err := matcher.doQueries(ctx, invs); err != nil {
return nil, err
}

results := make([][]*models.Vulnerability, len(invs))

for i, inv := range invs {
if ctx.Err() != nil {
return nil, ctx.Err()
}

pkgInfo := imodels.FromInventory(inv)
pkg := osvdev.Package{
Name: pkgInfo.Name(),
Ecosystem: pkgInfo.Ecosystem().String(),
}
vulns, ok := matcher.vulnCache.Load(pkg)
if !ok {
continue
}
results[i] = localmatcher.VulnerabilitiesAffectingPackage(vulns.([]models.Vulnerability), pkgInfo)
}

return results, nil
}

func (matcher *CachedOSVMatcher) doQueries(ctx context.Context, invs []*extractor.Inventory) error {
var batchResp *osvdev.BatchedResponse
deadlineExceeded := false

var queries []*osvdev.Query
{
// determine which packages aren't already cached
// convert Inventory to Query for each pkgs element
toQuery := make(map[*osvdev.Query]struct{})
for _, inv := range invs {
pkgInfo := imodels.FromInventory(inv)
if pkgInfo.Name() == "" || pkgInfo.Ecosystem().IsEmpty() {
continue
}
pkg := osvdev.Package{
Name: pkgInfo.Name(),
Ecosystem: pkgInfo.Ecosystem().String(),
}
if _, ok := matcher.vulnCache.Load(pkg); !ok {
toQuery[&osvdev.Query{Package: pkg}] = struct{}{}
}
}
queries = slices.Collect(maps.Keys(toQuery))
}

if len(queries) == 0 {
return nil
}

var err error

// If there is a timeout for the initial query, set an additional context deadline here.
if matcher.InitialQueryTimeout > 0 {
batchQueryCtx, cancelFunc := context.WithDeadline(ctx, time.Now().Add(matcher.InitialQueryTimeout))
batchResp, err = queryForBatchWithPaging(batchQueryCtx, &matcher.Client, queries)
cancelFunc()
} else {
batchResp, err = queryForBatchWithPaging(ctx, &matcher.Client, queries)
}

if err != nil {
// Deadline being exceeded is likely caused by a long paging time
// if that's the case, we can should return what we already got, and
// then let the caller know it is not all the results.
if errors.Is(err, context.DeadlineExceeded) {
deadlineExceeded = true
} else {
return err
}
}

vulnerabilities := make([][]models.Vulnerability, len(batchResp.Results))
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(maxConcurrentRequests)

for batchIdx, resp := range batchResp.Results {
vulnerabilities[batchIdx] = make([]models.Vulnerability, len(resp.Vulns))
for resultIdx, vuln := range resp.Vulns {
g.Go(func() error {
// exit early if another hydration request has already failed
// results are thrown away later, so avoid needless work
if ctx.Err() != nil {
return nil //nolint:nilerr // this value doesn't matter to errgroup.Wait()
}
vuln, err := matcher.Client.GetVulnByID(ctx, vuln.ID)
if err != nil {
return err
}
vulnerabilities[batchIdx][resultIdx] = *vuln

return nil
})
}
}

if err := g.Wait(); err != nil {
return err
}

if deadlineExceeded {
return context.DeadlineExceeded
}

for i, vulns := range vulnerabilities {
matcher.vulnCache.Store(queries[i].Package, vulns)
}

return nil
}
14 changes: 10 additions & 4 deletions internal/remediation/in_place.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ import (
"deps.dev/util/resolve"
"deps.dev/util/resolve/dep"
"deps.dev/util/semver"
"github.com/google/osv-scanner/internal/clients/clientinterfaces"
"github.com/google/osv-scanner/internal/remediation/upgrade"
"github.com/google/osv-scanner/internal/resolution"
"github.com/google/osv-scanner/internal/resolution/client"
lf "github.com/google/osv-scanner/internal/resolution/lockfile"
"github.com/google/osv-scanner/internal/resolution/util"
"github.com/google/osv-scanner/internal/utility/vulns"
"github.com/google/osv-scanner/pkg/models"
"golang.org/x/exp/maps"
)

Expand Down Expand Up @@ -105,7 +107,7 @@ func (r InPlaceResult) VulnCount() VulnCount {
// ComputeInPlacePatches finds all possible targeting version changes that would fix vulnerabilities in a resolved graph.
// TODO: Check for introduced vulnerabilities
func ComputeInPlacePatches(ctx context.Context, cl client.ResolutionClient, graph *resolve.Graph, opts Options) (InPlaceResult, error) {
res, err := inPlaceVulnsNodes(cl, graph)
res, err := inPlaceVulnsNodes(ctx, cl, graph)
if err != nil {
return InPlaceResult{}, err
}
Expand Down Expand Up @@ -235,12 +237,16 @@ type inPlaceVulnsNodesResult struct {
vkNodes map[resolve.VersionKey][]resolve.NodeID
}

func inPlaceVulnsNodes(cl client.VulnerabilityClient, graph *resolve.Graph) (inPlaceVulnsNodesResult, error) {
nodeVulns, err := cl.FindVulns(graph)
func inPlaceVulnsNodes(ctx context.Context, m clientinterfaces.VulnerabilityMatcher, graph *resolve.Graph) (inPlaceVulnsNodesResult, error) {
nodeVulns, err := m.MatchVulnerabilities(ctx, client.GraphToInventory(graph))
if err != nil {
return inPlaceVulnsNodesResult{}, err
}

// GraphToInventory/MatchVulnerabilities excludes the root node of the graph.
// Prepend an element to nodeVulns so that the indices line up with graph.Nodes[i] <=> nodeVulns[i]
nodeVulns = append([][]*models.Vulnerability{nil}, nodeVulns...)

result := inPlaceVulnsNodesResult{
nodeDependencies: make(map[resolve.NodeID][]resolve.VersionKey),
vkVulns: make(map[resolve.VersionKey][]resolution.Vulnerability),
Expand Down Expand Up @@ -272,7 +278,7 @@ func inPlaceVulnsNodes(cl client.VulnerabilityClient, graph *resolve.Graph) (inP
result.vkNodes[vk] = append(result.vkNodes[vk], nID)
for _, vuln := range nodeVulns[nID] {
resVuln := resolution.Vulnerability{
OSV: vuln,
OSV: *vuln,
ProblemChains: slices.Clone(chains),
DevOnly: !slices.ContainsFunc(chains, func(dc resolution.DependencyChain) bool { return !resolution.ChainIsDev(dc, nil) }),
}
Expand Down
10 changes: 2 additions & 8 deletions internal/resolution/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,16 @@ import (
"deps.dev/util/resolve"
"deps.dev/util/resolve/dep"
"deps.dev/util/semver"
"github.com/google/osv-scanner/internal/clients/clientinterfaces"
"github.com/google/osv-scanner/internal/depsdev"
"github.com/google/osv-scanner/pkg/models"
"github.com/google/osv-scanner/pkg/osv"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

type ResolutionClient struct {
DependencyClient
VulnerabilityClient
}

type VulnerabilityClient interface {
// FindVulns finds the vulnerabilities affecting each of Nodes in the graph.
// The returned Vulnerabilities[i] corresponds to the vulnerabilities in g.Nodes[i].
FindVulns(g *resolve.Graph) ([]models.Vulnerabilities, error)
clientinterfaces.VulnerabilityMatcher
}

type DependencyClient interface {
Expand Down
46 changes: 46 additions & 0 deletions internal/resolution/client/helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package client

import (
"deps.dev/util/resolve"
"github.com/google/osv-scalibr/extractor"
"github.com/google/osv-scalibr/plugin"
"github.com/google/osv-scalibr/purl"
)

// GraphToInventory is a helper function to convert a Graph into an Inventory for use with VulnerabilityMatcher.
func GraphToInventory(g *resolve.Graph) []*extractor.Inventory {
// g.Nodes[0] is the root node of the graph that should be excluded.
inv := make([]*extractor.Inventory, len(g.Nodes)-1)
for i, n := range g.Nodes[1:] {
inv[i] = &extractor.Inventory{
Name: n.Version.Name,
Version: n.Version.Version,
Extractor: mockExtractor{n.Version.System},
}
}

return inv
}

// mockExtractor is for GraphToInventory to get the ecosystem.
type mockExtractor struct {
ecosystem resolve.System
}

func (e mockExtractor) Ecosystem(*extractor.Inventory) string {
switch e.ecosystem {
case resolve.NPM:
return "npm"
case resolve.Maven:
return "Maven"
case resolve.UnknownSystem:
return ""
default:
return ""
}
}

func (e mockExtractor) Name() string { return "" }
func (e mockExtractor) Requirements() *plugin.Capabilities { return nil }
func (e mockExtractor) ToPURL(*extractor.Inventory) *purl.PackageURL { return nil }
func (e mockExtractor) Version() int { return 0 }
Loading

0 comments on commit fe4eaea

Please sign in to comment.