-
Notifications
You must be signed in to change notification settings - Fork 397
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add max balance middleware and tests
- Loading branch information
Your Name
committed
Feb 23, 2025
1 parent
c8a8456
commit dd50e7b
Showing
6 changed files
with
306 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
package main | ||
|
||
import ( | ||
"bytes" | ||
"encoding/json" | ||
"io" | ||
"net/http" | ||
|
||
tm2Client "github.com/gnolang/faucet/client/http" | ||
"github.com/gnolang/gno/tm2/pkg/crypto" | ||
) | ||
|
||
func getAccountBalanceMiddleware(tm2Client *tm2Client.Client, maxBalance int64) func(next http.Handler) http.Handler { | ||
type request struct { | ||
To string `json:"to"` | ||
} | ||
return func(next http.Handler) http.Handler { | ||
return http.HandlerFunc( | ||
func(w http.ResponseWriter, r *http.Request) { | ||
var data request | ||
body, err := io.ReadAll(r.Body) | ||
if err != nil { | ||
http.Error(w, err.Error(), http.StatusBadRequest) | ||
return | ||
} | ||
|
||
err = json.Unmarshal(body, &data) | ||
r.Body = io.NopCloser(bytes.NewBuffer(body)) | ||
balance, err := checkAccountBalance(tm2Client, data.To) | ||
if err != nil { | ||
http.Error(w, err.Error(), http.StatusBadRequest) | ||
return | ||
} | ||
if balance >= maxBalance { | ||
http.Error(w, "accounts is already topped up", http.StatusBadRequest) | ||
return | ||
} | ||
next.ServeHTTP(w, r) | ||
}, | ||
) | ||
} | ||
} | ||
|
||
var checkAccountBalance = func(tm2Client *tm2Client.Client, walletAddress string) (int64, error) { | ||
address, err := crypto.AddressFromString(walletAddress) | ||
if err != nil { | ||
return 0, err | ||
} | ||
acc, err := tm2Client.GetAccount(address) | ||
if err != nil { | ||
return 0, err | ||
} | ||
return acc.GetCoins().AmountOf("ugnot"), nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
package main | ||
|
||
import ( | ||
"bytes" | ||
"encoding/json" | ||
"errors" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
tm2Client "github.com/gnolang/faucet/client/http" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func mockedCheckAccountBalance(amount int64, err error) func(tm2Client *tm2Client.Client, walletAddress string) (int64, error) { | ||
return func(tm2Client *tm2Client.Client, walletAddress string) (int64, error) { | ||
return amount, err | ||
} | ||
} | ||
|
||
func TestGetAccountBalanceMiddleware(t *testing.T) { | ||
maxBalance := int64(1000) | ||
|
||
tests := []struct { | ||
name string | ||
requestBody map[string]string | ||
expectedStatus int | ||
expectedBody string | ||
checkBalanceFunc func(tm2Client *tm2Client.Client, walletAddress string) (int64, error) | ||
}{ | ||
{ | ||
name: "Valid address with low balance (should pass)", | ||
requestBody: map[string]string{"to": "valid_address_low_balance"}, | ||
expectedStatus: http.StatusOK, | ||
expectedBody: "next handler reached", | ||
checkBalanceFunc: mockedCheckAccountBalance(500, nil), | ||
}, | ||
{ | ||
name: "Valid address with high balance (should fail)", | ||
requestBody: map[string]string{"To": "valid_address_high_balance"}, | ||
expectedStatus: http.StatusBadRequest, | ||
expectedBody: "accounts is already topped up", | ||
checkBalanceFunc: mockedCheckAccountBalance(2*maxBalance, nil), | ||
}, | ||
{ | ||
name: "Invalid address (should fail)", | ||
requestBody: map[string]string{"To": "invalid_address"}, | ||
expectedStatus: http.StatusBadRequest, | ||
expectedBody: "account not found", | ||
checkBalanceFunc: mockedCheckAccountBalance(2*maxBalance, errors.New("account not found")), | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
checkAccountBalance = tt.checkBalanceFunc | ||
// Convert request body to JSON | ||
reqBody, _ := json.Marshal(tt.requestBody) | ||
|
||
// Create request | ||
req := httptest.NewRequest(http.MethodPost, "/claim", bytes.NewReader(reqBody)) | ||
req.Header.Set("Content-Type", "application/json") | ||
|
||
// Create ResponseRecorder | ||
rr := httptest.NewRecorder() | ||
|
||
// Mock next handler | ||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
w.Write([]byte("next handler reached")) | ||
}) | ||
|
||
// Apply middleware | ||
handler := getAccountBalanceMiddleware(nil, maxBalance)(nextHandler) | ||
handler.ServeHTTP(rr, req) | ||
|
||
// Check response | ||
assert.Equal(t, tt.expectedStatus, rr.Code) | ||
assert.Contains(t, rr.Body.String(), tt.expectedBody) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package main | ||
|
||
import ( | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestCooldownLimiter(t *testing.T) { | ||
cooldownDuration := time.Second | ||
limiter := NewCooldownLimiter(cooldownDuration) | ||
user := "testUser" | ||
|
||
// First check should be allowed | ||
if !limiter.CheckCooldown(user) { | ||
t.Errorf("Expected first CheckCooldown to return true, but got false") | ||
} | ||
|
||
// Second check immediately should be denied | ||
if limiter.CheckCooldown(user) { | ||
t.Errorf("Expected second CheckCooldown to return false, but got true") | ||
} | ||
|
||
require.Eventually(t, func() bool { | ||
return limiter.CheckCooldown(user) | ||
}, 2*cooldownDuration, 10*time.Millisecond, "Expected CheckCooldown to return true after cooldown period") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
"time" | ||
|
||
"github.com/google/go-github/v64/github" | ||
) | ||
|
||
// Mock function for exchangeCodeForToken | ||
func mockExchangeCodeForToken(ctx context.Context, secret, clientID, code string) (*github.User, error) { | ||
login := "mock_login" | ||
if code == "valid" { | ||
fmt.Println("mockExchangeCodeForToken: valid") | ||
return &github.User{Login: &login}, nil | ||
} | ||
return nil, errors.New("invalid code") | ||
} | ||
|
||
// Mock function for GitHub client | ||
func mockGetUser(token string) (*github.User, error) { | ||
if token == "mock_token" { | ||
return &github.User{Login: github.String("testUser")}, nil | ||
} | ||
return nil, errors.New("invalid token") | ||
} | ||
|
||
func TestGitHubMiddleware(t *testing.T) { | ||
cooldown := 2 * time.Minute | ||
exchangeCodeForUser = mockExchangeCodeForToken | ||
t.Run("Midleware without credentials", func(t *testing.T) { | ||
middleware := getGithubMiddleware("", "", cooldown) | ||
// Test missing clientID and secret, middleware does nothing | ||
req := httptest.NewRequest("GET", "http://localhost", nil) | ||
rec := httptest.NewRecorder() | ||
|
||
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
})) | ||
|
||
handler.ServeHTTP(rec, req) | ||
|
||
if rec.Code != http.StatusOK { | ||
t.Errorf("Expected status OK, got %d", rec.Code) | ||
} | ||
|
||
}) | ||
t.Run("request without code", func(t *testing.T) { | ||
middleware := getGithubMiddleware("mockClientID", "mockSecret", cooldown) | ||
req := httptest.NewRequest("GET", "http://localhost?code=", nil) | ||
rec := httptest.NewRecorder() | ||
|
||
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
})) | ||
|
||
handler.ServeHTTP(rec, req) | ||
|
||
if rec.Code != http.StatusBadRequest { | ||
t.Errorf("Expected status BadRequest, got %d", rec.Code) | ||
} | ||
|
||
}) | ||
|
||
t.Run("request invalid code", func(t *testing.T) { | ||
middleware := getGithubMiddleware("mockClientID", "mockSecret", cooldown) | ||
req := httptest.NewRequest("GET", "http://localhost?code=invalid", nil) | ||
rec := httptest.NewRecorder() | ||
|
||
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
})) | ||
|
||
handler.ServeHTTP(rec, req) | ||
|
||
if rec.Code != http.StatusBadRequest { | ||
t.Errorf("Expected status BadRequest, got %d", rec.Code) | ||
} | ||
}) | ||
|
||
t.Run("OK", func(t *testing.T) { | ||
middleware := getGithubMiddleware("mockClientID", "mockSecret", cooldown) | ||
req := httptest.NewRequest("GET", "http://localhost?code=valid", nil) | ||
rec := httptest.NewRecorder() | ||
|
||
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
})) | ||
|
||
handler.ServeHTTP(rec, req) | ||
|
||
if rec.Code != http.StatusOK { | ||
t.Errorf("Expected status OK, got %d", rec.Code) | ||
} | ||
}) | ||
|
||
t.Run("Cooldown active", func(t *testing.T) { | ||
middleware := getGithubMiddleware("mockClientID", "mockSecret", cooldown) | ||
req := httptest.NewRequest("GET", "http://localhost?code=valid", nil) | ||
rec := httptest.NewRecorder() | ||
|
||
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
})) | ||
|
||
handler.ServeHTTP(rec, req) | ||
|
||
if rec.Code != http.StatusOK { | ||
t.Errorf("Expected status OK, got %d", rec.Code) | ||
} | ||
|
||
req = httptest.NewRequest("GET", "http://localhost?code=valid", nil) | ||
rec = httptest.NewRecorder() | ||
|
||
handler.ServeHTTP(rec, req) | ||
if rec.Code != http.StatusTooManyRequests { | ||
t.Errorf("Expected status TooManyRequest, got %d", rec.Code) | ||
} | ||
}) | ||
|
||
} | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters