-
Notifications
You must be signed in to change notification settings - Fork 373
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: use new
VulnerabilityMatcher
in guided remediation (#1503)
Following up on #1470 - Made `ResolutionClient` use the `VulnerabilityMatcher` interface (and added helper function to convert deps.dev graphs into inventories) - Deleted old `VulnerabilityClient` - Created `CachedOSVMatcher` to re-implement performance improvements from the original `VulnerabilityClient` w.r.t. repeated queries. - Re-enabled local database capability in `osv-scanner fix`
- Loading branch information
1 parent
6874aa6
commit fe4eaea
Showing
13 changed files
with
296 additions
and
248 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
151 changes: 151 additions & 0 deletions
151
internal/clients/clientimpl/osvmatcher/cachedosvmatcher.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
package osvmatcher | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"maps" | ||
"slices" | ||
"sync" | ||
"time" | ||
|
||
"github.com/google/osv-scalibr/extractor" | ||
"github.com/google/osv-scanner/internal/clients/clientimpl/localmatcher" | ||
"github.com/google/osv-scanner/internal/imodels" | ||
"github.com/google/osv-scanner/internal/osvdev" | ||
"github.com/google/osv-scanner/pkg/models" | ||
"golang.org/x/sync/errgroup" | ||
) | ||
|
||
// CachedOSVMatcher implements the VulnerabilityMatcher interface with a osv.dev client. | ||
// It sends out requests for every vulnerability of each package, which get cached. | ||
// Checking if a specific version matches an OSV record is done locally. | ||
// This should be used when we know the same packages are going to be repeatedly | ||
// queried multiple times, as in guided remediation. | ||
// TODO: This does not support commit-based queries. | ||
type CachedOSVMatcher struct { | ||
Client osvdev.OSVClient | ||
// InitialQueryTimeout allows you to set a timeout specifically for the initial paging query | ||
// If timeout runs out, whatever pages that has been successfully queried within the timeout will | ||
// still return fully hydrated. | ||
InitialQueryTimeout time.Duration | ||
|
||
vulnCache sync.Map // map[osvdev.Package][]models.Vulnerability | ||
} | ||
|
||
func (matcher *CachedOSVMatcher) MatchVulnerabilities(ctx context.Context, invs []*extractor.Inventory) ([][]*models.Vulnerability, error) { | ||
// populate vulnCache with missing packages | ||
if err := matcher.doQueries(ctx, invs); err != nil { | ||
return nil, err | ||
} | ||
|
||
results := make([][]*models.Vulnerability, len(invs)) | ||
|
||
for i, inv := range invs { | ||
if ctx.Err() != nil { | ||
return nil, ctx.Err() | ||
} | ||
|
||
pkgInfo := imodels.FromInventory(inv) | ||
pkg := osvdev.Package{ | ||
Name: pkgInfo.Name(), | ||
Ecosystem: pkgInfo.Ecosystem().String(), | ||
} | ||
vulns, ok := matcher.vulnCache.Load(pkg) | ||
if !ok { | ||
continue | ||
} | ||
results[i] = localmatcher.VulnerabilitiesAffectingPackage(vulns.([]models.Vulnerability), pkgInfo) | ||
} | ||
|
||
return results, nil | ||
} | ||
|
||
func (matcher *CachedOSVMatcher) doQueries(ctx context.Context, invs []*extractor.Inventory) error { | ||
var batchResp *osvdev.BatchedResponse | ||
deadlineExceeded := false | ||
|
||
var queries []*osvdev.Query | ||
{ | ||
// determine which packages aren't already cached | ||
// convert Inventory to Query for each pkgs element | ||
toQuery := make(map[*osvdev.Query]struct{}) | ||
for _, inv := range invs { | ||
pkgInfo := imodels.FromInventory(inv) | ||
if pkgInfo.Name() == "" || pkgInfo.Ecosystem().IsEmpty() { | ||
continue | ||
} | ||
pkg := osvdev.Package{ | ||
Name: pkgInfo.Name(), | ||
Ecosystem: pkgInfo.Ecosystem().String(), | ||
} | ||
if _, ok := matcher.vulnCache.Load(pkg); !ok { | ||
toQuery[&osvdev.Query{Package: pkg}] = struct{}{} | ||
} | ||
} | ||
queries = slices.Collect(maps.Keys(toQuery)) | ||
} | ||
|
||
if len(queries) == 0 { | ||
return nil | ||
} | ||
|
||
var err error | ||
|
||
// If there is a timeout for the initial query, set an additional context deadline here. | ||
if matcher.InitialQueryTimeout > 0 { | ||
batchQueryCtx, cancelFunc := context.WithDeadline(ctx, time.Now().Add(matcher.InitialQueryTimeout)) | ||
batchResp, err = queryForBatchWithPaging(batchQueryCtx, &matcher.Client, queries) | ||
cancelFunc() | ||
} else { | ||
batchResp, err = queryForBatchWithPaging(ctx, &matcher.Client, queries) | ||
} | ||
|
||
if err != nil { | ||
// Deadline being exceeded is likely caused by a long paging time | ||
// if that's the case, we can should return what we already got, and | ||
// then let the caller know it is not all the results. | ||
if errors.Is(err, context.DeadlineExceeded) { | ||
deadlineExceeded = true | ||
} else { | ||
return err | ||
} | ||
} | ||
|
||
vulnerabilities := make([][]models.Vulnerability, len(batchResp.Results)) | ||
g, ctx := errgroup.WithContext(ctx) | ||
g.SetLimit(maxConcurrentRequests) | ||
|
||
for batchIdx, resp := range batchResp.Results { | ||
vulnerabilities[batchIdx] = make([]models.Vulnerability, len(resp.Vulns)) | ||
for resultIdx, vuln := range resp.Vulns { | ||
g.Go(func() error { | ||
// exit early if another hydration request has already failed | ||
// results are thrown away later, so avoid needless work | ||
if ctx.Err() != nil { | ||
return nil //nolint:nilerr // this value doesn't matter to errgroup.Wait() | ||
} | ||
vuln, err := matcher.Client.GetVulnByID(ctx, vuln.ID) | ||
if err != nil { | ||
return err | ||
} | ||
vulnerabilities[batchIdx][resultIdx] = *vuln | ||
|
||
return nil | ||
}) | ||
} | ||
} | ||
|
||
if err := g.Wait(); err != nil { | ||
return err | ||
} | ||
|
||
if deadlineExceeded { | ||
return context.DeadlineExceeded | ||
} | ||
|
||
for i, vulns := range vulnerabilities { | ||
matcher.vulnCache.Store(queries[i].Package, vulns) | ||
} | ||
|
||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
package client | ||
|
||
import ( | ||
"deps.dev/util/resolve" | ||
"github.com/google/osv-scalibr/extractor" | ||
"github.com/google/osv-scalibr/plugin" | ||
"github.com/google/osv-scalibr/purl" | ||
) | ||
|
||
// GraphToInventory is a helper function to convert a Graph into an Inventory for use with VulnerabilityMatcher. | ||
func GraphToInventory(g *resolve.Graph) []*extractor.Inventory { | ||
// g.Nodes[0] is the root node of the graph that should be excluded. | ||
inv := make([]*extractor.Inventory, len(g.Nodes)-1) | ||
for i, n := range g.Nodes[1:] { | ||
inv[i] = &extractor.Inventory{ | ||
Name: n.Version.Name, | ||
Version: n.Version.Version, | ||
Extractor: mockExtractor{n.Version.System}, | ||
} | ||
} | ||
|
||
return inv | ||
} | ||
|
||
// mockExtractor is for GraphToInventory to get the ecosystem. | ||
type mockExtractor struct { | ||
ecosystem resolve.System | ||
} | ||
|
||
func (e mockExtractor) Ecosystem(*extractor.Inventory) string { | ||
switch e.ecosystem { | ||
case resolve.NPM: | ||
return "npm" | ||
case resolve.Maven: | ||
return "Maven" | ||
case resolve.UnknownSystem: | ||
return "" | ||
default: | ||
return "" | ||
} | ||
} | ||
|
||
func (e mockExtractor) Name() string { return "" } | ||
func (e mockExtractor) Requirements() *plugin.Capabilities { return nil } | ||
func (e mockExtractor) ToPURL(*extractor.Inventory) *purl.PackageURL { return nil } | ||
func (e mockExtractor) Version() int { return 0 } |
Oops, something went wrong.