Skip to content

Commit

Permalink
In-line calculation of Meteor scores if you have installed multeval
Browse files Browse the repository at this point in the history
  • Loading branch information
elliottd committed May 4, 2016
1 parent 01b48ce commit bafe112
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 8 deletions.
81 changes: 76 additions & 5 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,20 @@

# Dimensionality of image feature vector
IMG_FEATS = 4096
MULTEVAL_DIR = 'multeval-0.5.1'

class cd:
"""Context manager for changing the current working directory"""
"""http://stackoverflow.com/questions/431684/how-do-i-cd-in-python"""
def __init__(self, newPath):
self.newPath = newPath

def __enter__(self):
self.savedPath = os.getcwd()
os.chdir(self.newPath)

def __exit__(self, etype, value, traceback):
os.chdir(self.savedPath)

class GroundedTranslationGenerator:

Expand Down Expand Up @@ -79,12 +92,17 @@ def generate(self):
self.hsn_size = 0

if self.model == None:
self.build_model()
self.build_model(generate=True)

self.generate_sentences(self.args.checkpoint, val=not self.args.test)
if not self.args.without_scores:
score = self.bleu_score(self.args.checkpoint, val=not self.args.test)
if self.args.multeval:
score, _, _ = self.multeval_scores(self.args.checkpoint,
val=not self.args.test)
self.build_model(generate=False)
self.calculate_pplx(self.args.checkpoint, val=not self.args.test)
return self.bleu_score(self.args.checkpoint, val=not self.args.test)
return score

def generate_sentences(self, filepath, val=True):
"""
Expand Down Expand Up @@ -286,6 +304,11 @@ def generate_sentences(self, filepath, val=True):
sys.stdout.flush()
# print/extract each sentence until it hits the first end-of-string token
for s in complete_sentences:
if self.args.verbose:
logger.info("%s",' '.join([x for x
in itertools.takewhile(
lambda n: n != "<E>",
complete_sentences[i])]))
decoded_str = ' '.join([x for x
in itertools.takewhile(
lambda n: n != "<E>", s[1:])])
Expand Down Expand Up @@ -461,6 +484,47 @@ def bleu_score(self, directory, val=True):
bleu = float(bleuscore.lstrip())
return bleu

def multeval_scores(self, directory, val=True):
'''
Maybe you want to evaluate with Meteor, TER, and BLEU?
'''
prefix = "val" if val else "test"
self.extract_references(directory, val)

with cd(MULTEVAL_DIR):
subprocess.check_call(
['./multeval.sh eval --refs ../%s/%s_reference.* \
--hyps-baseline ../%s/%sGenerated \
--meteor.language de \
2> multevaloutput 1> multevaloutput'
% (directory, prefix, directory, prefix)], shell=True)
handle = open("multevaloutput")
multdata = handle.readlines()
handle.close()
for line in multdata:
if line.startswith("RESULT: baseline: BLEU: AVG:"):
mbleu = line.split(":")[4]
mbleu = mbleu.replace("\n","")
mbleu = mbleu.strip()
lr = mbleu.split(".")
mbleu = lr[0]+"."+lr[1][0:2]
if line.startswith("RESULT: baseline: METEOR: AVG:"):
mmeteor = line.split(":")[4]
mmeteor = mmeteor.replace("\n","")
mmeteor = mmeteor.strip()
lr = mmeteor.split(".")
mmeteor = float(lr[0]+"."+lr[1][0:2])
if line.startswith("RESULT: baseline: TER: AVG:"):
mter = line.split(":")[4]
mter = mter.replace("\n","")
mter = mter.strip()
lr = mter.split(".")
mter = lr[0]+"."+lr[1][0:2]

logger.info("Meteor = %.2f", mmeteor)

return mmeteor, mbleu, mter

def extract_references(self, directory, val=True):
"""
Get reference descriptions for split we are generating outputs for.
Expand All @@ -474,12 +538,17 @@ def extract_references(self, directory, val=True):
codecs.open('%s/%s_reference.ref%d'
% (directory, prefix, refid), 'w', 'utf-8').write('\n'.join([x[refid] for x in references]))

def build_model(self):
def build_model(self, generate=False):
'''
Build a Keras model if one does not yet exist.
Helper function for generate().
'''

if generate:
t = self.args.generation_timesteps
else:
t = self.data_gen.max_seq_len
if self.args.mrnn:
m = models.MRNN(self.args.embed_size, self.args.hidden_size,
self.vocab_len,
Expand All @@ -489,7 +558,7 @@ def build_model(self):
weights=self.args.checkpoint,
gru=self.args.gru,
clipnorm=self.args.clipnorm,
t=self.data_gen.max_seq_len)
t=t)
else:
m = models.NIC(self.args.embed_size, self.args.hidden_size,
self.vocab_len,
Expand All @@ -499,7 +568,7 @@ def build_model(self):
weights=self.args.checkpoint,
gru=self.args.gru,
clipnorm=self.args.clipnorm,
t=self.data_gen.max_seq_len)
t=t)

self.model = m.buildKerasModel(use_sourcelang=self.use_sourcelang,
use_image=self.use_image)
Expand Down Expand Up @@ -619,6 +688,8 @@ def build_model(self):
help="Verbose output while decoding? If you choose\
verbose output then you'll see the total beam search\
decoding process. (Default = False)")
parser.add_argument("--multeval", action="store_true",
help="Evaluate using multeval?")

# Legacy options
parser.add_argument("--generate_from_N_words", type=int, default=0,
Expand Down
8 changes: 5 additions & 3 deletions util/timesteps.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def dual_search(self):
sampler = GroundedTranslationGenerator(self.args)

handle = open("../logs/timesteps-%s.log" % self.args.run_string, "w")
handle.write("{:3} | {:3} | {:3} | {:10}\n".format("Run", "T", "Beam", "BLEU"))
handle.write("{:3} | {:3} | {:3} | {:10}\n".format("Run", "T", "Beam", "Meteor"))
handle.close()
run = 0
for t in xrange(self.args.min_timesteps, self.args.max_timesteps+1):
Expand All @@ -65,10 +65,12 @@ def dual_search(self):
logger.info("Setting beam_width to: %d", b)
sampler.args.generation_timesteps = t
sampler.args.beam_width = b
bleu = sampler.generate()
sampler.model = None
meteor = sampler.generate()

handle.write("{:3d} | {:5} | {:5} | {:1.5f} \n".format(run,
sampler.args.generation_timesteps, sampler.args.beam_width, bleu))
sampler.args.generation_timesteps,
sampler.args.beam_width, meteor))
handle.close()
run += 1

Expand Down

0 comments on commit bafe112

Please sign in to comment.