Skip to content

Commit

Permalink
Support --foo=bar syntax (#98)
Browse files Browse the repository at this point in the history
Fixes #97

Relatively straightforward change. We do `head.split("=", 2) match` to
see if it contains any `=`, and if so treat the portion of the string
after the first `=` as the value. Added some unit test to cover both
success cases and failure cases: use with short names or with flags is
unsupported.
  • Loading branch information
lihaoyi authored Sep 16, 2023
1 parent 81c7eb6 commit b25a9fb
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 11 deletions.
43 changes: 32 additions & 11 deletions mainargs/src/TokenGrouping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,53 @@ object TokenGrouping {
}

val flatArgs = flatArgs0.toList
val keywordArgMap = argSigs
def makeKeywordArgMap(getNames: ArgSig => Iterable[String]) = argSigs
.collect {
case (a, r: TokensReader.Simple[_]) if !a.positional => a
case (a, r: TokensReader.Flag) => a
}
.flatMap { x => (x.name.map("--" + _) ++ x.shortName.map("-" + _)).map(_ -> x) }
.flatMap { x => getNames(x).map(_ -> x) }
.toMap[String, ArgSig]

lazy val keywordArgMap = makeKeywordArgMap(
x => x.name.map("--" + _) ++ x.shortName.map("-" + _)
)

lazy val longKeywordArgMap = makeKeywordArgMap(x => x.name.map("--" + _))

@tailrec def rec(
remaining: List[String],
current: Map[ArgSig, Vector[String]]
): Result[TokenGrouping[B]] = {
remaining match {
case head :: rest =>

def lookupArgMap(k: String, m: Map[String, ArgSig]): Option[(ArgSig, mainargs.TokensReader[_])] = {
m.get(k).map(a => (a, a.reader))
}

if (head.startsWith("-") && head.exists(_ != '-')) {
keywordArgMap.get(head) match {
case Some(cliArg: ArgSig) if cliArg.reader.isFlag =>
rec(rest, Util.appendMap(current, cliArg, ""))
case Some(cliArg: ArgSig) if !cliArg.reader.isLeftover =>
rest match {
case next :: rest2 => rec(rest2, Util.appendMap(current, cliArg, next))
case Nil =>
Result.Failure.MismatchedArguments(Nil, Nil, Nil, incomplete = Some(cliArg))
head.split("=", 2) match{
case Array(first, second) =>
lookupArgMap(first, longKeywordArgMap) match {
case Some((cliArg, _: TokensReader.Simple[_])) =>
rec(rest, Util.appendMap(current, cliArg, second))

case _ => complete(remaining, current)
}

case _ => complete(remaining, current)
case _ =>
lookupArgMap(head, keywordArgMap) match {
case Some((cliArg, _: TokensReader.Flag)) =>
rec(rest, Util.appendMap(current, cliArg, ""))
case Some((cliArg, _: TokensReader.Simple[_])) =>
rest match {
case next :: rest2 => rec(rest2, Util.appendMap(current, cliArg, next))
case Nil =>
Result.Failure.MismatchedArguments(Nil, Nil, Nil, incomplete = Some(cliArg))
}
case _ => complete(remaining, current)
}
}
} else {
positionalArgSigs.find(!current.contains(_)) match {
Expand Down
55 changes: 55 additions & 0 deletions mainargs/test/src/EqualsSyntaxTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package mainargs
import utest._

object EqualsSyntaxTests extends TestSuite {

object Main {
@main
def run(
@arg(short = 'f', doc = "String to print repeatedly")
foo: String,
@arg(name = "my-num", doc = "How many times to print string")
myNum: Int = 2,
@arg(doc = "Example flag")
bool: Flag
) = {
foo * myNum + " " + bool.value
}
}

val tests = Tests {
test("simple") {
ParserForMethods(Main).runOrThrow(Array("--foo=bar", "--my-num=3")) ==>
"barbarbar false"
}
test("multipleEquals") {
// --foo=bar syntax still works when there's an `=` on the right
ParserForMethods(Main).runOrThrow(Array("--foo=bar=qux")) ==>
"bar=quxbar=qux false"
}
test("shortName") {
// -f=bar syntax doesn't work for short names
ParserForMethods(Main).runEither(Array("-f=bar")) ==>
Left(
"""Missing argument: -f --foo <str>
|Unknown argument: "-f=bar"
|Expected Signature: run
| -f --foo <str> String to print repeatedly
| --my-num <int> How many times to print string
| --bool Example flag
|
|""".stripMargin)
}
test("notFlags") {
// -f=bar syntax doesn't work for flags
ParserForMethods(Main).runEither(Array("--foo=bar", "--bool=true")) ==>
Left("""Unknown argument: "--bool=true"
|Expected Signature: run
| -f --foo <str> String to print repeatedly
| --my-num <int> How many times to print string
| --bool Example flag
|
|""".stripMargin)
}
}
}

0 comments on commit b25a9fb

Please sign in to comment.