Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract tensorflow-dependent classes into a separate module #412

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
10 changes: 0 additions & 10 deletions api/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,10 @@ project.setDescription("This module contains the Kotlin API for building, traini
dependencies {
api project(":dataset")
implementation 'org.jetbrains.kotlin:kotlin-stdlib-jdk8:1.6.21'
api group: 'org.tensorflow', name: 'tensorflow', version: '1.15.0'
api 'com.github.doyaaaaaken:kotlin-csv-jvm:0.7.3' // for csv parsing
api 'io.github.microutils:kotlin-logging:2.1.21' // for logging
api 'io.jhdf:jhdf:0.5.7' // for hdf5 parsing
api 'com.beust:klaxon:5.5'
testImplementation 'ch.qos.logback:logback-classic:1.2.11'
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.2'
testImplementation 'org.junit.jupiter:junit-jupiter-engine:5.8.2'
testImplementation 'org.junit.jupiter:junit-jupiter-params:5.8.2'
testImplementation 'org.junit.jupiter:junit-jupiter-engine:5.8.2'
}

compileKotlin {
Expand All @@ -27,10 +21,6 @@ kotlin {
explicitApiWarning()
}

test {
useJUnitPlatform()
}

task fatJar(type: Jar) {
duplicatesStrategy = DuplicatesStrategy.INCLUDE

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,6 @@

package org.jetbrains.kotlinx.dl.api.core.metric

import org.jetbrains.kotlinx.dl.api.core.loss.Losses
import org.jetbrains.kotlinx.dl.api.core.loss.ReductionType
import org.jetbrains.kotlinx.dl.api.core.loss.allAxes
import org.jetbrains.kotlinx.dl.api.core.loss.safeMean
import org.jetbrains.kotlinx.dl.api.core.util.getDType
import org.tensorflow.Operand
import org.tensorflow.op.Ops
import org.tensorflow.op.core.ReduceSum
import org.tensorflow.op.math.Mean

/**
* Metrics.
*/
Expand Down Expand Up @@ -46,146 +36,4 @@ public enum class Metrics {
* `loss = square(log(y_true + 1.) - log(y_pred + 1.))`
*/
MSLE;

public companion object {
/** Converts enum value to subclass of [Metric]. */
public fun convert(metricType: Metrics): Metric {
return when (metricType) {
ACCURACY -> Accuracy()
MAE -> MAE()
MSE -> MSE()
MSLE -> MSLE()
}
}

/** Converts subclass of [Metric] to enum value. */
public fun convertBack(metric: Metric): Metrics {
return when (metric) {
is Accuracy -> ACCURACY
is org.jetbrains.kotlinx.dl.api.core.metric.MAE -> MAE
is org.jetbrains.kotlinx.dl.api.core.metric.MSE -> MSE
is org.jetbrains.kotlinx.dl.api.core.metric.MSLE -> MSLE
else -> ACCURACY
}
}
}
}

/**
* @see [Metrics.ACCURACY]
*/
public class Accuracy : Metric(reductionType = ReductionType.SUM_OVER_BATCH_SIZE) {
override fun apply(
tf: Ops,
yPred: Operand<Float>,
yTrue: Operand<Float>,
numberOfLabels: Operand<Float>?
): Operand<Float> {
val predicted: Operand<Long> = tf.math.argMax(yPred, tf.constant(1))
val expected: Operand<Long> = tf.math.argMax(yTrue, tf.constant(1))

return tf.math.mean(tf.dtypes.cast(tf.math.equal(predicted, expected), getDType()), tf.constant(0))
}
}

/**
* @see [Losses.MSE]
*/
public class MSE(reductionType: ReductionType = ReductionType.SUM_OVER_BATCH_SIZE) : Metric(reductionType) {
override fun apply(
tf: Ops,
yPred: Operand<Float>,
yTrue: Operand<Float>,
numberOfLabels: Operand<Float>?
): Operand<Float> {
val squaredError = tf.math.squaredDifference(yPred, yTrue)
return meanOfMetrics(tf, reductionType, squaredError, numberOfLabels, "Metric_MSE")
}
}

/**
* @see [Losses.MAE]
*/
public class MAE(reductionType: ReductionType = ReductionType.SUM_OVER_BATCH_SIZE) : Metric(reductionType) {
override fun apply(
tf: Ops,
yPred: Operand<Float>,
yTrue: Operand<Float>,
numberOfLabels: Operand<Float>?
): Operand<Float> {
val absoluteErrors = tf.math.abs(tf.math.sub(yPred, yTrue))
return meanOfMetrics(tf, reductionType, absoluteErrors, numberOfLabels, "Metric_MAE")
}
}

/**
* @see [Losses.MAPE]
*/
public class MAPE(reductionType: ReductionType = ReductionType.SUM_OVER_BATCH_SIZE) : Metric(reductionType) {
override fun apply(
tf: Ops,
yPred: Operand<Float>,
yTrue: Operand<Float>,
numberOfLabels: Operand<Float>?
): Operand<Float> {
val epsilon = 1e-7f

val diff = tf.math.abs(
tf.math.div(
tf.math.sub(yTrue, yPred),
tf.math.maximum(tf.math.abs(yTrue), tf.constant(epsilon))
)
)

return meanOfMetrics(tf, reductionType, tf.math.mul(diff, tf.constant(100f)), numberOfLabels, "Metric_MAPE")
}
}

/**
* @see [Losses.MSLE]
*/
public class MSLE(reductionType: ReductionType = ReductionType.SUM_OVER_BATCH_SIZE) : Metric(reductionType) {
override fun apply(
tf: Ops,
yPred: Operand<Float>,
yTrue: Operand<Float>,
numberOfLabels: Operand<Float>?
): Operand<Float> {
val epsilon = 1e-5f

val firstLog = tf.math.log(tf.math.add(tf.math.maximum(yPred, tf.constant(epsilon)), tf.constant(1.0f)))
val secondLog = tf.math.log(tf.math.add(tf.math.maximum(yTrue, tf.constant(epsilon)), tf.constant(1.0f)))

val loss = tf.math.squaredDifference(firstLog, secondLog)

return meanOfMetrics(tf, reductionType, loss, numberOfLabels, "Metric_MSLE")
}
}

internal fun meanOfMetrics(
tf: Ops,
reductionType: ReductionType,
metric: Operand<Float>,
numberOfLabels: Operand<Float>?,
metricName: String
): Operand<Float> {
val meanMetric = tf.math.mean(metric, tf.constant(-1), Mean.keepDims(false))

var totalMetric: Operand<Float> = tf.reduceSum(
meanMetric,
allAxes(tf, meanMetric),
ReduceSum.keepDims(false)
)

if (reductionType == ReductionType.SUM_OVER_BATCH_SIZE) {
check(numberOfLabels != null) { "Operand numberOfLosses must be not null." }

totalMetric = safeMean(
tf,
metric,
numberOfLabels
)
}

return tf.withName(metricName).identity(totalMetric)
}
}

This file was deleted.

Loading