Skip to content

Commit

Permalink
Select optimizer in model-builder (tensorflow#114)
Browse files Browse the repository at this point in the history
* Select optimizer in model-builder

* Disable hyperparameters by given need* flags

* Fix lint issue

* Merge branch 'master' into select-optimizer

* Fix lint with 'Google' clang format

* Merge branch 'master' into select-optimizer

* Merge branch 'master' into select-optimizer
  • Loading branch information
Lewuathe authored and mnottheone committed Dec 1, 2018
1 parent 47c1a8e commit 9580c6e
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 4 deletions.
15 changes: 14 additions & 1 deletion demos/model-builder/model-builder.html
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,25 @@
<paper-input no-animations label="Learning Rate" id="learning-rate-input" disabled="[[!datasetDownloaded]]" value={{learningRate}}>
</paper-input>

<paper-input no-animations label="Momentum" id="momentum" disabled="[[!datasetDownloaded]]" value={{momentum}}>
<paper-input no-animations label="Momentum" id="momentum" disabled="[[!needMomentum]]" value={{momentum}}>
</paper-input>

<paper-input no-animations label="Gamma" id="gamma" disabled="[[!needGamma]]" value={{gamma}}>
</paper-input>

<paper-input no-animations label="Batch Size" id="batch-size" disabled="[[!datasetDownloaded]]" value={{batchSize}}>
</paper-input>

<paper-dropdown-menu no-animations label="Optimizer" id="optimizer-dropdown" disabled="[[!datasetDownloaded]]">
<paper-listbox attr-for-selected="value" class="dropdown-content" selected="{{selectedOptimizerName}}" slot="dropdown-content">
<template is="dom-repeat" items="[[optimizerNames]]">
<paper-item value="[[item]]" label="[[item]]">
[[item]]
</paper-item>
</template>
</paper-listbox>
</paper-dropdown-menu>

<div hidden$="[[isValid]]" class="model-error">
<div hidden$="[[!datasetDownloaded]]"">
<paper-tooltip animation-delay="0" fit-to-visible-bounds>
Expand Down
84 changes: 81 additions & 3 deletions demos/model-builder/model-builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import '../demo-header';
import '../demo-footer';

// tslint:disable-next-line:max-line-length
import {Array1D, Array3D, DataStats, FeedEntry, Graph, GraphRunner, GraphRunnerEventObserver, InCPUMemoryShuffledInputProviderBuilder, InMemoryDataset, MetricReduction, MomentumOptimizer, NDArray, NDArrayMath, NDArrayMathCPU, NDArrayMathGPU, Optimizer, Scalar, Session, Tensor, util, xhr_dataset, XhrDataset, XhrDatasetConfig} from '../deeplearn';
import {Array1D, Array3D, DataStats, FeedEntry, Graph, GraphRunner, GraphRunnerEventObserver, InCPUMemoryShuffledInputProviderBuilder, InMemoryDataset, MetricReduction, MomentumOptimizer, SGDOptimizer, RMSPropOptimizer, AdagradOptimizer, NDArray, NDArrayMath, NDArrayMathCPU, NDArrayMathGPU, Optimizer, Scalar, Session, Tensor, util, xhr_dataset, XhrDataset, XhrDatasetConfig} from '../deeplearn';
import {NDArrayImageVisualizer} from '../ndarray-image-visualizer';
import {NDArrayLogitsVisualizer} from '../ndarray-logits-visualizer';
import {PolymerElement, PolymerHTMLElement} from '../polymer-spec';
Expand Down Expand Up @@ -73,8 +73,13 @@ export let ModelBuilderPolymer: new () => PolymerHTMLElement = PolymerElement({
datasetNames: Array,
selectedDatasetName: String,
modelNames: Array,
selectedOptimizerName: String,
optimizerNames: Array,
learningRate: Number,
momentum: Number,
needMomentum: Boolean,
gamma: Number,
needGamma: Boolean,
batchSize: Number,
selectedModelName: String,
selectedNormalizationOption:
Expand Down Expand Up @@ -119,13 +124,18 @@ export class ModelBuilder extends ModelBuilderPolymer {
private selectedDatasetName: string;
private modelNames: string[];
private selectedModelName: string;
private optimizerNames: string[];
private selectedOptimizerName: string;
private loadedWeights: LayerWeightsDict[]|null;
private dataSets: {[datasetName: string]: InMemoryDataset};
private dataSet: InMemoryDataset;
private xhrDatasetConfigs: {[datasetName: string]: XhrDatasetConfig};
private datasetStats: DataStats[];
private learingRate: number;
private momentum: number;
private needMomentum: boolean;
private gamma: number;
private needGamma: boolean;
private batchSize: number;

// Stats.
Expand Down Expand Up @@ -223,9 +233,21 @@ export class ModelBuilder extends ModelBuilderPolymer {
this.setupDatasetStats();
});
}
this.querySelector("#optimizer-dropdown .dropdown-content")
// tslint:disable-next-line:no-any
.addEventListener('iron-activate', (event: any) => {
// Activate, deactivate hyper parameter inputs.
this.refreshHyperParamRequirements(event.detail.selected);
});
this.learningRate = 0.1;
this.momentum = 0.1;
this.needMomentum = true;
this.gamma = 0.1;
this.needGamma = false;
this.batchSize = 64;
// Default optimizer is momentum
this.selectedOptimizerName = "momentum";
this.optimizerNames = ["sgd", "momentum", "rmsprop", "adagrad"];

this.applicationState = ApplicationState.IDLE;
this.loadedWeights = null;
Expand Down Expand Up @@ -279,6 +301,8 @@ export class ModelBuilder extends ModelBuilderPolymer {
return applicationState === ApplicationState.IDLE;
}



private getTestData(): NDArray[][] {
const data = this.dataSet.getData();
if (data == null) {
Expand Down Expand Up @@ -322,12 +346,66 @@ export class ModelBuilder extends ModelBuilderPolymer {
}
}

private resetHyperParamRequirements() {
this.needMomentum = false;
this.needGamma = false;
}

/**
* Set flag to disable input by optimizer selection.
*/
private refreshHyperParamRequirements(optimizerName: string) {
this.resetHyperParamRequirements();
switch (optimizerName) {
case "sgd": {
// No additional hyper parameters
break;
}
case "momentum": {
this.needMomentum = true;
break;
}
case "rmsprop": {
this.needMomentum = true;
this.needGamma = true;
break;
}
case "adagrad": {
this.needMomentum = true;
break;
}
default: {
throw new Error(`Unknown optimizer "${this.selectedOptimizerName}"`);
}
}
}

private createOptimizer() {
switch (this.selectedOptimizerName) {
case 'sgd': {
return new SGDOptimizer(+this.learningRate);
}
case 'momentum': {
return new MomentumOptimizer(+this.learningRate, +this.momentum);
}
case 'rmsprop': {
return new RMSPropOptimizer(+this.learningRate, +this.gamma);
}
case 'adagrad': {
return new AdagradOptimizer(+this.learningRate, +this.momentum);
}
default: {
throw new Error(`Unknown optimizer "${this.selectedOptimizerName}"`);
}
}
}

private startTraining() {
const trainingData = this.getTrainingData();
const testData = this.getTestData();

// Recreate optimizer with the latest learning rate.
this.optimizer = new MomentumOptimizer(+this.learningRate, +this.momentum);
// Recreate optimizer with the selected optimizer and hyperparameters.
this.optimizer = this.createOptimizer();

if (this.isValid && (trainingData != null) && (testData != null)) {
this.recreateCharts();
Expand Down

0 comments on commit 9580c6e

Please sign in to comment.