Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use new VulnerabilityMatcher in guided remediation #1503

Merged
merged 6 commits into from
Jan 19, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
CachedOSVMatcher
michaelkedar committed Jan 16, 2025
commit 27718d3eece6d98367fd403cbfd161f53e96f4d1
15 changes: 11 additions & 4 deletions cmd/osv-scanner/fix/main.go
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
@@ -371,11 +372,12 @@ func action(ctx *cli.Context, stdout, stderr io.Writer) (reporter.Reporter, erro
}
}

userAgent := "osv-scanner_fix/" + version.OSVVersion
if ctx.Bool("experimental-offline-vulnerabilities") {
matcher, err := localmatcher.NewLocalMatcher(
r,
ctx.String("experimental-local-db-path"),
"osv-scanner_fix/"+version.OSVVersion,
userAgent,
ctx.Bool("experimental-download-offline-databases"),
)
if err != nil {
@@ -393,9 +395,14 @@ func action(ctx *cli.Context, stdout, stderr io.Writer) (reporter.Reporter, erro

opts.Client.VulnerabilityMatcher = matcher
} else {
// TODO: replace with cached client
opts.Client.VulnerabilityMatcher = &osvmatcher.OSVMatcher{
Client: *osvdev.DefaultClient(), // TODO: UserAgent
config := osvdev.DefaultConfig()
config.UserAgent = userAgent
opts.Client.VulnerabilityMatcher = &osvmatcher.CachedOSVMatcher{
Client: osvdev.OSVClient{
HTTPClient: http.DefaultClient,
Config: config,
BaseHostURL: osvdev.DefaultBaseURL,
},
InitialQueryTimeout: 5 * time.Minute,
}
}
2 changes: 1 addition & 1 deletion internal/clients/clientimpl/localmatcher/localmatcher.go
Original file line number Diff line number Diff line change
@@ -77,7 +77,7 @@ func (matcher *LocalMatcher) MatchVulnerabilities(ctx context.Context, invs []*e
continue
}

results = append(results, db.VulnerabilitiesAffectingPackage(pkg))
results = append(results, VulnerabilitiesAffectingPackage(db.Vulnerabilities(false), pkg))
}

return results, nil
8 changes: 5 additions & 3 deletions internal/clients/clientimpl/localmatcher/zip.go
Original file line number Diff line number Diff line change
@@ -235,7 +235,8 @@ func (db *ZipDB) Vulnerabilities(includeWithdrawn bool) []models.Vulnerability {
return vulnerabilities
}

func (db *ZipDB) VulnerabilitiesAffectingPackage(pkg imodels.PackageInfo) []*models.Vulnerability {
// TODO: Move this to another file.
func VulnerabilitiesAffectingPackage(allVulns []models.Vulnerability, pkg imodels.PackageInfo) []*models.Vulnerability {
var vulnerabilities []*models.Vulnerability

// TODO (V2 Models): remove this once PackageDetails has been migrated
@@ -248,7 +249,7 @@ func (db *ZipDB) VulnerabilitiesAffectingPackage(pkg imodels.PackageInfo) []*mod
DepGroups: pkg.DepGroups(),
}

for _, vulnerability := range db.Vulnerabilities(false) {
for _, vulnerability := range allVulns {
if vulns.IsAffected(vulnerability, mappedPackageDetails) && !vulns.Include(vulnerabilities, vulnerability) {
vulnerabilities = append(vulnerabilities, &vulnerability)
}
@@ -258,10 +259,11 @@ func (db *ZipDB) VulnerabilitiesAffectingPackage(pkg imodels.PackageInfo) []*mod
}

func (db *ZipDB) Check(pkgs []imodels.PackageInfo) ([]*models.Vulnerability, error) {
allVulns := db.Vulnerabilities(false)
vulnerabilities := make([]*models.Vulnerability, 0, len(pkgs))

for _, pkg := range pkgs {
vulnerabilities = append(vulnerabilities, db.VulnerabilitiesAffectingPackage(pkg)...)
vulnerabilities = append(vulnerabilities, VulnerabilitiesAffectingPackage(allVulns, pkg)...)
}

return vulnerabilities, nil
151 changes: 151 additions & 0 deletions internal/clients/clientimpl/osvmatcher/cachedosvmatcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package osvmatcher

import (
"context"
"errors"
"maps"
"slices"
"sync"
"time"

"github.com/google/osv-scalibr/extractor"
"github.com/google/osv-scanner/internal/clients/clientimpl/localmatcher"
"github.com/google/osv-scanner/internal/imodels"
"github.com/google/osv-scanner/internal/osvdev"
"github.com/google/osv-scanner/pkg/models"
"golang.org/x/sync/errgroup"
)

// CachedOSVMatcher implements the VulnerabilityMatcher interface with a osv.dev client.
// It sends out requests for every vulnerability of each package, which get cached.
// Checking if a specific version matches an OSV record is done locally.
// This should be used when we know the same packages are going to be repeatedly
// queried multiple times, as in guided remediation.
// TODO: This does not support commit-based queries.
type CachedOSVMatcher struct {
Client osvdev.OSVClient
// InitialQueryTimeout allows you to set a timeout specifically for the initial paging query
// If timeout runs out, whatever pages that has been successfully queried within the timeout will
// still return fully hydrated.
InitialQueryTimeout time.Duration

vulnCache sync.Map // map[osvdev.Package][]models.Vulnerability
}

func (matcher *CachedOSVMatcher) MatchVulnerabilities(ctx context.Context, invs []*extractor.Inventory) ([][]*models.Vulnerability, error) {
// populate vulnCache with missing packages
if err := matcher.doQueries(ctx, invs); err != nil {
return nil, err
}

results := make([][]*models.Vulnerability, len(invs))

for i, inv := range invs {
if ctx.Err() != nil {
return nil, ctx.Err()
}

pkgInfo := imodels.FromInventory(inv)
pkg := osvdev.Package{
Name: pkgInfo.Name(),
Ecosystem: pkgInfo.Ecosystem().String(),
}
vulns, ok := matcher.vulnCache.Load(pkg)
if !ok {
continue
}
results[i] = localmatcher.VulnerabilitiesAffectingPackage(vulns.([]models.Vulnerability), pkgInfo)
}

return results, nil
}

func (matcher *CachedOSVMatcher) doQueries(ctx context.Context, invs []*extractor.Inventory) error {
var batchResp *osvdev.BatchedResponse
deadlineExceeded := false

var queries []*osvdev.Query
{
// determine which packages aren't already cached
// convert Inventory to Query for each pkgs element
toQuery := make(map[*osvdev.Query]struct{})
for _, inv := range invs {
pkgInfo := imodels.FromInventory(inv)
if pkgInfo.Name() == "" || pkgInfo.Ecosystem().IsEmpty() {
continue
}
pkg := osvdev.Package{
Name: pkgInfo.Name(),
Ecosystem: pkgInfo.Ecosystem().String(),
}
if _, ok := matcher.vulnCache.Load(pkg); !ok {
toQuery[&osvdev.Query{Package: pkg}] = struct{}{}
}
}
queries = slices.Collect(maps.Keys(toQuery))
}

if len(queries) == 0 {
return nil
}

var err error

// If there is a timeout for the initial query, set an additional context deadline here.
if matcher.InitialQueryTimeout > 0 {
batchQueryCtx, cancelFunc := context.WithDeadline(ctx, time.Now().Add(matcher.InitialQueryTimeout))
batchResp, err = queryForBatchWithPaging(batchQueryCtx, &matcher.Client, queries)
cancelFunc()
} else {
batchResp, err = queryForBatchWithPaging(ctx, &matcher.Client, queries)
}

if err != nil {
// Deadline being exceeded is likely caused by a long paging time
// if that's the case, we can should return what we already got, and
// then let the caller know it is not all the results.
if errors.Is(err, context.DeadlineExceeded) {
deadlineExceeded = true
} else {
return err
}
}

vulnerabilities := make([][]models.Vulnerability, len(batchResp.Results))
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(maxConcurrentRequests)

for batchIdx, resp := range batchResp.Results {
vulnerabilities[batchIdx] = make([]models.Vulnerability, len(resp.Vulns))
for resultIdx, vuln := range resp.Vulns {
g.Go(func() error {
// exit early if another hydration request has already failed
// results are thrown away later, so avoid needless work
if ctx.Err() != nil {
return nil //nolint:nilerr // this value doesn't matter to errgroup.Wait()
}
vuln, err := matcher.Client.GetVulnByID(ctx, vuln.ID)
if err != nil {
return err
}
vulnerabilities[batchIdx][resultIdx] = *vuln

return nil
})
}
}

if err := g.Wait(); err != nil {
return err
}

if deadlineExceeded {
return context.DeadlineExceeded
}

for i, vulns := range vulnerabilities {
matcher.vulnCache.Store(queries[i].Package, vulns)
}

return nil
}
19 changes: 2 additions & 17 deletions internal/resolution/clienttest/mock_resolution_client.go
Original file line number Diff line number Diff line change
@@ -9,10 +9,9 @@ import (
"deps.dev/util/resolve"
"deps.dev/util/resolve/schema"
"github.com/google/osv-scalibr/extractor"
"github.com/google/osv-scanner/internal/clients/clientimpl/localmatcher"
"github.com/google/osv-scanner/internal/imodels"
"github.com/google/osv-scanner/internal/resolution/client"
"github.com/google/osv-scanner/internal/utility/vulns"
"github.com/google/osv-scanner/pkg/lockfile"
"github.com/google/osv-scanner/pkg/models"
"gopkg.in/yaml.v3"
)
@@ -28,21 +27,7 @@ type mockVulnerabilityMatcher []models.Vulnerability
func (mvc mockVulnerabilityMatcher) MatchVulnerabilities(_ context.Context, invs []*extractor.Inventory) ([][]*models.Vulnerability, error) {
result := make([][]*models.Vulnerability, len(invs))
for i, inv := range invs {
pkg := imodels.FromInventory(inv)
// TODO (V2 Models): remove this once PackageDetails has been migrated
mappedPackageDetails := lockfile.PackageDetails{
Name: pkg.Name(),
Version: pkg.Version(),
Commit: pkg.Commit(),
Ecosystem: lockfile.Ecosystem(pkg.Ecosystem().String()),
CompareAs: lockfile.Ecosystem(pkg.Ecosystem().String()),
DepGroups: pkg.DepGroups(),
}
for _, v := range mvc {
if vulns.IsAffected(v, mappedPackageDetails) {
result[i] = append(result[i], &v)
}
}
result[i] = localmatcher.VulnerabilitiesAffectingPackage(mvc, imodels.FromInventory(inv))
}

return result, nil
6 changes: 3 additions & 3 deletions scripts/generate_mock_resolution_universe/main.go
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ var remediationOpts = remediation.Options{

func doRelockRelax(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename string) error {
cl := client.ResolutionClient{
VulnerabilityMatcher: &osvmatcher.OSVMatcher{
VulnerabilityMatcher: &osvmatcher.CachedOSVMatcher{
Client: *osvdev.DefaultClient(),
InitialQueryTimeout: 5 * time.Minute,
},
@@ -83,7 +83,7 @@ func doRelockRelax(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename

func doOverride(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename string) error {
cl := client.ResolutionClient{
VulnerabilityMatcher: &osvmatcher.OSVMatcher{
VulnerabilityMatcher: &osvmatcher.CachedOSVMatcher{
Client: *osvdev.DefaultClient(),
InitialQueryTimeout: 5 * time.Minute,
},
@@ -113,7 +113,7 @@ func doOverride(ddCl *client.DepsDevClient, rw manifest.ReadWriter, filename str

func doInPlace(ddCl *client.DepsDevClient, rw lockfile.ReadWriter, filename string) error {
cl := client.ResolutionClient{
VulnerabilityMatcher: &osvmatcher.OSVMatcher{
VulnerabilityMatcher: &osvmatcher.CachedOSVMatcher{
Client: *osvdev.DefaultClient(),
InitialQueryTimeout: 5 * time.Minute,
},