Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1383] Java new use of ParamObject #14645

Merged
merged 1 commit into from
Apr 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ public void testGenerated(){
NDArray$ NDArray = NDArray$.MODULE$;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you still need this line?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it didn't save the world of the method likeNDArray.<method_name> popping out

float[] arr = new float[]{1.0f, 2.0f, 3.0f};
NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new Context("cpu", 0));
float result = NDArray.norm(NDArray.new normParam(nd))[0].toArray()[0];
float result = NDArray.norm(new normParam(nd))[0].toArray()[0];
float cal = 0.0f;
for (float ele : arr) {
cal += ele * ele;
}
cal = (float) Math.sqrt(cal);
assertTrue(Math.abs(result - cal) < 1e-5);
NDArray dotResult = new NDArray(new float[]{0}, new Shape(new int[]{1}), new Context("cpu", 0));
NDArray.dot(NDArray.new dotParam(nd, nd).setOut(dotResult));
NDArray.dot(new dotParam(nd, nd).setOut(dotResult));
assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f}));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ private static int argmax(float[] prob) {
*/
static List<String> postProcessing(NDArray result, List<String> tokens) {
NDArray[] output = NDArray.split(
NDArray.new splitParam(result, 2).setAxis(2));
new splitParam(result, 2).setAxis(2));
// Get the formatted logits result
NDArray startLogits = output[0].reshape(new int[]{0, -3});
NDArray endLogits = output[1].reshape(new int[]{0, -3});
// Get Probability distribution
float[] startProb = NDArray.softmax(
NDArray.new softmaxParam(startLogits))[0].toArray();
new softmaxParam(startLogits))[0].toArray();
float[] endProb = NDArray.softmax(
NDArray.new softmaxParam(endLogits))[0].toArray();
new softmaxParam(endLogits))[0].toArray();
int startIdx = argmax(startProb);
int endIdx = argmax(endProb);
return tokens.subList(startIdx, endIdx + 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
def javaClassGen(FILE_PATH : String) : String = {
val notGenerated = Set("Custom")
val absClassFunctions = functionsToGenerate(false, false, true)
val absFuncs = absClassFunctions.filterNot(ele => notGenerated.contains(ele.name))
val (absFuncs, paramClassUncleaned) =
absClassFunctions.filterNot(ele => notGenerated.contains(ele.name))
.groupBy(_.name.toLowerCase).map(ele => {
/* Pattern matching for not generating deprecated method
* Group all method name in lowercase
Expand All @@ -166,15 +167,16 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
}
}).map(absClassFunction => {
generateJavaAPISignature(absClassFunction)
}).toSeq
}).toSeq.unzip
val paramClass = paramClassUncleaned.filterNot(_.isEmpty)
val packageName = "NDArrayBase"
val packageDef = "package org.apache.mxnet.javaapi"
writeFile(
FILE_PATH + "javaapi/",
packageDef,
packageName,
"import org.apache.mxnet.annotation.Experimental",
absFuncs)
absFuncs, Some(paramClass))
}

/**
Expand Down Expand Up @@ -248,7 +250,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
* @param func The function case class
* @return A formatted string for the function
*/
def generateJavaAPISignature(func : Func) : String = {
def generateJavaAPISignature(func : Func) : (String, String) = {
val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2
var argDef = ListBuffer[String]()
var classDef = ListBuffer[String]()
Expand Down Expand Up @@ -287,22 +289,23 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
| }
| def getOut() = this.out
| """.stripMargin
s"""$scalaDocNoParam
(s"""$scalaDocNoParam
| $experimentalTag
| def ${func.name}(po: ${func.name}Param) : $returnType
| /**
| """.stripMargin,
s"""/**
| * This Param Object is specifically used for ${func.name}
| ${requiredParam.mkString("\n")}
| */
| class ${func.name}Param(${argDef.mkString(",")}) {
| ${classDef.mkString("\n ")}
| }""".stripMargin
| }""".stripMargin)
} else {
argDef += "out : NDArray"
s"""$scalaDoc
(s"""$scalaDoc
|$experimentalTag
| def ${func.name}(${argDef.mkString(", ")}) : $returnType
| """.stripMargin
| """.stripMargin, "")
}
}

Expand All @@ -316,7 +319,8 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
* @return A MD5 string
*/
def writeFile(FILE_PATH: String, packageDef: String, className: String,
imports: String, absFuncs: Seq[String]): String = {
imports: String, absFuncs: Seq[String],
paramClass: Option[Seq[String]] = None): String = {

val finalStr =
s"""/*
Expand All @@ -343,7 +347,9 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
|// scalastyle:off
|abstract class $className {
|${absFuncs.mkString("\n")}
|}""".stripMargin
|}
|${paramClass.getOrElse(Seq()).mkString("\n")}
|""".stripMargin


val pw = new PrintWriter(new File(FILE_PATH + s"$className.scala"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public static void main(String[] args) {

// random
NDArray random = NDArray.random_uniform(
NDArray.new random_uniformParam()
new random_uniformParam()
.setLow(0.0f)
.setHigh(2.0f)
.setShape(new Shape(new int[]{10, 10}))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public static void main(String[] args) {
System.out.println(eleAdd);

// norm (L2 Norm)
NDArray normed = NDArray.norm(NDArray.new normParam(nd))[0];
NDArray normed = NDArray.norm(new normParam(nd))[0];
System.out.println(normed);
}
}