Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIK-4444 Check all matched endpoints for forceProtectionOff #123

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ public static boolean shouldSkipVulnerabilityScan(ContextObject context) {
}
List<Endpoint> matchedEndpoints = matchEndpoints(context.getRouteMetadata(), threadCache.getEndpoints());
if (matchedEndpoints != null && !matchedEndpoints.isEmpty()) {
return matchedEndpoints.get(0).protectionForcedOff();
if (matchedEndpoints.stream().anyMatch(Endpoint::protectionForcedOff)) {
// Protection is forced off on one of more of the matched endpoints :
return true;
}
}
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public void testFetchNewConfig() {
Optional<APIResponse> res = api.fetchNewConfig("token", 2);
assertTrue(res.isPresent());
assertTrue(res.get().block());
assertEquals(1, res.get().endpoints().size());
assertEquals(3, res.get().endpoints().size());
}
@Test
@StdIo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public void renewWithLinkedUpBackgroundProcess() throws InterruptedException {
assertNotNull(threadCacheObject);

// Test the endpoints :
assertEquals(1, threadCacheObject.getEndpoints().size());
assertEquals(3, threadCacheObject.getEndpoints().size());
Endpoint endpoint1 = threadCacheObject.getEndpoints().get(0);
assertEquals("*", endpoint1.getMethod());
assertEquals("/test_ratelimiting_1", endpoint1.getRoute());
Expand Down
6 changes: 6 additions & 0 deletions agent_api/src/test/java/utils/EmptySampleContextObject.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,10 @@ public EmptySampleContextObject(String argument) {
this();
this.query.put("arg", List.of(argument));
}
public EmptySampleContextObject(String argument, String route, String method) {
this();
this.query.put("arg", List.of(argument));
this.route = route;
this.method = method;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
package vulnerabilities;

import dev.aikido.agent_api.background.Endpoint;
import dev.aikido.agent_api.context.ContextObject;
import dev.aikido.agent_api.thread_cache.ThreadCache;
import dev.aikido.agent_api.thread_cache.ThreadCacheObject;
import dev.aikido.agent_api.vulnerabilities.SkipVulnerabilityScanDecider;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import utils.EmptySampleContextObject;
import utils.EmtpyThreadCacheObject;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;

public class SkipVulnerabilityScanDeciderTest {

private ContextObject context;
private ThreadCacheObject threadCacheObject;
private List<Endpoint> createEndpoints(boolean protectionForcedOff1, boolean protectionForcedOff2) {
List<Endpoint> endpoints = new ArrayList<>();
endpoints.add(new Endpoint("POST", "/api/login", 3, 1000, Collections.emptyList(), false, protectionForcedOff1, true));
endpoints.add(new Endpoint("POST", "/api/*", 1, 1000, Collections.emptyList(), false, protectionForcedOff2, true));
endpoints.add(new Endpoint("GET", "/", 3, 1000, Collections.emptyList(), false, false, false));
return endpoints;
}
private List<Endpoint> createEndpointsWildcardMethod(boolean protectionForcedOff1, boolean protectionForcedOff2) {
List<Endpoint> endpoints = new ArrayList<>();
endpoints.add(new Endpoint("*", "/api/login", 3, 1000, Collections.emptyList(), false, protectionForcedOff1, true));
endpoints.add(new Endpoint("*", "/api/*", 1, 1000, Collections.emptyList(), false, protectionForcedOff2, true));
endpoints.add(new Endpoint("GET", "/", 3, 1000, Collections.emptyList(), false, false, false));
return endpoints;
}

@BeforeEach
public void setUp() {
context = new EmptySampleContextObject();
threadCacheObject = EmtpyThreadCacheObject.getEmptyThreadCacheObject();
// Mock the ThreadCache to return our empty thread cache object
ThreadCache.set(threadCacheObject);
}

@Test
public void testShouldSkipVulnerabilityScan_NullContext() {
assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(null));
}

@Test
public void testShouldSkipVulnerabilityScan_BypassedIP() {

// Mock the ThreadCacheObject to return a bypassed IP
ThreadCacheObject mockThreadCache = mock(ThreadCacheObject.class);
when(mockThreadCache.isBypassedIP(context.getRemoteAddress())).thenReturn(true);
ThreadCache.set(mockThreadCache);

assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(context));
}

@Test
public void testShouldSkipVulnerabilityScan_ProtectionForcedOff_1() {
// Mock the ThreadCacheObject to return a matched endpoint
ThreadCacheObject mockThreadCache = mock(ThreadCacheObject.class);
when(mockThreadCache.getEndpoints()).thenReturn(createEndpoints(true, false));

ThreadCache.set(mockThreadCache);
assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/login", "POST")
));

ThreadCache.set(mockThreadCache);
assertFalse(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/login", "GET")
));

ThreadCache.set(mockThreadCache);
assertFalse(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/login2", "POST")
));

when(mockThreadCache.getEndpoints()).thenReturn(createEndpoints(false, false));
ThreadCache.set(mockThreadCache);
assertFalse(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/login", "POST")
));
}

@Test
public void testShouldSkipVulnerabilityScan_ProtectionForcedOff_1_WildCard() {
// Mock the ThreadCacheObject to return a matched endpoint
ThreadCacheObject mockThreadCache = mock(ThreadCacheObject.class);
when(mockThreadCache.getEndpoints()).thenReturn(createEndpointsWildcardMethod(true, false));

ThreadCache.set(mockThreadCache);
assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/login", "POST")
));

ThreadCache.set(mockThreadCache);
assertFalse(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/login2", "POST")
));

when(mockThreadCache.getEndpoints()).thenReturn(createEndpointsWildcardMethod(false, false));
ThreadCache.set(mockThreadCache);
assertFalse(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/login", "POST")
));
}

@Test
public void testShouldSkipVulnerabilityScan_ProtectionForcedOff_2() {
// Mock the ThreadCacheObject to return a matched endpoint
ThreadCacheObject mockThreadCache = mock(ThreadCacheObject.class);
when(mockThreadCache.getEndpoints()).thenReturn(createEndpoints(false, true));

ThreadCache.set(mockThreadCache);
assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/login", "POST")
));

ThreadCache.set(mockThreadCache);
assertFalse(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/login", "GET")
));

ThreadCache.set(mockThreadCache);
assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/login2", "POST")
));

ThreadCache.set(mockThreadCache);
assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/", "POST")
));

when(mockThreadCache.getEndpoints()).thenReturn(createEndpoints(true, true));
ThreadCache.set(mockThreadCache);
assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/login", "POST")
));

ThreadCache.set(mockThreadCache);
assertFalse(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/login", "GET")
));

ThreadCache.set(mockThreadCache);
assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/login2", "POST")
));

ThreadCache.set(mockThreadCache);
assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/", "POST")
));
}

@Test
public void testShouldSkipVulnerabilityScan_ProtectionForcedOff_WildcardMethod() {
// Mock the ThreadCacheObject to return a matched endpoint
ThreadCacheObject mockThreadCache = mock(ThreadCacheObject.class);
when(mockThreadCache.getEndpoints()).thenReturn(createEndpointsWildcardMethod(false, true));

ThreadCache.set(mockThreadCache);
assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/login", "POST")
));

ThreadCache.set(mockThreadCache);
assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/login2", "POST")
));

ThreadCache.set(mockThreadCache);
assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/", "POST")
));

when(mockThreadCache.getEndpoints()).thenReturn(createEndpointsWildcardMethod(true, true));
ThreadCache.set(mockThreadCache);
assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/login", "POST")
));

ThreadCache.set(mockThreadCache);
assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/login2", "POST")
));

ThreadCache.set(mockThreadCache);
assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/", "POST")
));
}

@Test
public void testShouldSkipVulnerabilityScan_NoConditionsMet() {

// Create a mocked endpoint with protection not forced off
Endpoint endpoint = mock(Endpoint.class);
when(endpoint.protectionForcedOff()).thenReturn(false);

// Mock the ThreadCacheObject to return a matched endpoint
ThreadCacheObject mockThreadCache = mock(ThreadCacheObject.class);
when(mockThreadCache.getEndpoints()).thenReturn(createEndpoints(false, false));
when(mockThreadCache.isBypassedIP(context.getRemoteAddress())).thenReturn(false);
ThreadCache.set(mockThreadCache);

assertFalse(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(
new EmptySampleContextObject("", "/api/login", "POST")
));
}
}
31 changes: 30 additions & 1 deletion end2end/server/mock_aikido_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,24 @@
"windowSizeInMS": 1000 * 5,
},
"graphql": False,
},
{
"route": "/api/pets/create",
"method": "POST",
"forceProtectionOff": False,
"rateLimiting": {
"enabled": False,
},
"graphql": False,
},
{
"route": "/api/*",
"method": "*",
"forceProtectionOff": False,
"rateLimiting": {
"enabled": False,
},
"graphql": False,
}
],
"blockedUserIds": ["12345"],
Expand Down Expand Up @@ -75,7 +93,7 @@ def mock_set_config():
configUpdatedAt = int(time.time())
responses["config"] = request.get_json()
responses["config"]["configUpdatedAt"] = configUpdatedAt
responses["configUpdatedAt"] = { "serviceId": 1, "configUpdatedAt": configUpdatedAt }
responses["configUpdatedAt"] = configUpdatedAt
return jsonify({})


Expand All @@ -89,6 +107,17 @@ def mock_reset():
events = [] # Reset events
return jsonify({})

@app.route('/mock/set_protection', methods=['POST'])
def mock_set_protection():
req = request.get_json()
global responses
responses["config"]["endpoints"][1]["forceProtectionOff"] = bool(req["api_pets_create"])
responses["config"]["endpoints"][2]["forceProtectionOff"] = bool(req["api"])
responses["config"]["configUpdatedAt"] = int(time.time()*1000)
responses["configUpdatedAt"] = int(time.time()*1000)

return jsonify({})

if __name__ == '__main__':
if len(sys.argv) < 2 or len(sys.argv) > 3:
print("Usage: python mock_server.py <port> [config_file]")
Expand Down
33 changes: 32 additions & 1 deletion end2end/spring_boot_mysql.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import time
from utils.test_safe_vs_unsafe_payloads import test_safe_vs_unsafe_payloads
from spring_boot_mysql.test_two_sql_attacks import test_two_sql_attacks
from spring_boot_mysql.test_ip_blocking import test_ip_blocking
from spring_boot_mysql.test_bot_blocking import test_bot_blocking
from spring_boot_mysql.test_ratelimiting import test_ratelimiting_per_user, test_ratelimiting
from utils.EventHandler import EventHandler
from utils.make_requests import make_post_request

payloads = {
"safe": { "name": "Bobby" },
Expand All @@ -16,15 +18,44 @@

event_handler = EventHandler()
event_handler.reset()

# Test SQL attacks :
test_safe_vs_unsafe_payloads(payloads, urls, user_id="123") # Test MySQL driver
print("✅ MySQL Driver tested")

test_safe_vs_unsafe_payloads(payloads, urls, "/mariadb", user_id="456") # Also test MariaDB driver
print("✅ MariaDB Driver tested")

# Test blocklists :
test_ip_blocking("http://localhost:8082/")
print("✅ IP Blocking tested")

test_bot_blocking("http://localhost:8082/")
print("✅ Bot Blocking tested")


# Test ratelimiting (we can use a header to set user) :
test_ratelimiting("http://localhost:8082/test_ratelimiting_1")
print("✅ Rate-limiting tested (IP Based)")

test_ratelimiting_per_user("http://localhost:8082/test_ratelimiting_1")
print("✅ Rate-limiting tested (User Based)")

test_two_sql_attacks(event_handler)
print("✅ Attack reporting tested (2x)")

# Test forceProtectionOff
make_post_request(urls["enabled"], payloads["unsafe"], status_code=500)

# Tests with /api/* and method * protection forced off.
event_handler.set_protection(False, True)
time.sleep(70) # Wait for config to be fetched
make_post_request(urls["enabled"], payloads["unsafe"], status_code=200)

# Tests with /api/pets/create protection forced off.
event_handler.set_protection(True, False)
time.sleep(70) # Wait for config to be fetched
make_post_request(urls["enabled"], payloads["unsafe"], status_code=200)


test_two_sql_attacks(event_handler)
print("✅ Tested force protection off")
6 changes: 6 additions & 0 deletions end2end/utils/EventHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ def fetch_events_from_mock(self):
return json_events
def fetch_attacks(self):
return filter_on_event_type(self.fetch_events_from_mock(), "detected_attack")
def set_protection(self, api_pets_create_protection, api_protection):
print("Setting forceProtectionOff")
res = requests.post(self.url + "/mock/set_protection", json={
"api_pets_create": api_pets_create_protection,
"api": api_protection
}, timeout=5)

def filter_on_event_type(events, type):
return [event for event in events if event["type"] == type]