diff --git a/go.mod b/go.mod index 220f341676..e3d17a6a95 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/go-git/go-git/v5 v5.12.0 github.com/google/go-cmp v0.6.0 github.com/google/go-containerregistry v0.20.2 - github.com/google/osv-scalibr v0.1.6-0.20241219225011-fd6877f0b783 + github.com/google/osv-scalibr v0.1.6-0.20250105222824-56e5c3bfb149 github.com/ianlancetaylor/demangle v0.0.0-20240912202439-0a2b6291aafd github.com/jedib0t/go-pretty/v6 v6.6.5 github.com/muesli/reflow v0.3.0 diff --git a/go.sum b/go.sum index 2947904abc..ffc9a39f33 100644 --- a/go.sum +++ b/go.sum @@ -123,8 +123,8 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-containerregistry v0.20.2 h1:B1wPJ1SN/S7pB+ZAimcciVD+r+yV/l/DSArMxlbwseo= github.com/google/go-containerregistry v0.20.2/go.mod h1:z38EKdKh4h7IP2gSfUUqEvalZBqs6AoLeWfUy34nQC8= -github.com/google/osv-scalibr v0.1.6-0.20241219225011-fd6877f0b783 h1:YzLIdmgxXdnYO0oGnS+i0s7kC3uwlVBZe53YIfvtrh4= -github.com/google/osv-scalibr v0.1.6-0.20241219225011-fd6877f0b783/go.mod h1:S8mrRjoWESAOOTq25lJqzxiKR6tbWSFYG8SVb5EFLHk= +github.com/google/osv-scalibr v0.1.6-0.20250105222824-56e5c3bfb149 h1:NR/j8m7lWb1V/izQi7oJlCZ5U/Z6GqM8hkoHghABdTQ= +github.com/google/osv-scalibr v0.1.6-0.20250105222824-56e5c3bfb149/go.mod h1:S8mrRjoWESAOOTq25lJqzxiKR6tbWSFYG8SVb5EFLHk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= diff --git a/internal/local/fixtures/db/file.json b/internal/clients/clientimpl/localmatcher/fixtures/db/file.json similarity index 100% rename from internal/local/fixtures/db/file.json rename to internal/clients/clientimpl/localmatcher/fixtures/db/file.json diff --git a/internal/local/fixtures/db/file.yaml b/internal/clients/clientimpl/localmatcher/fixtures/db/file.yaml similarity index 100% rename from internal/local/fixtures/db/file.yaml rename to internal/clients/clientimpl/localmatcher/fixtures/db/file.yaml diff --git a/internal/local/fixtures/db/nested-1/osv-1.json b/internal/clients/clientimpl/localmatcher/fixtures/db/nested-1/osv-1.json similarity index 100% rename from internal/local/fixtures/db/nested-1/osv-1.json rename to internal/clients/clientimpl/localmatcher/fixtures/db/nested-1/osv-1.json diff --git a/internal/local/fixtures/db/nested-2/invalid.json b/internal/clients/clientimpl/localmatcher/fixtures/db/nested-2/invalid.json similarity index 100% rename from internal/local/fixtures/db/nested-2/invalid.json rename to internal/clients/clientimpl/localmatcher/fixtures/db/nested-2/invalid.json diff --git a/internal/local/fixtures/db/nested-2/osv-2.json b/internal/clients/clientimpl/localmatcher/fixtures/db/nested-2/osv-2.json similarity index 100% rename from internal/local/fixtures/db/nested-2/osv-2.json rename to internal/clients/clientimpl/localmatcher/fixtures/db/nested-2/osv-2.json diff --git a/internal/clients/clientimpl/localmatcher/localmatcher.go b/internal/clients/clientimpl/localmatcher/localmatcher.go new file mode 100644 index 0000000000..69ce879fd5 --- /dev/null +++ b/internal/clients/clientimpl/localmatcher/localmatcher.go @@ -0,0 +1,164 @@ +package localmatcher + +import ( + "context" + "errors" + "fmt" + "os" + "path" + "slices" + "strings" + + "github.com/google/osv-scalibr/extractor" + "github.com/google/osv-scanner/internal/imodels" + "github.com/google/osv-scanner/internal/imodels/ecosystem" + "github.com/google/osv-scanner/pkg/models" + "github.com/google/osv-scanner/pkg/reporter" + "github.com/ossf/osv-schema/bindings/go/osvschema" +) + +const zippedDBRemoteHost = "https://osv-vulnerabilities.storage.googleapis.com" +const envKeyLocalDBCacheDirectory = "OSV_SCANNER_LOCAL_DB_CACHE_DIRECTORY" + +// LocalMatcher implements the VulnerabilityMatcher interface by downloading the osv export zip files, +// and performing the matching locally. +type LocalMatcher struct { + dbBasePath string + dbs map[osvschema.Ecosystem]*ZipDB + downloadDB bool + // TODO(v2 logging): Remove this reporter + r reporter.Reporter +} + +func NewLocalMatcher(r reporter.Reporter, localDBPath string, downloadDB bool) (*LocalMatcher, error) { + dbBasePath, err := setupLocalDBDirectory(localDBPath) + if err != nil { + return nil, fmt.Errorf("could not create %s: %w", dbBasePath, err) + } + + return &LocalMatcher{ + dbBasePath: dbBasePath, + dbs: make(map[osvschema.Ecosystem]*ZipDB), + downloadDB: downloadDB, + r: r, + }, nil +} + +func (matcher *LocalMatcher) Match(ctx context.Context, invs []*extractor.Inventory) ([][]*models.Vulnerability, error) { + results := make([][]*models.Vulnerability, 0, len(invs)) + + // slice to track ecosystems that did not have an offline database available + var missingDBs []string + + for _, inv := range invs { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + pkg := imodels.FromInventory(inv) + if pkg.Ecosystem.IsEmpty() { + if pkg.Commit == "" { + // This should never happen, as those results will be filtered out before matching + return nil, errors.New("ecosystem is empty and there is no commit hash") + } + + // Is a commit based query, skip local scanning + results = append(results, []*models.Vulnerability{}) + // TODO (V2 logging): + matcher.r.Infof("Skipping commit scanning for: %s\n", pkg.Commit) + + continue + } + + db, err := matcher.loadDBFromCache(ctx, pkg.Ecosystem) + + if err != nil { + if errors.Is(err, ErrOfflineDatabaseNotFound) { + missingDBs = append(missingDBs, string(pkg.Ecosystem.Ecosystem)) + } else { + // TODO(V2 logging): + // the most likely error at this point is that the PURL could not be parsed + matcher.r.Errorf("could not load db for %s ecosystem: %v\n", pkg.Ecosystem, err) + } + + results = append(results, []*models.Vulnerability{}) + + continue + } + + results = append(results, db.VulnerabilitiesAffectingPackage(pkg)) + } + + if len(missingDBs) > 0 { + missingDBs = slices.Compact(missingDBs) + slices.Sort(missingDBs) + + // TODO(v2 logging): + matcher.r.Errorf("could not find local databases for ecosystems: %s\n", strings.Join(missingDBs, ", ")) + } + + return results, nil +} + +func (matcher *LocalMatcher) loadDBFromCache(ctx context.Context, ecosystem ecosystem.Parsed) (*ZipDB, error) { + if db, ok := matcher.dbs[ecosystem.Ecosystem]; ok { + return db, nil + } + + db, err := NewZippedDB(ctx, matcher.dbBasePath, string(ecosystem.Ecosystem), fmt.Sprintf("%s/%s/all.zip", zippedDBRemoteHost, ecosystem.Ecosystem), !matcher.downloadDB) + + if err != nil { + return nil, err + } + + // TODO(v2 logging): Replace with slog / another logger + matcher.r.Infof("Loaded %s local db from %s\n", db.Name, db.StoredAt) + + matcher.dbs[ecosystem.Ecosystem] = db + + return db, nil +} + +// setupLocalDBDirectory attempts to set up the directory the scanner should +// use to store local databases. +// +// if a local path is explicitly provided either by the localDBPath parameter +// or via the envKeyLocalDBCacheDirectory environment variable, the scanner will +// attempt to use the user cache directory if possible or otherwise the temp directory +// +// if an error occurs at any point when a local path is not explicitly provided, +// the scanner will fall back to the temp directory first before finally erroring +func setupLocalDBDirectory(localDBPath string) (string, error) { + var err error + + // fallback to the env variable if a local database path has not been provided + if localDBPath == "" { + if p, envSet := os.LookupEnv(envKeyLocalDBCacheDirectory); envSet { + localDBPath = p + } + } + + implicitPath := localDBPath == "" + + // if we're implicitly picking a path, use the user cache directory if available + if implicitPath { + localDBPath, err = os.UserCacheDir() + + if err != nil { + localDBPath = os.TempDir() + } + } + + altPath := path.Join(localDBPath, "osv-scanner") + err = os.MkdirAll(altPath, 0750) + if err == nil { + return altPath, nil + } + + // if we're implicitly picking a path, try the temp directory before giving up + if implicitPath && localDBPath != os.TempDir() { + return setupLocalDBDirectory(os.TempDir()) + } + + return "", err +} diff --git a/internal/local/zip.go b/internal/clients/clientimpl/localmatcher/zip.go similarity index 77% rename from internal/local/zip.go rename to internal/clients/clientimpl/localmatcher/zip.go index 6d500fd447..fc3951b401 100644 --- a/internal/local/zip.go +++ b/internal/clients/clientimpl/localmatcher/zip.go @@ -1,4 +1,4 @@ -package local +package localmatcher import ( "archive/zip" @@ -16,6 +16,7 @@ import ( "path" "strings" + "github.com/google/osv-scanner/internal/imodels" "github.com/google/osv-scanner/internal/utility/vulns" "github.com/google/osv-scanner/pkg/lockfile" "github.com/google/osv-scanner/pkg/models" @@ -37,8 +38,8 @@ type ZipDB struct { var ErrOfflineDatabaseNotFound = errors.New("no offline version of the OSV database is available") -func fetchRemoteArchiveCRC32CHash(url string) (uint32, error) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodHead, url, nil) +func fetchRemoteArchiveCRC32CHash(ctx context.Context, url string) (uint32, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) if err != nil { return 0, err @@ -75,7 +76,7 @@ func fetchLocalArchiveCRC32CHash(data []byte) uint32 { return crc32.Checksum(data, crc32.MakeTable(crc32.Castagnoli)) } -func (db *ZipDB) fetchZip() ([]byte, error) { +func (db *ZipDB) fetchZip(ctx context.Context) ([]byte, error) { cache, err := os.ReadFile(db.StoredAt) if db.Offline { @@ -87,7 +88,7 @@ func (db *ZipDB) fetchZip() ([]byte, error) { } if err == nil { - remoteHash, err := fetchRemoteArchiveCRC32CHash(db.ArchiveURL) + remoteHash, err := fetchRemoteArchiveCRC32CHash(ctx, db.ArchiveURL) if err != nil { return nil, err @@ -98,7 +99,7 @@ func (db *ZipDB) fetchZip() ([]byte, error) { } } - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, db.ArchiveURL, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, db.ArchiveURL, nil) if err != nil { return nil, fmt.Errorf("could not retrieve OSV database archive: %w", err) @@ -176,10 +177,10 @@ func (db *ZipDB) loadZipFile(zipFile *zip.File) { // Internally, the archive is cached along with the date that it was fetched // so that a new version of the archive is only downloaded if it has been // modified, per HTTP caching standards. -func (db *ZipDB) load() error { +func (db *ZipDB) load(ctx context.Context) error { db.vulnerabilities = []models.Vulnerability{} - body, err := db.fetchZip() + body, err := db.fetchZip(ctx) if err != nil { return err @@ -202,14 +203,14 @@ func (db *ZipDB) load() error { return nil } -func NewZippedDB(dbBasePath, name, url string, offline bool) (*ZipDB, error) { +func NewZippedDB(ctx context.Context, dbBasePath, name, url string, offline bool) (*ZipDB, error) { db := &ZipDB{ Name: name, ArchiveURL: url, Offline: offline, StoredAt: path.Join(dbBasePath, name, "all.zip"), } - if err := db.load(); err != nil { + if err := db.load(ctx); err != nil { return nil, fmt.Errorf("unable to fetch OSV database: %w", err) } @@ -232,20 +233,30 @@ func (db *ZipDB) Vulnerabilities(includeWithdrawn bool) []models.Vulnerability { return vulnerabilities } -func (db *ZipDB) VulnerabilitiesAffectingPackage(pkg lockfile.PackageDetails) models.Vulnerabilities { - var vulnerabilities models.Vulnerabilities +func (db *ZipDB) VulnerabilitiesAffectingPackage(pkg imodels.PackageInfo) []*models.Vulnerability { + var vulnerabilities []*models.Vulnerability + + // TODO (V2 Models): remove this once PackageDetails has been migrated + mappedPackageDetails := lockfile.PackageDetails{ + Name: pkg.Name, + Version: pkg.Version, + Commit: pkg.Commit, + Ecosystem: lockfile.Ecosystem(pkg.Ecosystem.String()), + CompareAs: lockfile.Ecosystem(pkg.Ecosystem.String()), + DepGroups: pkg.DepGroups, + } for _, vulnerability := range db.Vulnerabilities(false) { - if vulns.IsAffected(vulnerability, pkg) && !vulns.Include(vulnerabilities, vulnerability) { - vulnerabilities = append(vulnerabilities, vulnerability) + if vulns.IsAffected(vulnerability, mappedPackageDetails) && !vulns.Include(vulnerabilities, vulnerability) { + vulnerabilities = append(vulnerabilities, &vulnerability) } } return vulnerabilities } -func (db *ZipDB) Check(pkgs []lockfile.PackageDetails) (models.Vulnerabilities, error) { - vulnerabilities := make(models.Vulnerabilities, 0, len(pkgs)) +func (db *ZipDB) Check(pkgs []imodels.PackageInfo) ([]*models.Vulnerability, error) { + vulnerabilities := make([]*models.Vulnerability, 0, len(pkgs)) for _, pkg := range pkgs { vulnerabilities = append(vulnerabilities, db.VulnerabilitiesAffectingPackage(pkg)...) diff --git a/internal/local/zip_test.go b/internal/clients/clientimpl/localmatcher/zip_test.go similarity index 87% rename from internal/local/zip_test.go rename to internal/clients/clientimpl/localmatcher/zip_test.go index 17169be6b9..63525a23f9 100644 --- a/internal/local/zip_test.go +++ b/internal/clients/clientimpl/localmatcher/zip_test.go @@ -1,8 +1,9 @@ -package local_test +package localmatcher_test import ( "archive/zip" "bytes" + "context" "encoding/base64" "encoding/binary" "encoding/json" @@ -16,7 +17,7 @@ import ( "sort" "testing" - "github.com/google/osv-scanner/internal/local" + "github.com/google/osv-scanner/internal/clients/clientimpl/localmatcher" "github.com/google/osv-scanner/internal/testutility" "github.com/google/osv-scanner/pkg/models" ) @@ -145,10 +146,10 @@ func TestNewZippedDB_Offline_WithoutCache(t *testing.T) { t.Errorf("a server request was made when running offline") }) - _, err := local.NewZippedDB(testDir, "my-db", ts.URL, true) + _, err := localmatcher.NewZippedDB(context.Background(), testDir, "my-db", ts.URL, true) - if !errors.Is(err, local.ErrOfflineDatabaseNotFound) { - t.Errorf("expected \"%v\" error but got \"%v\"", local.ErrOfflineDatabaseNotFound, err) + if !errors.Is(err, localmatcher.ErrOfflineDatabaseNotFound) { + t.Errorf("expected \"%v\" error but got \"%v\"", localmatcher.ErrOfflineDatabaseNotFound, err) } } @@ -177,7 +178,7 @@ func TestNewZippedDB_Offline_WithCache(t *testing.T) { "GHSA-5.json": {ID: "GHSA-5"}, })) - db, err := local.NewZippedDB(testDir, "my-db", ts.URL, true) + db, err := localmatcher.NewZippedDB(context.Background(), testDir, "my-db", ts.URL, true) if err != nil { t.Fatalf("unexpected error \"%v\"", err) @@ -195,7 +196,7 @@ func TestNewZippedDB_BadZip(t *testing.T) { _, _ = w.Write([]byte("this is not a zip")) }) - _, err := local.NewZippedDB(testDir, "my-db", ts.URL, false) + _, err := localmatcher.NewZippedDB(context.Background(), testDir, "my-db", ts.URL, false) if err == nil { t.Errorf("expected an error but did not get one") @@ -207,7 +208,7 @@ func TestNewZippedDB_UnsupportedProtocol(t *testing.T) { testDir := testutility.CreateTestDir(t) - _, err := local.NewZippedDB(testDir, "my-db", "file://hello-world", false) + _, err := localmatcher.NewZippedDB(context.Background(), testDir, "my-db", "file://hello-world", false) if err == nil { t.Errorf("expected an error but did not get one") @@ -237,7 +238,7 @@ func TestNewZippedDB_Online_WithoutCache(t *testing.T) { }) }) - db, err := local.NewZippedDB(testDir, "my-db", ts.URL, false) + db, err := localmatcher.NewZippedDB(context.Background(), testDir, "my-db", ts.URL, false) if err != nil { t.Fatalf("unexpected error \"%v\"", err) @@ -269,7 +270,7 @@ func TestNewZippedDB_Online_WithoutCacheAndNoHashHeader(t *testing.T) { })) }) - db, err := local.NewZippedDB(testDir, "my-db", ts.URL, false) + db, err := localmatcher.NewZippedDB(context.Background(), testDir, "my-db", ts.URL, false) if err != nil { t.Fatalf("unexpected error \"%v\"", err) @@ -307,7 +308,7 @@ func TestNewZippedDB_Online_WithSameCache(t *testing.T) { cacheWrite(t, determineStoredAtPath(testDir, "my-db"), cache) - db, err := local.NewZippedDB(testDir, "my-db", ts.URL, false) + db, err := localmatcher.NewZippedDB(context.Background(), testDir, "my-db", ts.URL, false) if err != nil { t.Fatalf("unexpected error \"%v\"", err) @@ -345,7 +346,7 @@ func TestNewZippedDB_Online_WithDifferentCache(t *testing.T) { "GHSA-3.json": {ID: "GHSA-3"}, })) - db, err := local.NewZippedDB(testDir, "my-db", ts.URL, false) + db, err := localmatcher.NewZippedDB(context.Background(), testDir, "my-db", ts.URL, false) if err != nil { t.Fatalf("unexpected error \"%v\"", err) @@ -375,7 +376,7 @@ func TestNewZippedDB_Online_WithCacheButNoHashHeader(t *testing.T) { "GHSA-3.json": {ID: "GHSA-3"}, })) - _, err := local.NewZippedDB(testDir, "my-db", ts.URL, false) + _, err := localmatcher.NewZippedDB(context.Background(), testDir, "my-db", ts.URL, false) if err == nil { t.Errorf("expected an error but did not get one") @@ -403,7 +404,7 @@ func TestNewZippedDB_Online_WithBadCache(t *testing.T) { cacheWriteBad(t, determineStoredAtPath(testDir, "my-db"), "this is not json!") - db, err := local.NewZippedDB(testDir, "my-db", ts.URL, false) + db, err := localmatcher.NewZippedDB(context.Background(), testDir, "my-db", ts.URL, false) if err != nil { t.Fatalf("unexpected error \"%v\"", err) @@ -429,7 +430,7 @@ func TestNewZippedDB_FileChecks(t *testing.T) { }) }) - db, err := local.NewZippedDB(testDir, "my-db", ts.URL, false) + db, err := localmatcher.NewZippedDB(context.Background(), testDir, "my-db", ts.URL, false) if err != nil { t.Fatalf("unexpected error \"%v\"", err) diff --git a/internal/clients/clientimpl/osvmatcher/errors.go b/internal/clients/clientimpl/osvmatcher/errors.go new file mode 100644 index 0000000000..7fa5262a3c --- /dev/null +++ b/internal/clients/clientimpl/osvmatcher/errors.go @@ -0,0 +1,16 @@ +package osvmatcher + +import "fmt" + +type DuringPagingError struct { + PageDepth int + Inner error +} + +func (e *DuringPagingError) Error() string { + return fmt.Sprintf("error during paging at depths %d - %s", e.PageDepth, e.Inner) +} + +func (e *DuringPagingError) Unwrap() error { + return e.Inner +} diff --git a/internal/clients/clientimpl/osvmatcher/osvmatcher.go b/internal/clients/clientimpl/osvmatcher/osvmatcher.go new file mode 100644 index 0000000000..8c67053221 --- /dev/null +++ b/internal/clients/clientimpl/osvmatcher/osvmatcher.go @@ -0,0 +1,218 @@ +package osvmatcher + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/google/osv-scalibr/extractor" + "github.com/google/osv-scalibr/log" + "github.com/google/osv-scanner/internal/imodels" + "github.com/google/osv-scanner/internal/osvdev" + "github.com/google/osv-scanner/internal/semantic" + "github.com/google/osv-scanner/pkg/models" + "github.com/ossf/osv-schema/bindings/go/osvschema" + "golang.org/x/sync/errgroup" +) + +const ( + maxConcurrentRequests = 1000 +) + +// OSVMatcher implements the VulnerabilityMatcher interface with a osv.dev client. +// It sends out requests for every package version and does not perform caching. +type OSVMatcher struct { + Client osvdev.OSVClient + // InitialQueryTimeout allows you to set a timeout specifically for the initial paging query + // If timeout runs out, whatever pages that has been successfully queried within the timeout will + // still return fully hydrated. + InitialQueryTimeout time.Duration +} + +func (matcher *OSVMatcher) Match(ctx context.Context, pkgs []*extractor.Inventory) ([][]*models.Vulnerability, error) { + var batchResp *osvdev.BatchedResponse + deadlineExceeded := false + + { + var err error + + // convert Inventory to Query for each pkgs element + queries := invsToQueries(pkgs) + // If there is a timeout for the initial query, set an additional context deadline here. + if matcher.InitialQueryTimeout > 0 { + batchQueryCtx, cancelFunc := context.WithDeadline(ctx, time.Now().Add(matcher.InitialQueryTimeout)) + batchResp, err = queryForBatchWithPaging(batchQueryCtx, &matcher.Client, queries) + cancelFunc() + } else { + batchResp, err = queryForBatchWithPaging(ctx, &matcher.Client, queries) + } + + if err != nil { + // Deadline being exceeded is likely caused by a long paging time + // if that's the case, we can should return what we already got, and + // then let the caller know it is not all the results. + if errors.Is(err, context.DeadlineExceeded) { + deadlineExceeded = true + } else { + return nil, err + } + } + } + + vulnerabilities := make([][]*models.Vulnerability, len(batchResp.Results)) + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(maxConcurrentRequests) + + for batchIdx, resp := range batchResp.Results { + vulnerabilities[batchIdx] = make([]*models.Vulnerability, len(resp.Vulns)) + for resultIdx, vuln := range resp.Vulns { + g.Go(func() error { + // exit early if another hydration request has already failed + // results are thrown away later, so avoid needless work + if ctx.Err() != nil { + return nil //nolint:nilerr // this value doesn't matter to errgroup.Wait() + } + vuln, err := matcher.Client.GetVulnByID(ctx, vuln.ID) + if err != nil { + return err + } + vulnerabilities[batchIdx][resultIdx] = vuln + + return nil + }) + } + } + + if err := g.Wait(); err != nil { + return nil, err + } + + if deadlineExceeded { + return vulnerabilities, context.DeadlineExceeded + } + + return vulnerabilities, nil +} + +func queryForBatchWithPaging(ctx context.Context, c *osvdev.OSVClient, queries []*osvdev.Query) (*osvdev.BatchedResponse, error) { + batchResp, err := c.QueryBatch(ctx, queries) + + if err != nil { + return nil, err + } + // --- Paging logic --- + var errToReturn error + nextPageQueries := []*osvdev.Query{} + nextPageIndexMap := []int{} + for i, res := range batchResp.Results { + if res.NextPageToken == "" { + continue + } + + query := *queries[i] + query.PageToken = res.NextPageToken + nextPageQueries = append(nextPageQueries, &query) + nextPageIndexMap = append(nextPageIndexMap, i) + } + + if len(nextPageQueries) > 0 { + // If context is cancelled or deadline exceeded, return now + if ctx.Err() != nil { + return batchResp, &DuringPagingError{ + PageDepth: 1, + Inner: ctx.Err(), + } + } + + nextPageResp, err := c.QueryBatch(ctx, nextPageQueries) + if err != nil { + var dpr *DuringPagingError + if ok := errors.As(err, &dpr); ok { + dpr.PageDepth += 1 + errToReturn = dpr + } else { + errToReturn = &DuringPagingError{ + PageDepth: 1, + Inner: err, + } + } + } + + // Whether there is an error or not, if there is any data, + // we want to save and return what we got. + if nextPageResp != nil { + for i, res := range nextPageResp.Results { + batchResp.Results[nextPageIndexMap[i]].Vulns = append(batchResp.Results[nextPageIndexMap[i]].Vulns, res.Vulns...) + // Set next page token so caller knows whether this is all of the results + // even if it is being cancelled. + batchResp.Results[nextPageIndexMap[i]].NextPageToken = res.NextPageToken + } + } + } + + return batchResp, errToReturn +} + +func pkgToQuery(pkg imodels.PackageInfo) *osvdev.Query { + if pkg.Name != "" && !pkg.Ecosystem.IsEmpty() && pkg.Version != "" { + return &osvdev.Query{ + Package: osvdev.Package{ + Name: pkg.Name, + Ecosystem: pkg.Ecosystem.String(), + }, + Version: pkg.Version, + } + } + + if pkg.Commit != "" { + return &osvdev.Query{ + Commit: pkg.Commit, + } + } + + // This should have be filtered out before reaching this point + log.Errorf("invalid query element: %#v", pkg) + + return nil +} + +// invsToQueries converts inventories to queries via the osv-scanner internal imodels +// to perform the necessary transformations +func invsToQueries(invs []*extractor.Inventory) []*osvdev.Query { + queries := make([]*osvdev.Query, len(invs)) + + for i, inv := range invs { + pkg := imodels.FromInventory(inv) + pkg = patchPackageForRequest(pkg) + queries[i] = pkgToQuery(pkg) + } + + return queries +} + +// patchPackageForRequest modifies packages before they are sent to osv.dev to +// account for edge cases. +func patchPackageForRequest(pkg imodels.PackageInfo) imodels.PackageInfo { + // Assume Go stdlib patch version as the latest version + // + // This is done because go1.20 and earlier do not support patch + // version in go.mod file, and will fail to build. + // + // However, if we assume patch version as .0, this will cause a lot of + // false positives. This compromise still allows osv-scanner to pick up + // when the user is using a minor version that is out-of-support. + if pkg.Name == "stdlib" && pkg.Ecosystem.Ecosystem == osvschema.EcosystemGo { + v := semantic.ParseSemverLikeVersion(pkg.Version, 3) + if len(v.Components) == 2 { + pkg.Version = fmt.Sprintf( + "%d.%d.%d", + v.Components.Fetch(0), + v.Components.Fetch(1), + 9999, + ) + } + } + + return pkg +} diff --git a/internal/clients/clientinterfaces/vulnerabilitymatcher.go b/internal/clients/clientinterfaces/vulnerabilitymatcher.go new file mode 100644 index 0000000000..68f16ab17d --- /dev/null +++ b/internal/clients/clientinterfaces/vulnerabilitymatcher.go @@ -0,0 +1,12 @@ +package clientinterfaces + +import ( + "context" + + "github.com/google/osv-scalibr/extractor" + "github.com/google/osv-scanner/pkg/models" +) + +type VulnerabilityMatcher interface { + Match(ctx context.Context, invs []*extractor.Inventory) ([][]*models.Vulnerability, error) +} diff --git a/internal/imodels/imodels.go b/internal/imodels/imodels.go index 2453005a59..e936ac010d 100644 --- a/internal/imodels/imodels.go +++ b/internal/imodels/imodels.go @@ -53,6 +53,8 @@ type PackageInfo struct { OSPackageName string AdditionalLocations []string // Contains Inventory.Locations[1..] + + OriginalInventory *extractor.Inventory } // FromInventory converts an extractor.Inventory into a PackageInfo. @@ -65,6 +67,7 @@ func FromInventory(inventory *extractor.Inventory) PackageInfo { Version: inventory.Version, Location: inventory.Locations[0], AdditionalLocations: inventory.Locations[1:], + OriginalInventory: inventory, // TODO: SourceType } @@ -136,7 +139,7 @@ func FromInventory(inventory *extractor.Inventory) PackageInfo { type PackageScanResult struct { PackageInfo PackageInfo // TODO: Use osvschema.Vulnerability instead - Vulnerabilities []models.Vulnerability + Vulnerabilities []*models.Vulnerability Licenses []models.License ImageOriginLayerID string diff --git a/internal/local/check.go b/internal/local/check.go deleted file mode 100644 index b8c2b4e95c..0000000000 --- a/internal/local/check.go +++ /dev/null @@ -1,174 +0,0 @@ -package local - -import ( - "errors" - "fmt" - "os" - "path" - "slices" - "strings" - - "github.com/google/osv-scanner/pkg/lockfile" - "github.com/google/osv-scanner/pkg/models" - "github.com/google/osv-scanner/pkg/osv" - "github.com/google/osv-scanner/pkg/reporter" -) - -const zippedDBRemoteHost = "https://osv-vulnerabilities.storage.googleapis.com" -const envKeyLocalDBCacheDirectory = "OSV_SCANNER_LOCAL_DB_CACHE_DIRECTORY" - -func loadDB(dbBasePath string, ecosystem lockfile.Ecosystem, offline bool) (*ZipDB, error) { - return NewZippedDB(dbBasePath, string(ecosystem), fmt.Sprintf("%s/%s/all.zip", zippedDBRemoteHost, ecosystem), offline) -} - -func toPackageDetails(query *osv.Query) (lockfile.PackageDetails, error) { - if query.Package.PURL != "" { - pkg, err := models.PURLToPackage(query.Package.PURL) - - if err != nil { - return lockfile.PackageDetails{}, err - } - - return lockfile.PackageDetails{ - Name: pkg.Name, - Version: pkg.Version, - Ecosystem: lockfile.Ecosystem(pkg.Ecosystem), - CompareAs: lockfile.Ecosystem(pkg.Ecosystem), - }, nil - } - - return lockfile.PackageDetails{ - Name: query.Package.Name, - Version: query.Version, - Commit: query.Commit, - Ecosystem: lockfile.Ecosystem(query.Package.Ecosystem), - CompareAs: lockfile.Ecosystem(query.Package.Ecosystem), - }, nil -} - -// setupLocalDBDirectory attempts to set up the directory the scanner should -// use to store local databases. -// -// if a local path is explicitly provided either by the localDBPath parameter -// or via the envKeyLocalDBCacheDirectory environment variable, the scanner will -// attempt to use the user cache directory if possible or otherwise the temp directory -// -// if an error occurs at any point when a local path is not explicitly provided, -// the scanner will fall back to the temp directory first before finally erroring -func setupLocalDBDirectory(localDBPath string) (string, error) { - var err error - - // fallback to the env variable if a local database path has not been provided - if localDBPath == "" { - if p, envSet := os.LookupEnv(envKeyLocalDBCacheDirectory); envSet { - localDBPath = p - } - } - - implicitPath := localDBPath == "" - - // if we're implicitly picking a path, use the user cache directory if available - if implicitPath { - localDBPath, err = os.UserCacheDir() - - if err != nil { - localDBPath = os.TempDir() - } - } - - altPath := path.Join(localDBPath, "osv-scanner") - err = os.MkdirAll(altPath, 0750) - if err == nil { - return altPath, nil - } - - // if we're implicitly picking a path, try the temp directory before giving up - if implicitPath && localDBPath != os.TempDir() { - return setupLocalDBDirectory(os.TempDir()) - } - - return "", err -} - -func MakeRequest(r reporter.Reporter, query osv.BatchedQuery, offline bool, localDBPath string) (*osv.HydratedBatchedResponse, error) { - results := make([]osv.Response, 0, len(query.Queries)) - dbs := make(map[lockfile.Ecosystem]*ZipDB) - - dbBasePath, err := setupLocalDBDirectory(localDBPath) - - if err != nil { - return &osv.HydratedBatchedResponse{}, fmt.Errorf("could not create %s: %w", dbBasePath, err) - } - - loadDBFromCache := func(ecosystem lockfile.Ecosystem) (*ZipDB, error) { - if db, ok := dbs[ecosystem]; ok { - return db, nil - } - - db, err := loadDB(dbBasePath, ecosystem, offline) - - if err != nil { - return nil, err - } - - r.Infof("Loaded %s local db from %s\n", db.Name, db.StoredAt) - - dbs[ecosystem] = db - - return db, nil - } - - // slice to track ecosystems that did not have an offline database available - var missingDbs []string - - for _, query := range query.Queries { - pkg, err := toPackageDetails(query) - - if err != nil { - // currently, this will actually only error if the PURL cannot be parses - r.Errorf("skipping %s as it is not a valid PURL: %v\n", query.Package.PURL, err) - results = append(results, osv.Response{Vulns: []models.Vulnerability{}}) - - continue - } - - if pkg.Ecosystem == "" { - if pkg.Commit == "" { - // The only time this can happen should be when someone passes in their own OSV-Scanner-Results file. - return nil, errors.New("ecosystem is empty and there is no commit hash") - } - - // Is a commit based query, skip local scanning - results = append(results, osv.Response{}) - r.Infof("Skipping commit scanning for: %s\n", pkg.Commit) - - continue - } - - db, err := loadDBFromCache(pkg.Ecosystem) - - if err != nil { - if errors.Is(err, ErrOfflineDatabaseNotFound) { - missingDbs = append(missingDbs, string(pkg.Ecosystem)) - } else { - // the most likely error at this point is that the PURL could not be parsed - r.Errorf("could not load db for %s ecosystem: %v\n", pkg.Ecosystem, err) - } - - results = append(results, osv.Response{Vulns: []models.Vulnerability{}}) - - continue - } - - results = append(results, osv.Response{Vulns: db.VulnerabilitiesAffectingPackage(pkg)}) - } - - if len(missingDbs) > 0 { - missingDbs = slices.Compact(missingDbs) - slices.Sort(missingDbs) - - r.Errorf("could not find local databases for ecosystems: %s\n", strings.Join(missingDbs, ", ")) - } - - return &osv.HydratedBatchedResponse{Results: results}, nil -} diff --git a/internal/osvdev/config.go b/internal/osvdev/config.go index 0a8fd76397..3d73fdaf09 100644 --- a/internal/osvdev/config.go +++ b/internal/osvdev/config.go @@ -3,9 +3,7 @@ package osvdev import "github.com/google/osv-scanner/internal/version" type ClientConfig struct { - MaxConcurrentRequests int MaxConcurrentBatchRequests int - MaxRetryAttempts int JitterMultiplier float64 BackoffDurationExponential float64 @@ -21,7 +19,6 @@ func DefaultConfig() ClientConfig { BackoffDurationExponential: 2, BackoffDurationMultiplier: 1, UserAgent: "osv-scanner/" + version.OSVVersion, - MaxConcurrentRequests: 1000, MaxConcurrentBatchRequests: 10, } } diff --git a/internal/osvdev/models.go b/internal/osvdev/models.go index 7f4563cb27..5839a9ef18 100644 --- a/internal/osvdev/models.go +++ b/internal/osvdev/models.go @@ -11,9 +11,10 @@ type Package struct { // Query represents a query to OSV. type Query struct { - Commit string `json:"commit,omitempty"` - Package Package `json:"package,omitempty"` - Version string `json:"version,omitempty"` + Commit string `json:"commit,omitempty"` + Package Package `json:"package,omitempty"` + Version string `json:"version,omitempty"` + PageToken string `json:"page_token,omitempty"` } // BatchedQuery represents a batched query to OSV. @@ -28,12 +29,14 @@ type MinimalVulnerability struct { // Response represents a full response from OSV. type Response struct { - Vulns []models.Vulnerability `json:"vulns"` + Vulns []models.Vulnerability `json:"vulns"` + NextPageToken string `json:"next_page_token"` } // MinimalResponse represents an unhydrated response from OSV. type MinimalResponse struct { - Vulns []MinimalVulnerability `json:"vulns"` + Vulns []MinimalVulnerability `json:"vulns"` + NextPageToken string `json:"next_page_token"` } // BatchedResponse represents an unhydrated batched response from OSV. diff --git a/internal/osvdev/osvdev.go b/internal/osvdev/osvdev.go index 51efcd472c..6d0074bf64 100644 --- a/internal/osvdev/osvdev.go +++ b/internal/osvdev/osvdev.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "math" @@ -42,8 +43,8 @@ func DefaultClient() *OSVClient { } } -// GetVulnsByID is an interface to this endpoint: https://google.github.io/osv.dev/get-v1-vulns/ -func (c *OSVClient) GetVulnsByID(ctx context.Context, id string) (*models.Vulnerability, error) { +// GetVulnByID is an interface to this endpoint: https://google.github.io/osv.dev/get-v1-vulns/ +func (c *OSVClient) GetVulnByID(ctx context.Context, id string) (*models.Vulnerability, error) { resp, err := c.makeRetryRequest(func(client *http.Client) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.BaseHostURL+GetEndpoint+"/"+id, nil) if err != nil { @@ -73,12 +74,16 @@ func (c *OSVClient) GetVulnsByID(ctx context.Context, id string) (*models.Vulner } // QueryBatch is an interface to this endpoint: https://google.github.io/osv.dev/post-v1-querybatch/ +// This function performs paging invisibly until the context expires, after which all pages that has already +// been retrieved are returned. +// +// See if next_page_token field in the response is fully filled out to determine if there are extra pages remaining func (c *OSVClient) QueryBatch(ctx context.Context, queries []*Query) (*BatchedResponse, error) { // API has a limit of how many queries are in one batch queryChunks := chunkBy(queries, MaxQueriesPerQueryBatchRequest) totalOsvRespBatched := make([][]MinimalResponse, len(queryChunks)) - g, ctx := errgroup.WithContext(ctx) + g, errGrpCtx := errgroup.WithContext(ctx) g.SetLimit(c.Config.MaxConcurrentBatchRequests) for batchIndex, queries := range queryChunks { requestBytes, err := json.Marshal(BatchedQuery{Queries: queries}) @@ -89,7 +94,7 @@ func (c *OSVClient) QueryBatch(ctx context.Context, queries []*Query) (*BatchedR g.Go(func() error { // exit early if another hydration request has already failed // results are thrown away later, so avoid needless work - if ctx.Err() != nil { + if errGrpCtx.Err() != nil { return nil } @@ -97,7 +102,7 @@ func (c *OSVClient) QueryBatch(ctx context.Context, queries []*Query) (*BatchedR // Make sure request buffer is inside retry, if outside // http request would finish the buffer, and retried requests would be empty requestBuf := bytes.NewBuffer(requestBytes) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.BaseHostURL+QueryBatchEndpoint, requestBuf) + req, err := http.NewRequestWithContext(errGrpCtx, http.MethodPost, c.BaseHostURL+QueryBatchEndpoint, requestBuf) if err != nil { return nil, err } @@ -142,6 +147,10 @@ func (c *OSVClient) QueryBatch(ctx context.Context, queries []*Query) (*BatchedR } // Query is an interface to this endpoint: https://google.github.io/osv.dev/post-v1-query/ +// This function performs paging invisibly until the context expires, after which all pages that has already +// been retrieved are returned. +// +// See if next_page_token field in the response is fully filled out to determine if there are extra pages remaining func (c *OSVClient) Query(ctx context.Context, query *Query) (*Response, error) { requestBytes, err := json.Marshal(query) if err != nil { @@ -232,6 +241,11 @@ func (c *OSVClient) makeRetryRequest(action func(client *http.Client) (*http.Res resp, err = action(c.HTTPClient) + // Don't retry, since deadline has already been exceeded + if errors.Is(err, context.DeadlineExceeded) { + return nil, err + } + // The network request itself failed, did not even get a response if err != nil { lastErr = fmt.Errorf("attempt %d: request failed: %w", i+1, err) diff --git a/internal/osvdev/osvdev_test.go b/internal/osvdev/osvdev_test.go index 8069e23b22..0cb7059a80 100644 --- a/internal/osvdev/osvdev_test.go +++ b/internal/osvdev/osvdev_test.go @@ -3,11 +3,13 @@ package osvdev_test import ( "context" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/osv-scalibr/testing/extracttest" "github.com/google/osv-scanner/internal/osvdev" + "github.com/google/osv-scanner/internal/testutility" "github.com/ossf/osv-schema/bindings/go/osvschema" ) @@ -45,7 +47,7 @@ func TestOSVClient_GetVulnsByID(t *testing.T) { c := osvdev.DefaultClient() c.Config.UserAgent = "osv-scanner-api-test" - got, err := c.GetVulnsByID(context.Background(), tt.id) + got, err := c.GetVulnByID(context.Background(), tt.id) if diff := cmp.Diff(tt.wantErr, err, cmpopts.EquateErrors()); diff != "" { t.Fatalf("Unexpected error (-want +got):\n%s", diff) @@ -159,6 +161,75 @@ func TestOSVClient_QueryBatch(t *testing.T) { } } +func TestOSVClient_QueryBatchDeadline(t *testing.T) { + t.Parallel() + testutility.SkipIfNotAcceptanceTesting(t, "Takes a long time to run") + + tests := []struct { + name string + queries []*osvdev.Query + wantIDs [][]string + wantErr error + }{ + { + name: "linux package lookup", + queries: []*osvdev.Query{ + { + Commit: "60e572dbf7b4ded66b488f54773f66aaf6184321", + }, + { + Package: osvdev.Package{ + Name: "linux", + Ecosystem: "Ubuntu:22.04:LTS", + }, + Version: "5.15.0-17.17", + }, + { + Package: osvdev.Package{ + Name: "abcd-definitely-does-not-exist", + Ecosystem: string(osvschema.EcosystemNPM), + }, + Version: "1.0.0", + }, + }, + wantErr: context.DeadlineExceeded, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + c := osvdev.DefaultClient() + c.Config.UserAgent = "osv-scanner-api-test" + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*1)) + + got, err := c.QueryBatch(ctx, tt.queries) + cancel() + if diff := cmp.Diff(tt.wantErr, err, cmpopts.EquateErrors()); diff != "" { + t.Fatalf("Unexpected error (-want +got):\n%s", diff) + } + + if err != nil { + return + } + + gotResults := make([][]string, 0, len(got.Results)) + for _, res := range got.Results { + gotVulnIDs := make([]string, 0, len(res.Vulns)) + for _, vuln := range res.Vulns { + gotVulnIDs = append(gotVulnIDs, vuln.ID) + } + gotResults = append(gotResults, gotVulnIDs) + } + + if diff := cmp.Diff(tt.wantIDs, gotResults); diff != "" { + t.Errorf("Unexpected vuln IDs (-want +got):\n%s", diff) + } + }) + } +} + func TestOSVClient_Query(t *testing.T) { t.Parallel() @@ -211,7 +282,8 @@ func TestOSVClient_Query(t *testing.T) { }, wantErr: extracttest.ContainsErrStr{ Str: `client error: status="400 Bad Request" body={"code":3,"message":"Invalid query."}`, - }}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -242,6 +314,59 @@ func TestOSVClient_Query(t *testing.T) { } } +func TestOSVClient_QueryDeadline(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + query osvdev.Query + wantIDs []string + wantErr error + }{ + { + name: "linux Package lookup", + query: osvdev.Query{ + Package: osvdev.Package{ + // Use a deleted package as it is less likely new vulns will be published for it + Name: "linux", + Ecosystem: "Ubuntu:22.04:LTS", + }, + Version: "5.15.0-17.17", + }, + wantErr: context.DeadlineExceeded, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + c := osvdev.DefaultClient() + c.Config.UserAgent = "osv-scanner-api-test" + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*1)) + got, err := c.Query(ctx, &tt.query) + cancel() + + if diff := cmp.Diff(tt.wantErr, err, cmpopts.EquateErrors()); diff != "" { + t.Fatalf("Unexpected error (-want +got):\n%s", diff) + } + + if err != nil { + return + } + + gotVulnIDs := make([]string, 0, len(got.Vulns)) + for _, vuln := range got.Vulns { + gotVulnIDs = append(gotVulnIDs, vuln.ID) + } + + if diff := cmp.Diff(tt.wantIDs, gotVulnIDs); diff != "" { + t.Errorf("Unexpected vuln IDs (-want +got):\n%s", diff) + } + }) + } +} + func TestOSVClient_ExperimentalDetermineVersion(t *testing.T) { t.Parallel() diff --git a/internal/resolution/client/osv_offline_client.go b/internal/resolution/client/osv_offline_client.go index 0779ac0c81..96505f143f 100644 --- a/internal/resolution/client/osv_offline_client.go +++ b/internal/resolution/client/osv_offline_client.go @@ -1,3 +1,4 @@ +//nolint:unused,revive package client import ( @@ -6,10 +7,7 @@ import ( "strings" "deps.dev/util/resolve" - "github.com/google/osv-scanner/internal/local" - "github.com/google/osv-scanner/internal/resolution/util" "github.com/google/osv-scanner/pkg/models" - "github.com/google/osv-scanner/pkg/osv" "github.com/google/osv-scanner/pkg/reporter" ) @@ -22,59 +20,63 @@ func NewOSVOfflineClient(r reporter.Reporter, system resolve.System, downloadDBs if system == resolve.UnknownSystem { return nil, errors.New("osv offline client created with unknown ecosystem") } - // Make a dummy request to the local client to log and make sure the database is downloaded without error. - q := osv.BatchedQuery{Queries: []*osv.Query{{ - Package: osv.Package{ - Name: "foo", - Ecosystem: string(util.OSVEcosystem[system]), - }, - Version: "1.0.0", - }}} - _, err := local.MakeRequest(r, q, !downloadDBs, localDBPath) - if err != nil { - return nil, err - } - - if r.HasErrored() { - return nil, errors.New("error creating osv offline client") - } - return &OSVOfflineClient{localDBPath: localDBPath}, nil + panic("TODO: Temporarily disabled") + + // // Make a dummy request to the local client to log and make sure the database is downloaded without error. + // q := osv.BatchedQuery{Queries: []*osv.Query{{ + // Package: osv.Package{ + // Name: "foo", + // Ecosystem: string(util.OSVEcosystem[system]), + // }, + // Version: "1.0.0", + // }}} + // _, err := local.MakeRequest(r, q, !downloadDBs, localDBPath) + // if err != nil { + // return nil, err + // } + + // if r.HasErrored() { + // return nil, errors.New("error creating osv offline client") + // } + + // return &OSVOfflineClient{localDBPath: localDBPath}, nil } func (c *OSVOfflineClient) FindVulns(g *resolve.Graph) ([]models.Vulnerabilities, error) { - var query osv.BatchedQuery - query.Queries = make([]*osv.Query, len(g.Nodes)-1) - for i, node := range g.Nodes[1:] { - query.Queries[i] = &osv.Query{ - Package: osv.Package{ - Name: node.Version.Name, - Ecosystem: string(util.OSVEcosystem[node.Version.System]), - }, - Version: node.Version.Version, - } - } - - // If local.MakeRequest logs an error, it's probably fatal for guided remediation. - // Set up a reporter to capture error logs and return the logs as an error. - r := &errorReporter{} - // DB should already be downloaded, set offline to true. - hydrated, err := local.MakeRequest(r, query, true, c.localDBPath) - - if err != nil { - return nil, err - } - - if r.HasErrored() { - return nil, r.GetError() - } - - nodeVulns := make([]models.Vulnerabilities, len(g.Nodes)) - for i, res := range hydrated.Results { - nodeVulns[i+1] = res.Vulns - } - - return nodeVulns, nil + return []models.Vulnerabilities{}, nil + // var query osv.BatchedQuery + // query.Queries = make([]*osv.Query, len(g.Nodes)-1) + // for i, node := range g.Nodes[1:] { + // query.Queries[i] = &osv.Query{ + // Package: osv.Package{ + // Name: node.Version.Name, + // Ecosystem: string(util.OSVEcosystem[node.Version.System]), + // }, + // Version: node.Version.Version, + // } + // } + + // // If local.MakeRequest logs an error, it's probably fatal for guided remediation. + // // Set up a reporter to capture error logs and return the logs as an error. + // r := &errorReporter{} + // // DB should already be downloaded, set offline to true. + // hydrated, err := local.MakeRequest(r, query, true, c.localDBPath) + + // if err != nil { + // return nil, err + // } + + // if r.HasErrored() { + // return nil, r.GetError() + // } + + // nodeVulns := make([]models.Vulnerabilities, len(g.Nodes)) + // for i, res := range hydrated.Results { + // nodeVulns[i+1] = res.Vulns + // } + + // return nodeVulns, nil } // errorReporter is a reporter.Reporter to capture error logs and pack them into an error. diff --git a/internal/utility/vulns/vulnerabilities.go b/internal/utility/vulns/vulnerabilities.go index 5e45e2df82..0c9539e6b8 100644 --- a/internal/utility/vulns/vulnerabilities.go +++ b/internal/utility/vulns/vulnerabilities.go @@ -2,16 +2,16 @@ package vulns import "github.com/google/osv-scanner/pkg/models" -func Include(vs models.Vulnerabilities, vulnerability models.Vulnerability) bool { +func Include(vs []*models.Vulnerability, vulnerability models.Vulnerability) bool { for _, vuln := range vs { if vuln.ID == vulnerability.ID { return true } - if isAliasOf(vuln, vulnerability) { + if isAliasOf(*vuln, vulnerability) { return true } - if isAliasOf(vulnerability, vuln) { + if isAliasOf(vulnerability, *vuln) { return true } } diff --git a/internal/utility/vulns/vulnerabilities_test.go b/internal/utility/vulns/vulnerabilities_test.go index 9ba61f1173..39ebbea424 100644 --- a/internal/utility/vulns/vulnerabilities_test.go +++ b/internal/utility/vulns/vulnerabilities_test.go @@ -15,14 +15,14 @@ func TestVulnerabilities_Includes(t *testing.T) { } tests := []struct { name string - vs models.Vulnerabilities + vs []*models.Vulnerability args args want bool }{ { name: "", - vs: models.Vulnerabilities{ - models.Vulnerability{ + vs: []*models.Vulnerability{ + { ID: "GHSA-1", Aliases: []string{}, }, @@ -37,8 +37,8 @@ func TestVulnerabilities_Includes(t *testing.T) { }, { name: "", - vs: models.Vulnerabilities{ - models.Vulnerability{ + vs: []*models.Vulnerability{ + { ID: "GHSA-1", Aliases: []string{}, }, @@ -53,8 +53,8 @@ func TestVulnerabilities_Includes(t *testing.T) { }, { name: "", - vs: models.Vulnerabilities{ - models.Vulnerability{ + vs: []*models.Vulnerability{ + { ID: "GHSA-1", Aliases: []string{"GHSA-2"}, }, @@ -69,8 +69,8 @@ func TestVulnerabilities_Includes(t *testing.T) { }, { name: "", - vs: models.Vulnerabilities{ - models.Vulnerability{ + vs: []*models.Vulnerability{ + { ID: "GHSA-1", Aliases: []string{}, }, @@ -85,8 +85,8 @@ func TestVulnerabilities_Includes(t *testing.T) { }, { name: "", - vs: models.Vulnerabilities{ - models.Vulnerability{ + vs: []*models.Vulnerability{ + { ID: "GHSA-1", Aliases: []string{"CVE-1"}, }, @@ -101,8 +101,8 @@ func TestVulnerabilities_Includes(t *testing.T) { }, { name: "", - vs: models.Vulnerabilities{ - models.Vulnerability{ + vs: []*models.Vulnerability{ + { ID: "GHSA-1", Aliases: []string{"CVE-2"}, }, diff --git a/pkg/osvscanner/osvscanner.go b/pkg/osvscanner/osvscanner.go index 811e63272e..3d94ea7821 100644 --- a/pkg/osvscanner/osvscanner.go +++ b/pkg/osvscanner/osvscanner.go @@ -1,19 +1,22 @@ package osvscanner import ( + "context" "errors" "fmt" + "time" + "github.com/google/osv-scalibr/extractor" + "github.com/google/osv-scanner/internal/clients/clientimpl/localmatcher" + "github.com/google/osv-scanner/internal/clients/clientimpl/osvmatcher" + "github.com/google/osv-scanner/internal/clients/clientinterfaces" "github.com/google/osv-scanner/internal/config" "github.com/google/osv-scanner/internal/depsdev" "github.com/google/osv-scanner/internal/imodels" "github.com/google/osv-scanner/internal/imodels/results" - "github.com/google/osv-scanner/internal/local" + "github.com/google/osv-scanner/internal/osvdev" "github.com/google/osv-scanner/internal/output" - "github.com/google/osv-scanner/internal/semantic" - "github.com/google/osv-scanner/internal/version" "github.com/google/osv-scanner/pkg/models" - "github.com/google/osv-scanner/pkg/osv" "github.com/google/osv-scanner/pkg/reporter" "github.com/ossf/osv-schema/bindings/go/osvschema" @@ -114,15 +117,33 @@ func DoScan(actions ScannerActions, r reporter.Reporter) (models.VulnerabilityRe scanResult.PackageScanResults = packages + // ----- Filtering ----- filterUnscannablePackages(r, &scanResult) - filterIgnoredPackages(r, &scanResult) + // ----- Custom Overrides ----- overrideGoVersion(r, &scanResult) - err = makeRequest(r, scanResult.PackageScanResults, actions.CompareOffline, actions.DownloadDatabases, actions.LocalDBPath) - if err != nil { - return models.VulnerabilityResults{}, err + // --- Make Vulnerability Requests --- + { + var matcher clientinterfaces.VulnerabilityMatcher + var err error + if actions.CompareOffline { + matcher, err = localmatcher.NewLocalMatcher(r, actions.LocalDBPath, actions.DownloadDatabases) + if err != nil { + return models.VulnerabilityResults{}, err + } + } else { + matcher = &osvmatcher.OSVMatcher{ + Client: *osvdev.DefaultClient(), + InitialQueryTimeout: 5 * time.Minute, + } + } + + err = makeRequestWithMatcher(r, scanResult.PackageScanResults, matcher) + if err != nil { + return models.VulnerabilityResults{}, err + } } if len(actions.ScanLicensesAllowlist) > 0 || actions.ScanLicensesSummary { @@ -176,82 +197,27 @@ func DoScan(actions ScannerActions, r reporter.Reporter) (models.VulnerabilityRe return results, nil } -// patchPackageForRequest modifies packages before they are sent to osv.dev to -// account for edge cases. -func patchPackageForRequest(pkg imodels.PackageInfo) imodels.PackageInfo { - // Assume Go stdlib patch version as the latest version - // - // This is done because go1.20 and earlier do not support patch - // version in go.mod file, and will fail to build. - // - // However, if we assume patch version as .0, this will cause a lot of - // false positives. This compromise still allows osv-scanner to pick up - // when the user is using a minor version that is out-of-support. - if pkg.Name == "stdlib" && pkg.Ecosystem.Ecosystem == osvschema.EcosystemGo { - v := semantic.ParseSemverLikeVersion(pkg.Version, 3) - if len(v.Components) == 2 { - pkg.Version = fmt.Sprintf( - "%d.%d.%d", - v.Components.Fetch(0), - v.Components.Fetch(1), - 9999, - ) - } - } - - return pkg -} - -// TODO(V2): This will be replaced by the new client interface -func makeRequest( +// TODO(V2): Add context +func makeRequestWithMatcher( r reporter.Reporter, packages []imodels.PackageScanResult, - compareOffline bool, - downloadDBs bool, - localDBPath string) error { - // Make OSV queries from the packages. - var query osv.BatchedQuery - for _, psr := range packages { - p := psr.PackageInfo - p = patchPackageForRequest(p) - switch { - // Prefer making package requests where possible. - case !p.Ecosystem.IsEmpty() && p.Name != "" && p.Version != "": - query.Queries = append(query.Queries, osv.MakePkgRequest(p)) - case p.Commit != "": - query.Queries = append(query.Queries, osv.MakeCommitRequest(p.Commit)) - default: - return fmt.Errorf("package %v does not have a commit, PURL or ecosystem/name/version identifier", p) - } + matcher clientinterfaces.VulnerabilityMatcher) error { + invs := make([]*extractor.Inventory, 0, len(packages)) + for _, pkgs := range packages { + invs = append(invs, pkgs.PackageInfo.OriginalInventory) } - var err error - var hydratedResp *osv.HydratedBatchedResponse - - if compareOffline { - // TODO(v2): Stop depending on lockfile.PackageDetails and use imodels.PackageInfo - // Downloading databases requires network access. - hydratedResp, err = local.MakeRequest(r, query, !downloadDBs, localDBPath) - if err != nil { - return fmt.Errorf("local comparison failed %w", err) - } - } else { - if osv.RequestUserAgent == "" { - osv.RequestUserAgent = "osv-scanner-api/v" + version.OSVVersion - } - - resp, err := osv.MakeRequest(query) - if err != nil { - return fmt.Errorf("%w: osv.dev query failed: %w", ErrAPIFailed, err) - } - hydratedResp, err = osv.Hydrate(resp) - if err != nil { - return fmt.Errorf("%w: failed to hydrate OSV response: %w", ErrAPIFailed, err) + res, err := matcher.Match(context.Background(), invs) + if err != nil { + // TODO: Handle error here + r.Errorf("error when retrieving vulns: %v", err) + if res == nil { + return err } } - for i, result := range hydratedResp.Results { - packages[i].Vulnerabilities = result.Vulns + for i, vulns := range res { + packages[i].Vulnerabilities = vulns } return nil @@ -288,6 +254,8 @@ func overrideGoVersion(r reporter.Reporter, scanResults *results.ScanResults) { configToUse := scanResults.ConfigManager.Get(r, pkg.Location) if configToUse.GoVersionOverride != "" { scanResults.PackageScanResults[i].PackageInfo.Version = configToUse.GoVersionOverride + // Also patch it in the inventory, as we have to use the original inventory to make requests + scanResults.PackageScanResults[i].PackageInfo.OriginalInventory.Version = configToUse.GoVersionOverride } continue diff --git a/pkg/osvscanner/vulnerability_result.go b/pkg/osvscanner/vulnerability_result.go index 1963159cf5..0129aa24a9 100644 --- a/pkg/osvscanner/vulnerability_result.go +++ b/pkg/osvscanner/vulnerability_result.go @@ -53,7 +53,9 @@ func buildVulnerabilityResults( if len(psr.Vulnerabilities) > 0 { if !configToUse.ShouldIgnorePackageVulnerabilities(p) { includePackage = true - pkg.Vulnerabilities = psr.Vulnerabilities + for _, vuln := range psr.Vulnerabilities { + pkg.Vulnerabilities = append(pkg.Vulnerabilities, *vuln) + } pkg.Groups = grouper.Group(grouper.ConvertVulnerabilityToIDAliases(pkg.Vulnerabilities)) for i, group := range pkg.Groups { pkg.Groups[i].MaxSeverity = output.MaxSeverity(group, pkg) diff --git a/pkg/osvscanner/vulnerability_result_internal_test.go b/pkg/osvscanner/vulnerability_result_internal_test.go index 2ef604bfd5..00710c95b1 100644 --- a/pkg/osvscanner/vulnerability_result_internal_test.go +++ b/pkg/osvscanner/vulnerability_result_internal_test.go @@ -81,9 +81,13 @@ func Test_assembleResult(t *testing.T) { } licensesResp = makeLicensesResp() for i := range packages { + vulnPointers := []*models.Vulnerability{} + for _, vuln := range vulnsResp.Results[i].Vulns { + vulnPointers = append(vulnPointers, &vuln) + } scanResults.PackageScanResults = append(scanResults.PackageScanResults, imodels.PackageScanResult{ PackageInfo: packages[i], - Vulnerabilities: vulnsResp.Results[i].Vulns, + Vulnerabilities: vulnPointers, Licenses: licensesResp[i], }) }