Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input Validation for Monitor Fields #1779

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ import org.opensearch.commons.alerting.model.ScheduledJob
import org.opensearch.commons.alerting.util.AlertingException
import org.opensearch.commons.alerting.util.isMonitorOfStandardType
import org.opensearch.commons.utils.getInvalidNameChars
import org.opensearch.commons.utils.isValidId
import org.opensearch.commons.utils.isValidName
import org.opensearch.commons.utils.isValidQueryName
import org.opensearch.core.rest.RestStatus
import org.opensearch.core.xcontent.ToXContent
import org.opensearch.core.xcontent.XContentParser.Token
Expand Down Expand Up @@ -86,6 +88,14 @@ class RestIndexMonitorAction : BaseRestHandler() {
throw AlertingException.wrap(IllegalArgumentException("Missing monitor ID"))
}

// Check if the ID is valid
if (request.method() == PUT && !isValidId(id)) {
throw IllegalArgumentException(
"Invalid monitor ID [$id]. " +
"Monitor ID should be alphanumeric string with +, /, _, or - characters only."
)
}

// Validate request by parsing JSON to Monitor
val xcp = request.contentParser()
ensureExpectedToken(Token.START_OBJECT, xcp.nextToken(), xcp)
Expand All @@ -95,6 +105,14 @@ class RestIndexMonitorAction : BaseRestHandler() {
try {
monitor = Monitor.parse(xcp, id).copy(lastUpdateTime = Instant.now())

// Validate if the monitor name is valid
if (!isValidName(monitor.name)) {
throw IllegalArgumentException(
"Invalid monitor name [${monitor.name}]. " +
"Monitor Name should be alphanumeric (4-50 chars) starting with letter or underscore."
)
}

rbacRoles = request.contentParser().map()["rbac_roles"] as List<String>?

validateDataSources(monitor)
Expand All @@ -106,23 +124,68 @@ class RestIndexMonitorAction : BaseRestHandler() {
Monitor.MonitorType.QUERY_LEVEL_MONITOR -> {
triggers.forEach {
if (it !is QueryLevelTrigger) {
throw (IllegalArgumentException("Illegal trigger type, ${it.javaClass.name}, for query level monitor"))
throw (IllegalArgumentException("Illegal trigger type, ${it.javaClass.name}, for query level monitor."))
}
if (!isValidName(it.name)) {
throw IllegalArgumentException(
"Invalid trigger name [${it.name}]. " +
"Trigger Name should be alphanumeric (4-50 chars) starting with letter or underscore."
)
}
it.actions.forEach { action ->
val destinationId = action.destinationId
if (!isValidId(destinationId)) {
throw IllegalArgumentException(
"Invalid destination ID [$destinationId]. " +
"Destination ID should be alphanumeric string with +, /, _, or - characters only."
)
}
}
}
}

Monitor.MonitorType.BUCKET_LEVEL_MONITOR -> {
triggers.forEach {
if (it !is BucketLevelTrigger) {
throw IllegalArgumentException("Illegal trigger type, ${it.javaClass.name}, for bucket level monitor")
throw IllegalArgumentException("Illegal trigger type, ${it.javaClass.name}, for bucket level monitor.")
}
if (!isValidName(it.name)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed with @vikhy-aws offline that it would be helpful to add a validate function to the trigger interface to validate the various fields rather than adding each check here. Specific trigger types (e.g., bucket level trigger) can then override the interface's validate function as needed to accommodate their specific validation scenarios.

throw IllegalArgumentException(
"Invalid trigger name [${it.name}]. " +
"Trigger Name should be alphanumeric (4-50 chars) starting with letter or underscore."
)
}
it.actions.forEach { action ->
val destinationId = action.destinationId
if (!isValidId(destinationId)) {
throw IllegalArgumentException(
"Invalid destination ID [$destinationId]. " +
"Destination ID should be alphanumeric string with +, /, _, or - characters only."
)
}
}
}
}

Monitor.MonitorType.CLUSTER_METRICS_MONITOR -> {
triggers.forEach {
if (it !is QueryLevelTrigger) {
throw IllegalArgumentException("Illegal trigger type, ${it.javaClass.name}, for cluster metrics monitor")
throw IllegalArgumentException("Illegal trigger type, ${it.javaClass.name}, for cluster metrics monitor.")
}
if (!isValidName(it.name)) {
throw IllegalArgumentException(
"Invalid trigger name [${it.name}]. " +
"Trigger Name should be alphanumeric (4-50 chars) starting with letter or underscore."
)
}
it.actions.forEach { action ->
val destinationId = action.destinationId
if (!isValidId(destinationId)) {
throw IllegalArgumentException(
"Invalid destination ID [$destinationId]. " +
"Destination ID should be alphanumeric string with +, /, _, or - characters only."
)
}
}
}
}
Expand All @@ -131,7 +194,22 @@ class RestIndexMonitorAction : BaseRestHandler() {
validateDocLevelQueryName(monitor)
triggers.forEach {
if (it !is DocumentLevelTrigger) {
throw IllegalArgumentException("Illegal trigger type, ${it.javaClass.name}, for document level monitor")
throw IllegalArgumentException("Illegal trigger type, ${it.javaClass.name}, for document level monitor.")
}
if (!isValidName(it.name)) {
throw IllegalArgumentException(
"Invalid trigger name [${it.name}]. " +
"Trigger Name should be alphanumeric (4-50 chars) starting with letter or underscore."
)
}
it.actions.forEach { action ->
val destinationId = action.destinationId
if (!isValidId(destinationId)) {
throw IllegalArgumentException(
"Invalid destination ID [$destinationId]. " +
"Destination ID should be alphanumeric string with +, /, _, or - characters only."
)
}
}
}
}
Expand All @@ -158,7 +236,7 @@ class RestIndexMonitorAction : BaseRestHandler() {
private fun validateDocLevelQueryName(monitor: Monitor) {
monitor.inputs.filterIsInstance<DocLevelMonitorInput>().forEach { docLevelMonitorInput ->
docLevelMonitorInput.queries.forEach { dlq ->
if (!isValidName(dlq.name)) {
if (!isValidQueryName(dlq.name)) {
throw IllegalArgumentException(
"Doc level query name may not start with [_, +, -], contain '..', or contain: " +
getInvalidNameChars().replace("\\", "")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ fun randomQueryLevelTrigger(
severity: String = "1",
condition: Script = randomScript(),
actions: List<Action> = mutableListOf(),
destinationId: String = ""
destinationId: String = "sample"
): QueryLevelTrigger {
return QueryLevelTrigger(
id = id,
Expand All @@ -315,7 +315,7 @@ fun randomBucketLevelTrigger(
severity: String = "1",
bucketSelector: BucketSelectorExtAggregationBuilder = randomBucketSelectorExtAggregationBuilder(name = id),
actions: List<Action> = mutableListOf(),
destinationId: String = ""
destinationId: String = "sample"
): BucketLevelTrigger {
return BucketLevelTrigger(
id = id,
Expand All @@ -326,7 +326,7 @@ fun randomBucketLevelTrigger(
)
}

fun randomActionsForBucketLevelTrigger(min: Int = 0, max: Int = 10, destinationId: String = ""): List<Action> =
fun randomActionsForBucketLevelTrigger(min: Int = 0, max: Int = 10, destinationId: String = "sample"): List<Action> =
(min..randomInt(max)).map { randomActionWithPolicy(destinationId = destinationId) }

fun randomDocumentLevelTrigger(
Expand All @@ -335,7 +335,7 @@ fun randomDocumentLevelTrigger(
severity: String = "1",
condition: Script = randomScript(),
actions: List<Action> = mutableListOf(),
destinationId: String = ""
destinationId: String = "sample"
): DocumentLevelTrigger {
return DocumentLevelTrigger(
id = id,
Expand Down Expand Up @@ -424,15 +424,15 @@ fun randomTemplateScript(
fun randomAction(
name: String = OpenSearchRestTestCase.randomUnicodeOfLength(10),
template: Script = randomTemplateScript("Hello World"),
destinationId: String = "",
destinationId: String = "sample",
throttleEnabled: Boolean = false,
throttle: Throttle = randomThrottle()
) = Action(name, destinationId, template, template, throttleEnabled, throttle, actionExecutionPolicy = null)

fun randomActionWithPolicy(
name: String = OpenSearchRestTestCase.randomUnicodeOfLength(10),
template: Script = randomTemplateScript("Hello World"),
destinationId: String = "",
destinationId: String = "sample",
throttleEnabled: Boolean = false,
throttle: Throttle = randomThrottle(),
actionExecutionPolicy: ActionExecutionPolicy? = randomActionExecutionPolicy()
Expand Down Expand Up @@ -773,7 +773,7 @@ fun randomChainedAlertTrigger(
severity: String = "1",
condition: Script = randomScript(),
actions: List<Action> = mutableListOf(),
destinationId: String = ""
destinationId: String = "sample"
): ChainedAlertTrigger {
return ChainedAlertTrigger(
id = id,
Expand Down
Loading
Loading