Skip to content

Commit

Permalink
Included additional tests
Browse files Browse the repository at this point in the history
Signed-off-by: naveensrinivasan <[email protected]>
  • Loading branch information
naveensrinivasan committed Nov 21, 2024
1 parent bd33593 commit 40f05ae
Showing 1 changed file with 125 additions and 0 deletions.
125 changes: 125 additions & 0 deletions cmd/server/server_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package server

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/bitbomdev/minefield/pkg/graph"
Expand Down Expand Up @@ -200,3 +202,126 @@ func TestOptions_PersistentPreRunE(t *testing.T) {
})
}
}

func TestWithCORS(t *testing.T) {
// Define test cases
testCases := []struct {
name string
options options
requestOrigin string
expectedOrigin string
}{
{
name: "Allowed Origin",
options: options{
CORS: []string{"http://localhost:3000", "https://example.com"},
},
requestOrigin: "http://localhost:3000",
expectedOrigin: "http://localhost:3000",
},
{
name: "Disallowed Origin",
options: options{
CORS: []string{"http://localhost:3000", "https://example.com"},
},
requestOrigin: "http://malicious.com",
expectedOrigin: "",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create a dummy handler that writes a 200 OK status
dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

// Wrap the dummy handler with CORS middleware
handler := withCORS(dummyHandler, &tc.options)

// Create a new HTTP request with the specified Origin header
req := httptest.NewRequest("GET", "http://localhost:8089/test", nil)
req.Header.Set("Origin", tc.requestOrigin)

// Create a ResponseRecorder to capture the response
rr := httptest.NewRecorder()

// Serve the HTTP request
handler.ServeHTTP(rr, req)

// Check the CORS headers
if tc.expectedOrigin != "" {
assert.Equal(t, tc.expectedOrigin, rr.Header().Get("Access-Control-Allow-Origin"), "Access-Control-Allow-Origin should match the allowed origin")
} else {
assert.Empty(t, rr.Header().Get("Access-Control-Allow-Origin"), "Access-Control-Allow-Origin should be empty for disallowed origins")
}

// Optionally, check other CORS headers if needed
if rr.Header().Get("Access-Control-Allow-Credentials") != "" {
assert.Equal(t, "true", rr.Header().Get("Access-Control-Allow-Credentials"), "Access-Control-Allow-Credentials should be true")
}
})
}
}

func TestNewServerCommand(t *testing.T) {
tests := []struct {
name string
storage graph.Storage
options *options
wantErr bool
wantCommand struct {
use string
short string
}
}{
{
name: "creates server command with valid storage and options",
storage: &mockStorage{},
options: &options{
concurrency: 10,
addr: "localhost:8089",
},
wantErr: false,
wantCommand: struct {
use string
short string
}{
use: "server",
short: "Start the minefield server for graph operations and queries",
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd, err := NewServerCommand(tt.storage, tt.options)

if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, cmd)
return
}

assert.NoError(t, err)
assert.NotNil(t, cmd)
assert.Equal(t, tt.wantCommand.use, cmd.Use)
assert.Equal(t, tt.wantCommand.short, cmd.Short)
assert.True(t, cmd.DisableAutoGenTag)

// Verify storage is set correctly in options
assert.Equal(t, tt.storage, tt.options.storage)

// Verify flags are added
flags := cmd.Flags()

concurrencyFlag := flags.Lookup("concurrency")
assert.NotNil(t, concurrencyFlag)
assert.Equal(t, "10", concurrencyFlag.DefValue)

addrFlag := flags.Lookup("addr")
assert.NotNil(t, addrFlag)
assert.Equal(t, "localhost:8089", addrFlag.DefValue)
})
}
}

0 comments on commit 40f05ae

Please sign in to comment.