Skip to content

Commit

Permalink
vulndb/client: make Get accept a single module path
Browse files Browse the repository at this point in the history
Current signature of Get accepts a list of module paths. A single
argument is a cleaner solution not affecting client library usability.
This CL makes the switch and cleans up unit testing.

Change-Id: Ic67fa02e0372f19882b75c47ced8f1a2a9b3a233
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/344870
Run-TryBot: Zvonimir Pavlinovic <[email protected]>
TryBot-Result: Go Bot <[email protected]>
Reviewed-by: Roland Shoemaker <[email protected]>
Reviewed-by: kokoro <[email protected]>
Trust: Zvonimir Pavlinovic <[email protected]>
  • Loading branch information
softdev050 committed Sep 2, 2021
1 parent 3e913ef commit ce2b095
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 115 deletions.
123 changes: 55 additions & 68 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,30 +47,26 @@ import (
)

type source interface {
Get([]string) ([]*osv.Entry, error)
Get(string) ([]*osv.Entry, error)
Index() (osv.DBIndex, error)
}

type localSource struct {
dir string
}

func (ls *localSource) Get(modules []string) ([]*osv.Entry, error) {
var entries []*osv.Entry
for _, p := range modules {
content, err := ioutil.ReadFile(filepath.Join(ls.dir, p+".json"))
if os.IsNotExist(err) {
continue
} else if err != nil {
return nil, err
}
var e []*osv.Entry
if err = json.Unmarshal(content, &e); err != nil {
return nil, err
}
entries = append(entries, e...)
func (ls *localSource) Get(module string) ([]*osv.Entry, error) {
content, err := ioutil.ReadFile(filepath.Join(ls.dir, module+".json"))
if os.IsNotExist(err) {
return nil, nil
} else if err != nil {
return nil, err
}
return entries, nil
var e []*osv.Entry
if err = json.Unmarshal(content, &e); err != nil {
return nil, err
}
return e, nil
}

func (ls *localSource) Index() (osv.DBIndex, error) {
Expand Down Expand Up @@ -147,69 +143,60 @@ func (hs *httpSource) Index() (osv.DBIndex, error) {
return index, nil
}

func (hs *httpSource) Get(modules []string) ([]*osv.Entry, error) {
var entries []*osv.Entry

func (hs *httpSource) Get(module string) ([]*osv.Entry, error) {
index, err := hs.Index()
if err != nil {
return nil, err
}

var stillNeed []string
for _, p := range modules {
lastModified, present := index[p]
if !present {
continue
}
if hs.cache != nil {
if cached, err := hs.cache.ReadEntries(hs.dbName, p); err != nil {
return nil, err
} else if len(cached) != 0 {
var stale bool
for _, c := range cached {
if c.Modified.Before(lastModified) {
stale = true
break
}
}
if !stale {
entries = append(entries, cached...)
continue
lastModified, present := index[module]
if !present {
return nil, nil
}

if hs.cache != nil {
if cached, err := hs.cache.ReadEntries(hs.dbName, module); err != nil {
return nil, err
} else if len(cached) != 0 {
var stale bool
for _, c := range cached {
if c.Modified.Before(lastModified) {
stale = true
break
}
}
if !stale {
return cached, nil
}
}
stillNeed = append(stillNeed, p)
}

for _, p := range stillNeed {
resp, err := hs.c.Get(fmt.Sprintf("%s/%s.json", hs.url, p))
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
continue
}
// might want this to be a LimitedReader
content, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var e []*osv.Entry
if err = json.Unmarshal(content, &e); err != nil {
return nil, err
}
// TODO: we may want to check that the returned entries actually match
// the module we asked about, so that the cache cannot be poisoned
entries = append(entries, e...)
resp, err := hs.c.Get(fmt.Sprintf("%s/%s.json", hs.url, module))
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, nil
}
// might want this to be a LimitedReader
content, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var e []*osv.Entry
// TODO: we may want to check that the returned entries actually match
// the module we asked about, so that the cache cannot be poisoned
if err = json.Unmarshal(content, &e); err != nil {
return nil, err
}

if hs.cache != nil {
if err := hs.cache.WriteEntries(hs.dbName, p, e); err != nil {
return nil, err
}
if hs.cache != nil {
if err := hs.cache.WriteEntries(hs.dbName, module, e); err != nil {
return nil, err
}
}
return entries, nil
return e, nil
}

type Client struct {
Expand Down Expand Up @@ -252,11 +239,11 @@ func NewClient(sources []string, opts Options) (*Client, error) {
return c, nil
}

func (c *Client) Get(modules []string) ([]*osv.Entry, error) {
func (c *Client) Get(module string) ([]*osv.Entry, error) {
var entries []*osv.Entry
// probably should be parallelized
for _, s := range c.sources {
e, err := s.Get(modules)
e, err := s.Get(module)
if err != nil {
return nil, err // be failure tolerant?
}
Expand Down
79 changes: 32 additions & 47 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,49 +21,38 @@ import (
"golang.org/x/vulndb/osv"
)

var testVuln1 string = `[
{"ID":"ID1","Package":{"Name":"golang.org/example/one","Ecosystem":"go"}, "Summary":"",
var testVuln string = `[
{"ID":"ID","Package":{"Name":"golang.org/example/one","Ecosystem":"go"}, "Summary":"",
"Severity":2,"Affects":{"Ranges":[{"Type":"SEMVER","Introduced":"","Fixed":"v2.2.0"}]},
"ecosystem_specific":{"Symbols":["some_symbol_1"]
}}]`

var testVuln2 string = `[
{"ID":"ID2","Package":{"Name":"golang.org/example/two","Ecosystem":"go"}, "Summary":"",
"Severity":2,"Affects":{"Ranges":[{"Type":"SEMVER","Introduced":"","Fixed":"v2.1.0"}]},
"ecosystem_specific":{"Symbols":["some_symbol_2"]
}}]`

// index containing timestamps for packages in testVuln1 and testVuln2.
// index containing timestamps for package in testVuln.
var index string = `{
"golang.org/example/one": "2020-03-09T10:00:00.81362141-07:00",
"golang.org/example/two": "2019-02-05T09:00:00.31561157-07:00"
"golang.org/example/one": "2020-03-09T10:00:00.81362141-07:00"
}`

func serveTestVuln1(w http.ResponseWriter, req *http.Request) {
fmt.Fprintf(w, testVuln1)
}

func serveTestVuln2(w http.ResponseWriter, req *http.Request) {
fmt.Fprintf(w, testVuln2)
func serveTestVuln(w http.ResponseWriter, req *http.Request) {
fmt.Fprintf(w, testVuln)
}

func serveIndex(w http.ResponseWriter, req *http.Request) {
fmt.Fprintf(w, index)
}

// cachedTestVuln2 returns a function creating a local cache
// for db with `dbName` with a version of testVuln2 where
// cachedTestVuln returns a function creating a local cache
// for db with `dbName` with a version of testVuln where
// Summary="cached" and LastModified happened after entry
// in the `index` for the same pkg.
func cachedTestVuln2(dbName string) func() Cache {
func cachedTestVuln(dbName string) func() Cache {
return func() Cache {
c := &fsCache{}
e := &osv.Entry{
ID: "ID2",
ID: "ID1",
Details: "cached",
Modified: time.Now(),
}
c.WriteEntries(dbName, "golang.org/example/two", []*osv.Entry{e})
c.WriteEntries(dbName, "golang.org/example/one", []*osv.Entry{e})
return c
}
}
Expand All @@ -81,10 +70,7 @@ func createDirAndFile(dir, file, content string) error {
func localDB(t *testing.T) (string, error) {
dbName := t.TempDir()

if err := createDirAndFile(path.Join(dbName, "/golang.org/example/"), "one.json", testVuln1); err != nil {
return "", err
}
if err := createDirAndFile(path.Join(dbName, "/golang.org/example/"), "two.json", testVuln2); err != nil {
if err := createDirAndFile(path.Join(dbName, "/golang.org/example/"), "one.json", testVuln); err != nil {
return "", err
}
if err := createDirAndFile(path.Join(dbName, ""), "index.json", index); err != nil {
Expand All @@ -99,8 +85,7 @@ func TestClient(t *testing.T) {
}

// Create a local http database.
http.HandleFunc("/golang.org/example/one.json", serveTestVuln1)
http.HandleFunc("/golang.org/example/two.json", serveTestVuln2)
http.HandleFunc("/golang.org/example/one.json", serveTestVuln)
http.HandleFunc("/index.json", serveIndex)

l, err := net.Listen("tcp", "127.0.0.1:")
Expand All @@ -121,20 +106,20 @@ func TestClient(t *testing.T) {
name string
source string
createCache func() Cache
noVulns int
summaries map[string]string
// cache summary for testVuln
summary string
}{
// Test the http client without any cache.
{name: "http-no-cache", source: "http://localhost:" + port, createCache: func() Cache { return nil }, noVulns: 2, summaries: map[string]string{"ID1": "", "ID2": ""}},
{name: "http-no-cache", source: "http://localhost:" + port, createCache: func() Cache { return nil }, summary: ""},
// Test the http client with empty cache.
{name: "http-empty-cache", source: "http://localhost:" + port, createCache: func() Cache { return &fsCache{} }, noVulns: 2, summaries: map[string]string{"ID1": "", "ID2": ""}},
{name: "http-empty-cache", source: "http://localhost:" + port, createCache: func() Cache { return &fsCache{} }, summary: ""},
// Test the client with non-stale cache containing a version of testVuln2 where Summary="cached".
{name: "http-cache", source: "http://localhost:" + port, createCache: cachedTestVuln2("localhost"), noVulns: 2, summaries: map[string]string{"ID1": "", "ID2": "cached"}},
{name: "http-cache", source: "http://localhost:" + port, createCache: cachedTestVuln("localhost"), summary: "cached"},
// Repeat the same for local file client.
{name: "file-no-cache", source: "file://" + localDBName, createCache: func() Cache { return nil }, noVulns: 2, summaries: map[string]string{"ID1": "", "ID2": ""}},
{name: "file-empty-cache", source: "file://" + localDBName, createCache: func() Cache { return &fsCache{} }, noVulns: 2, summaries: map[string]string{"ID1": "", "ID2": ""}},
{name: "file-no-cache", source: "file://" + localDBName, createCache: func() Cache { return nil }, summary: ""},
{name: "file-empty-cache", source: "file://" + localDBName, createCache: func() Cache { return &fsCache{} }, summary: ""},
// Cache does not play a role in local file databases.
{name: "file-cache", source: "file://" + localDBName, createCache: cachedTestVuln2(localDBName), noVulns: 2, summaries: map[string]string{"ID1": "", "ID2": ""}},
{name: "file-cache", source: "file://" + localDBName, createCache: cachedTestVuln(localDBName), summary: ""},
} {
// Create fresh cache location each time.
cacheRoot = t.TempDir()
Expand All @@ -144,18 +129,17 @@ func TestClient(t *testing.T) {
t.Fatal(err)
}

vulns, err := client.Get([]string{"golang.org/example/one", "golang.org/example/two"})
vulns, err := client.Get("golang.org/example/one")
if err != nil {
t.Fatal(err)
}
if len(vulns) != test.noVulns {
t.Errorf("want %v vulns for %s; got %v", test.noVulns, test.name, len(vulns))

if len(vulns) != 1 {
t.Errorf("%s: want 1 vuln for golang.org/example/one; got %v", test.name, len(vulns))
}

for _, v := range vulns {
if s, ok := test.summaries[v.ID]; !ok || v.Details != s {
t.Errorf("want '%s' summary for vuln with id %v in %s; got '%s'", s, v.ID, test.name, v.Details)
}
if v := vulns[0]; v.Details != test.summary {
t.Errorf("%s: want '%s' summary for testVuln; got '%s'", test.name, test.summary, v.Details)
}
}
}
Expand All @@ -177,11 +161,12 @@ func TestCorrectFetchesNoCache(t *testing.T) {
defer ts.Close()

hs := &httpSource{url: ts.URL, c: new(http.Client)}
_, err := hs.Get([]string{"a", "b", "c"})
if err != nil {
t.Fatalf("unexpected error: %s", err)
for _, module := range []string{"a", "b", "c"} {
if _, err := hs.Get(module); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
expectedFetches := map[string]int{"/index.json": 1, "/a.json": 1, "/b.json": 1}
expectedFetches := map[string]int{"/index.json": 3, "/a.json": 1, "/b.json": 1}
if !reflect.DeepEqual(fetches, expectedFetches) {
t.Errorf("unexpected fetches, got %v, want %v", fetches, expectedFetches)
}
Expand Down

0 comments on commit ce2b095

Please sign in to comment.