forked from kaldi-asr/kaldi
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'dropout_schedule' into nnet3-dropout
- Loading branch information
Showing
15 changed files
with
180 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -389,44 +389,50 @@ def _parse_dropout_string(num_archives_to_process, dropout_str): | |
""" | ||
dropout_values = [] | ||
parts = dropout_str.strip().split(',') | ||
|
||
try: | ||
if len(parts) < 2: | ||
raise Exception("dropout proportion string must specify " | ||
"at least the start and end dropouts") | ||
|
||
# Starting dropout proportion | ||
dropout_values.append((0, float(parts[0]))) | ||
data_fraction_one_previous='' # used to control situations like: [email protected],[email protected] | ||
for i in range(1, len(parts) - 1): | ||
value_x_pair = parts[i].split('@') | ||
if len(value_x_pair) == 1: | ||
# Dropout proportion at half of training | ||
dropout_proportion = float(parts[i]) | ||
dropout_values.append((0.5 * num_archives_to_process, | ||
dropout_proportion)) | ||
dropout_proportion = float(value_x_pair) | ||
num_archives = int(0.5 * num_archives_to_process) | ||
else: | ||
assert len(value_x_pair) == 2 | ||
dropout_proportion, data_fraction = value_x_pair | ||
if data_fraction == data_fraction_one_previous : | ||
dropout_values.append( | ||
(float(data_fraction) * num_archives_to_process + 1.0, | ||
float(dropout_proportion))) | ||
else: | ||
dropout_values.append( | ||
(float(data_fraction) * num_archives_to_process, | ||
float(dropout_proportion))) | ||
_, data_fraction_one_previous = value_x_pair | ||
|
||
dropout_proportion = float(value_x_pair[0]) | ||
data_fraction = float(value_x_pair[1]) | ||
num_archives = round(float(data_fraction) | ||
* num_archives_to_process) | ||
|
||
if (num_archives < dropout_values[-1][0] | ||
or num_archives >= num_archives_to_process): | ||
logger.error( | ||
"Failed while parsing value %s in dropout-schedule. " | ||
"dropout-schedule must be in incresing " | ||
"order of data fractions.", value_x_pair) | ||
raise ValueError | ||
elif num_archives == dropout_values[-1][0]: | ||
num_archives += 1.0 | ||
|
||
dropout_values.append(num_archives, float(dropout_proportion)) | ||
|
||
dropout_values.append((num_archives_to_process, float(parts[-1]))) | ||
except Exception: | ||
logger.error("Unable to parse dropout proportion string {0}. " | ||
logger.error("Unable to parse dropout proportion string %s. " | ||
"See help for option " | ||
"--trainer.dropout-schedule.".format(dropout_str)) | ||
"--trainer.dropout-schedule.", dropout_str) | ||
raise | ||
|
||
# reverse sort so that its easy to retrieve the dropout proportion | ||
# for a particular data fraction | ||
dropout_values.sort(key=lambda x: x[0], reverse=True) | ||
dropout_values.reverse() | ||
for num_archives, proportion in dropout_values: | ||
assert num_archives <= num_archives_to_process and num_archives >= 0 | ||
assert proportion <= 1 and proportion >= 0 | ||
|
@@ -738,7 +744,8 @@ def __init__(self): | |
doesn't increase the effective learning | ||
rate.""") | ||
self.parser.add_argument("--trainer.dropout-schedule", type=str, | ||
dest='dropout_schedule', default='', | ||
action=common_lib.NullstrToNoneAction, | ||
dest='dropout_schedule', default=None, | ||
help="""Use this to specify the dropout | ||
schedule. You specify a piecewise linear | ||
function on the domain [0,1], where 0 is the | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
|
||
ifndef DOUBLE_PRECISION | ||
$(error DOUBLE_PRECISION not defined.) | ||
endif | ||
|
||
|
||
CUDA_INCLUDE= -I$(CUDATKDIR)/include | ||
CUDA_FLAGS = -g -Xcompiler -fPIC --verbose --machine 64 -DHAVE_CUDA \ | ||
-DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) | ||
CXXFLAGS += -DHAVE_CUDA -I$(CUDATKDIR)/include | ||
CUDA_LDFLAGS += -L$(CUDATKDIR)/lib64 -Wl,-rpath,$(CUDATKDIR)/lib64 | ||
CUDA_LDLIBS += -lcublas -lcudart -lcurand #LDLIBS : The libs are loaded later than static libs in implicit rule |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# You have to make sure ATLASLIBS is set... | ||
|
||
ifndef FSTROOT | ||
$(error FSTROOT not defined.) | ||
endif | ||
|
||
ifndef ATLASINC | ||
$(error ATLASINC not defined.) | ||
endif | ||
|
||
ifndef ATLASLIBS | ||
$(error ATLASLIBS not defined.) | ||
endif | ||
|
||
|
||
DOUBLE_PRECISION = 0 | ||
CXXFLAGS = -m64 -maltivec -mcpu=power8 -Wall -I.. \ | ||
-mtune=power8 -mpower8-vector -mvsx -pthread \ | ||
-DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) \ | ||
-Wno-sign-compare -Wno-unused-local-typedefs -Winit-self \ | ||
-DHAVE_EXECINFO_H=1 -rdynamic -DHAVE_CXXABI_H \ | ||
-DHAVE_ATLAS -I$(ATLASINC) \ | ||
-I$(FSTROOT)/include \ | ||
$(EXTRA_CXXFLAGS) \ | ||
-g # -O0 -DKALDI_PARANOID | ||
|
||
ifeq ($(KALDI_FLAVOR), dynamic) | ||
CXXFLAGS += -fPIC | ||
endif | ||
|
||
LDFLAGS = -rdynamic $(OPENFSTLDFLAGS) | ||
LDLIBS = $(EXTRA_LDLIBS) $(OPENFSTLIBS) $(ATLASLIBS) -lm -lpthread -ldl | ||
CC = g++ | ||
CXX = g++ | ||
AR = ar | ||
AS = as | ||
RANLIB = ranlib |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# You have to make sure FSTROOT,OPENBLASROOT,OPENBLASLIBS are set... | ||
|
||
ifndef FSTROOT | ||
$(error FSTROOT not defined.) | ||
endif | ||
|
||
ifndef OPENBLASLIBS | ||
$(error OPENBLASLIBS not defined.) | ||
endif | ||
|
||
ifndef OPENBLASROOT | ||
$(error OPENBLASROOT not defined.) | ||
endif | ||
|
||
|
||
DOUBLE_PRECISION = 0 | ||
CXXFLAGS = -m64 -maltivec -mcpu=power8 -Wall -I.. \ | ||
-mtune=power8 -mpower8-vector -mvsx -pthread \ | ||
-DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) \ | ||
-Wno-sign-compare -Wno-unused-local-typedefs -Winit-self \ | ||
-DHAVE_EXECINFO_H=1 -rdynamic -DHAVE_CXXABI_H \ | ||
-DHAVE_OPENBLAS -I $(OPENBLASROOT)/include \ | ||
-I $(FSTROOT)/include \ | ||
$(EXTRA_CXXFLAGS) \ | ||
-g # -O0 -DKALDI_PARANOID | ||
|
||
ifeq ($(KALDI_FLAVOR), dynamic) | ||
CXXFLAGS += -fPIC | ||
endif | ||
|
||
LDFLAGS = -rdynamic $(OPENFSTLDFLAGS) | ||
LDLIBS = $(EXTRA_LDLIBS) $(OPENFSTLIBS) $(OPENBLASLIBS) -lm -lpthread -ldl | ||
CC = g++ | ||
CXX = g++ | ||
AR = ar | ||
AS = as | ||
RANLIB = ranlib |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.