From 24e17f10ce3a8c217e4b822b1f94fbd9caf1c350 Mon Sep 17 00:00:00 2001
From: Richard Zowalla <13417392+rzo1@users.noreply.github.com>
Date: Wed, 29 May 2024 08:09:54 +0200
Subject: [PATCH] OPENNLP-1539 - Introduce parameter for POSTaggerME to
configure output POS tag format (#601)
---
.../opennlp/tools/postag/POSTagFormat.java} | 38 +---
.../tools/postag/POSTagFormatMapper.java | 209 ++++++++++++++++++
.../opennlp/tools/postag/POSTaggerME.java | 92 +++++---
.../namefind/TokenNameFinderModelTest.java | 2 +-
.../opennlp/tools/postag/POSModelTest.java | 4 +-
.../opennlp/tools/postag/POSTaggerMETest.java | 135 +++++++++--
.../POSTaggerNameFeatureGeneratorTest.java | 2 +-
7 files changed, 400 insertions(+), 82 deletions(-)
rename opennlp-tools/src/{test/java/opennlp/tools/postag/POSTaggerMEIT.java => main/java/opennlp/tools/postag/POSTagFormat.java} (54%)
create mode 100644 opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormatMapper.java
diff --git a/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMEIT.java b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormat.java
similarity index 54%
rename from opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMEIT.java
rename to opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormat.java
index 9b521ce89..ddb9cc5f5 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMEIT.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormat.java
@@ -14,38 +14,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package opennlp.tools.postag;
-import java.io.IOException;
-
-import org.junit.jupiter.api.Assertions;
-import org.junit.jupiter.api.BeforeAll;
-import org.junit.jupiter.api.Test;
-
-public class POSTaggerMEIT {
-
- private static POSTagger tagger;
-
- @BeforeAll
- public static void prepare() throws IOException {
- tagger = new POSTaggerME("en");
- }
-
- @Test
- void testPOSTagger() {
-
- String[] tags = tagger.tag(new String[] {
- "The",
- "driver",
- "got",
- "badly",
- "injured",
- "."});
-
- // TODO OPENNLP-1539 Adjust this depending on the POSFormat
- String[] expected = {"DET", "NOUN", "VERB", "ADV", "VERB", "PUNCT"};
- Assertions.assertArrayEquals(expected, tags);
- }
+/**
+ * Defines the format for part-of-speech tagging, i.e.
+ * PENN
+ * or UD format.
+ */
+public enum POSTagFormat {
+ UD, PENN, UNKNOWN
}
diff --git a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormatMapper.java b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormatMapper.java
new file mode 100644
index 000000000..e02cb5204
--- /dev/null
+++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormatMapper.java
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package opennlp.tools.postag;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A mapping implementation for converting between different POS tag formats.
+ * This class supports conversion between Penn Treebank (PENN) and Universal Dependencies (UD) formats.
+ * The conversion is based on the Universal Dependencies conversion table.
+ * Please note that when converting from UD to Penn format, there may be ambiguity in some cases.
+ */
+public class POSTagFormatMapper {
+
+ private static final Logger logger = LoggerFactory.getLogger(POSTagFormatMapper.class);
+
+ private static final Map CONVERSION_TABLE_PENN_TO_UD = new HashMap<>();
+ private static final Map CONVERSION_TABLE_UD_TO_PENN = new HashMap<>();
+
+ static {
+ /*
+ * This is a conversion table to convert PENN to UD format as described in
+ * https://universaldependencies.org/tagset-conversion/en-penn-uposf.html
+ */
+ CONVERSION_TABLE_PENN_TO_UD.put("#", "SYM");
+ CONVERSION_TABLE_PENN_TO_UD.put("$", "SYM");
+ CONVERSION_TABLE_PENN_TO_UD.put("''", "PUNCT");
+ CONVERSION_TABLE_PENN_TO_UD.put(",", "PUNCT");
+ CONVERSION_TABLE_PENN_TO_UD.put("-LRB-", "PUNCT");
+ CONVERSION_TABLE_PENN_TO_UD.put("-RRB-", "PUNCT");
+ CONVERSION_TABLE_PENN_TO_UD.put(".", "PUNCT");
+ CONVERSION_TABLE_PENN_TO_UD.put(":", "PUNCT");
+ CONVERSION_TABLE_PENN_TO_UD.put("AFX", "ADJ");
+ CONVERSION_TABLE_PENN_TO_UD.put("CC", "CCONJ");
+ CONVERSION_TABLE_PENN_TO_UD.put("CD", "NUM");
+ CONVERSION_TABLE_PENN_TO_UD.put("DT", "DET");
+ CONVERSION_TABLE_PENN_TO_UD.put("EX", "PRON");
+ CONVERSION_TABLE_PENN_TO_UD.put("FW", "X");
+ CONVERSION_TABLE_PENN_TO_UD.put("HYPH", "PUNCT");
+ CONVERSION_TABLE_PENN_TO_UD.put("IN", "ADP");
+ CONVERSION_TABLE_PENN_TO_UD.put("JJ", "ADJ");
+ CONVERSION_TABLE_PENN_TO_UD.put("JJR", "ADJ");
+ CONVERSION_TABLE_PENN_TO_UD.put("JJS", "ADJ");
+ CONVERSION_TABLE_PENN_TO_UD.put("LS", "X");
+ CONVERSION_TABLE_PENN_TO_UD.put("MD", "VERB");
+ CONVERSION_TABLE_PENN_TO_UD.put("NIL", "X");
+ CONVERSION_TABLE_PENN_TO_UD.put("NN", "NOUN");
+ CONVERSION_TABLE_PENN_TO_UD.put("NNP", "PROPN");
+ CONVERSION_TABLE_PENN_TO_UD.put("NNPS", "PROPN");
+ CONVERSION_TABLE_PENN_TO_UD.put("NNS", "NOUN");
+ CONVERSION_TABLE_PENN_TO_UD.put("PDT", "DET");
+ CONVERSION_TABLE_PENN_TO_UD.put("POS", "PART");
+ CONVERSION_TABLE_PENN_TO_UD.put("PRP", "PRON");
+ CONVERSION_TABLE_PENN_TO_UD.put("PRP$", "DET");
+ CONVERSION_TABLE_PENN_TO_UD.put("RB", "ADV");
+ CONVERSION_TABLE_PENN_TO_UD.put("RBR", "ADV");
+ CONVERSION_TABLE_PENN_TO_UD.put("RBS", "ADV");
+ CONVERSION_TABLE_PENN_TO_UD.put("RP", "ADP");
+ CONVERSION_TABLE_PENN_TO_UD.put("SYM", "SYM");
+ CONVERSION_TABLE_PENN_TO_UD.put("TO", "PART");
+ CONVERSION_TABLE_PENN_TO_UD.put("UH", "INTJ");
+ CONVERSION_TABLE_PENN_TO_UD.put("VB", "VERB");
+ CONVERSION_TABLE_PENN_TO_UD.put("VBD", "VERB");
+ CONVERSION_TABLE_PENN_TO_UD.put("VBG", "VERB");
+ CONVERSION_TABLE_PENN_TO_UD.put("VBN", "VERB");
+ CONVERSION_TABLE_PENN_TO_UD.put("VBP", "VERB");
+ CONVERSION_TABLE_PENN_TO_UD.put("VBZ", "VERB");
+ CONVERSION_TABLE_PENN_TO_UD.put("WDT", "DET");
+ CONVERSION_TABLE_PENN_TO_UD.put("WP", "PRON");
+ CONVERSION_TABLE_PENN_TO_UD.put("WP$", "DET");
+ CONVERSION_TABLE_PENN_TO_UD.put("WRB", "ADV");
+
+ /*
+ * Note: The back conversion might lose information.
+ */
+ CONVERSION_TABLE_UD_TO_PENN.put("ADJ", "JJ");
+ CONVERSION_TABLE_UD_TO_PENN.put("ADP", "IN");
+ CONVERSION_TABLE_UD_TO_PENN.put("ADV", "RB");
+ CONVERSION_TABLE_UD_TO_PENN.put("AUX", "MD");
+ CONVERSION_TABLE_UD_TO_PENN.put("CCONJ", "CC");
+ CONVERSION_TABLE_UD_TO_PENN.put("DET", "DT");
+ CONVERSION_TABLE_UD_TO_PENN.put("INTJ", "UH");
+ CONVERSION_TABLE_UD_TO_PENN.put("NOUN", "NN");
+ CONVERSION_TABLE_UD_TO_PENN.put("NUM", "CD");
+ CONVERSION_TABLE_UD_TO_PENN.put("PART", "RP");
+ CONVERSION_TABLE_UD_TO_PENN.put("PRON", "PRP");
+ CONVERSION_TABLE_UD_TO_PENN.put("PROPN", "NNP");
+ CONVERSION_TABLE_UD_TO_PENN.put("PUNCT", ".");
+ CONVERSION_TABLE_UD_TO_PENN.put("SCONJ", "IN");
+ CONVERSION_TABLE_UD_TO_PENN.put("SYM", "SYM");
+ CONVERSION_TABLE_UD_TO_PENN.put("VERB", "VB");
+ CONVERSION_TABLE_UD_TO_PENN.put("X", "FW");
+ }
+
+ private final POSTagFormat modelFormat;
+
+ protected POSTagFormatMapper(final String[] possibleOutcomes) {
+ this.modelFormat = guessModelTagFormat(possibleOutcomes);
+ }
+
+ /**
+ * Converts a a list of tags to the specified format.
+ *
+ * @param tags a list of tags to be converted.
+ * @return an array containing the converted tags with the same order and size as the given input list.
+ * Note: A given tag might be {@code ?} if no mapping for the given {@code tag} could be found.
+ */
+ public String[] convertTags(List tags) {
+ Objects.requireNonNull(tags, "Supplied tags must not be NULL.");
+ return tags.stream()
+ .map(this::convertTag)
+ .toArray(String[]::new);
+ }
+
+ /**
+ * Converts a given tag to the specified format.
+ *
+ * @param tag no restrictions on this parameter.
+ * @return the converted tag form or {@code ?} if no mapping for {@code tag} could be found.
+ */
+ public String convertTag(String tag) {
+ switch (modelFormat) {
+ case UD -> {
+ return CONVERSION_TABLE_UD_TO_PENN.getOrDefault(tag, "?");
+ }
+ case PENN -> {
+ if ("NOUN".equals(tag)) {
+ logger.warn("Ambiguity detected: NN can be 'NN' or 'NNS' depending on the number. " +
+ "Returning 'NN'.");
+ }
+ if ("PART".equals(tag)) {
+ logger.warn("Ambiguity detected: PART can be 'RP' or 'TO'. Returning 'RP'.");
+ }
+ if ("PROPN".equals(tag)) {
+ logger.warn("Ambiguity detected: Can be 'NNP' or 'NNPS. Returning 'NNP'");
+ }
+ if ("PUNCT".equals(tag)) {
+ logger.warn("Ambiguity detected: PUNCT needs specific punctuation mapping. Returning '.'");
+ }
+ if ("VERB".equals(tag)) {
+ logger.warn("Ambiguity detected: VERB can be 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ'. " +
+ "Returning 'VERB'.");
+ }
+ return CONVERSION_TABLE_PENN_TO_UD.getOrDefault(tag, "?");
+ }
+ default -> {
+ return tag;
+ }
+ }
+ }
+
+ /**
+ *
+ * @return The guessed {@link POSTagFormat}. Guaranteed to be not {@code null}.
+ */
+ public POSTagFormat getGuessedFormat() {
+ return this.modelFormat;
+ }
+
+ /**
+ * Guesses the {@link POSTagFormat} by using majority quorum.
+ * @param outcomes must not be {@code null}.
+ * @return the guessed {@link POSTagFormat}.
+ * If the given input was empty, {@link POSTagFormat#UNKNOWN} is returned.
+ */
+ private POSTagFormat guessModelTagFormat(final String[] outcomes) {
+ Objects.requireNonNull(outcomes, "Outcomes must not be NULL.");
+ int udMatches = 0;
+ int pennMatches = 0;
+
+ for (String outcome : outcomes) {
+ if (CONVERSION_TABLE_UD_TO_PENN.containsKey(outcome)) {
+ udMatches++;
+ }
+ if (CONVERSION_TABLE_PENN_TO_UD.containsKey(outcome)) {
+ pennMatches++;
+ }
+ }
+
+ if (udMatches > pennMatches) {
+ return POSTagFormat.UD;
+ } else if (pennMatches > udMatches) {
+ return POSTagFormat.PENN;
+ } else {
+ logger.warn("Detected an unknown POS format.");
+ return POSTagFormat.UNKNOWN;
+ }
+ }
+}
diff --git a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java
index 0268f48b5..56b77c329 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java
@@ -19,6 +19,7 @@
import java.io.IOException;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -83,16 +84,30 @@ public class POSTaggerME implements POSTagger {
private final SequenceValidator sequenceValidator;
+ private final POSTagFormat posTagFormat;
+ private final POSTagFormatMapper posTagFormatMapper;
+
/**
* Initializes a {@link POSTaggerME} by downloading a default model for a given
* {@code language}.
*
* @param language An ISO conform language code.
- *
* @throws IOException Thrown if the model could not be downloaded or saved.
*/
public POSTaggerME(String language) throws IOException {
- this(DownloadUtil.downloadModel(language, DownloadUtil.ModelType.POS, POSModel.class));
+ this(language, POSTagFormat.UD);
+ }
+
+ /**
+ * Initializes a {@link POSTaggerME} by downloading a default model for a given
+ * {@code language}.
+ *
+ * @param language An ISO conform language code.
+ * @param format A valid {@link POSTagFormat}.
+ * @throws IOException Thrown if the model could not be downloaded or saved.
+ */
+ public POSTaggerME(String language, POSTagFormat format) throws IOException {
+ this(DownloadUtil.downloadModel(language, DownloadUtil.ModelType.POS, POSModel.class), format);
}
/**
@@ -101,6 +116,17 @@ public POSTaggerME(String language) throws IOException {
* @param model A valid {@link POSModel}.
*/
public POSTaggerME(POSModel model) {
+ this(model, POSTagFormat.UD);
+ }
+
+ /**
+ * Initializes a {@link POSTaggerME} with the provided {@link POSModel model}.
+ *
+ * @param model A valid {@link POSModel}.
+ * @param format A valid {@link POSTagFormat}.
+ */
+ public POSTaggerME(POSModel model, POSTagFormat format) {
+ this.posTagFormat = format;
POSTaggerFactory factory = model.getFactory();
int beamSize = POSTaggerME.DEFAULT_BEAM_SIZE;
@@ -121,12 +147,13 @@ public POSTaggerME(POSModel model) {
if (model.getPosSequenceModel() != null) {
this.model = model.getPosSequenceModel();
- }
- else {
+ } else {
this.model = new opennlp.tools.ml.BeamSearch<>(beamSize,
model.getPosModel(), 0);
}
+ this.posTagFormatMapper = new POSTagFormatMapper(getAllPosTags());
+
}
/**
@@ -144,16 +171,15 @@ public String[] tag(String[] sentence) {
@Override
public String[] tag(String[] sentence, Object[] additionalContext) {
bestSequence = model.bestSequence(sentence, additionalContext, contextGen, sequenceValidator);
- List t = bestSequence.getOutcomes();
- return t.toArray(new String[0]);
+ final List t = bestSequence.getOutcomes();
+ return convertTags(t);
}
/**
* Returns at most the specified {@code numTaggings} for the specified {@code sentence}.
*
* @param numTaggings The number of tagging to be returned.
- * @param sentence An array of tokens which make up a sentence.
- *
+ * @param sentence An array of tokens which make up a sentence.
* @return At most the specified number of taggings for the specified {@code sentence}.
*/
public String[][] tag(int numTaggings, String[] sentence) {
@@ -162,11 +188,19 @@ public String[][] tag(int numTaggings, String[] sentence) {
String[][] tags = new String[bestSequences.length][];
for (int si = 0; si < tags.length; si++) {
List t = bestSequences[si].getOutcomes();
- tags[si] = t.toArray(new String[0]);
+ tags[si] = convertTags(t);
}
return tags;
}
+ private String[] convertTags(List t) {
+ if (posTagFormatMapper.getGuessedFormat() == posTagFormat) {
+ return t.toArray(new String[0]);
+ } else {
+ return posTagFormatMapper.convertTags(t);
+ }
+ }
+
@Override
public Sequence[] topKSequences(String[] sentence) {
return this.topKSequences(sentence, null);
@@ -194,10 +228,10 @@ public double[] probs() {
}
public String[] getOrderedTags(List words, List tags, int index) {
- return getOrderedTags(words,tags,index,null);
+ return getOrderedTags(words, tags, index, null);
}
- public String[] getOrderedTags(List words, List tags, int index,double[] tprobs) {
+ public String[] getOrderedTags(List words, List tags, int index, double[] tprobs) {
if (modelPackage.getPosModel() != null) {
@@ -205,7 +239,7 @@ public String[] getOrderedTags(List words, List tags, int index,
double[] probs = posModel.eval(contextGen.getContext(index,
words.toArray(new String[0]),
- tags.toArray(new String[0]),null));
+ tags.toArray(new String[0]), null));
String[] orderedTags = new String[probs.length];
for (int i = 0; i < probs.length; i++) {
@@ -221,17 +255,16 @@ public String[] getOrderedTags(List words, List tags, int index,
}
probs[max] = 0;
}
- return orderedTags;
- }
- else {
+ return convertTags(Arrays.stream(orderedTags).toList());
+ } else {
throw new UnsupportedOperationException("This method can only be called if the "
+ "classification model is an event model!");
}
}
public static POSModel train(String languageCode,
- ObjectStream samples, TrainingParameters trainParams,
- POSTaggerFactory posFactory) throws IOException {
+ ObjectStream samples, TrainingParameters trainParams,
+ POSTaggerFactory posFactory) throws IOException {
int beamSize = trainParams.getIntParameter(BeamSearch.BEAM_SIZE_PARAMETER, POSTaggerME.DEFAULT_BEAM_SIZE);
@@ -249,14 +282,12 @@ public static POSModel train(String languageCode,
EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams,
manifestInfoEntries);
posModel = trainer.train(es);
- }
- else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) {
+ } else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) {
POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator);
EventModelSequenceTrainer trainer =
TrainerFactory.getEventModelSequenceTrainer(trainParams, manifestInfoEntries);
posModel = trainer.train(ss);
- }
- else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
+ } else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(
trainParams, manifestInfoEntries);
@@ -264,15 +295,13 @@ else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator);
seqPosModel = trainer.train(ss);
- }
- else {
+ } else {
throw new IllegalArgumentException("Trainer type is not supported: " + trainerType);
}
if (posModel != null) {
return new POSModel(languageCode, posModel, beamSize, manifestInfoEntries, posFactory);
- }
- else {
+ } else {
return new POSModel(languageCode, seqPosModel, manifestInfoEntries, posFactory);
}
}
@@ -282,9 +311,7 @@ else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) {
*
* @param samples The {@link ObjectStream} to process.
* @param cutoff A non-negative cut-off value.
- *
* @return A valid {@link Dictionary} instance holding nGrams.
- *
* @throws IOException Thrown if IO errors occurred during dictionary construction.
*/
public static Dictionary buildNGramDictionary(ObjectStream samples, int cutoff)
@@ -295,8 +322,9 @@ public static Dictionary buildNGramDictionary(ObjectStream samples, i
while ((sample = samples.read()) != null) {
String[] words = sample.getSentence();
- if (words.length > 0)
+ if (words.length > 0) {
ngramModel.add(new StringList(words), 1, 1);
+ }
}
ngramModel.cutoff(cutoff, Integer.MAX_VALUE);
@@ -308,13 +336,12 @@ public static Dictionary buildNGramDictionary(ObjectStream samples, i
* Populates a {@link POSDictionary} from an {@link ObjectStream} of samples.
*
* @param samples The {@link ObjectStream} to process.
- * @param dict The {@link MutableTagDictionary} to use during population.
+ * @param dict The {@link MutableTagDictionary} to use during population.
* @param cutoff A non-negative cut-off value.
- *
* @throws IOException Thrown if IO errors occurred during dictionary construction.
*/
public static void populatePOSDictionary(ObjectStream samples,
- MutableTagDictionary dict, int cutoff) throws IOException {
+ MutableTagDictionary dict, int cutoff) throws IOException {
logger.info("Expanding POS Dictionary ...");
long start = System.nanoTime();
@@ -377,6 +404,7 @@ public static void populatePOSDictionary(ObjectStream samples,
}
}
- logger.info("... finished expanding POS Dictionary. [ {} ms]", (System.nanoTime() - start) / 1000000 );
+ logger.info("... finished expanding POS Dictionary. [ {} ms]", (System.nanoTime() - start) / 1000000);
}
+
}
diff --git a/opennlp-tools/src/test/java/opennlp/tools/namefind/TokenNameFinderModelTest.java b/opennlp-tools/src/test/java/opennlp/tools/namefind/TokenNameFinderModelTest.java
index 5379a1f39..c77b28743 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/namefind/TokenNameFinderModelTest.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/namefind/TokenNameFinderModelTest.java
@@ -57,7 +57,7 @@ void testNERWithPOSModel() throws IOException {
Path resourcesFolder = Files.createTempDirectory("resources").toAbsolutePath();
// save a POS model there
- POSModel posModel = POSTaggerMETest.trainPOSModel(ModelType.MAXENT);
+ POSModel posModel = POSTaggerMETest.trainPennFormatPOSModel(ModelType.MAXENT);
Assertions.assertNotNull(posModel);
File posModelFile = new File(resourcesFolder.toFile(), "pos-model.bin");
diff --git a/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java b/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java
index 14086c16a..1565c45ba 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java
@@ -30,7 +30,7 @@ public class POSModelTest {
@Test
void testPOSModelSerializationMaxent() throws IOException {
- POSModel posModel = POSTaggerMETest.trainPOSModel(ModelType.MAXENT);
+ POSModel posModel = POSTaggerMETest.trainPennFormatPOSModel(ModelType.MAXENT);
Assertions.assertFalse(posModel.isLoadedFromSerialized());
try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
@@ -45,7 +45,7 @@ void testPOSModelSerializationMaxent() throws IOException {
@Test
void testPOSModelSerializationPerceptron() throws IOException {
- POSModel posModel = POSTaggerMETest.trainPOSModel(ModelType.PERCEPTRON);
+ POSModel posModel = POSTaggerMETest.trainPennFormatPOSModel(ModelType.PERCEPTRON);
Assertions.assertFalse(posModel.isLoadedFromSerialized());
try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
diff --git a/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMETest.java b/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMETest.java
index 945de120d..d01fe3ab7 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMETest.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMETest.java
@@ -19,10 +19,13 @@
import java.io.IOException;
import java.nio.charset.StandardCharsets;
+import java.nio.file.Path;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
+import opennlp.tools.AbstractModelLoaderTest;
+import opennlp.tools.EnabledWhenCDNAvailable;
import opennlp.tools.formats.ResourceAsStreamFactory;
import opennlp.tools.util.InputStreamFactory;
import opennlp.tools.util.InsufficientTrainingDataException;
@@ -35,11 +38,11 @@
* Tests for the {@link POSTaggerME} class.
*/
-public class POSTaggerMETest {
+public class POSTaggerMETest extends AbstractModelLoaderTest {
private static ObjectStream createSampleStream() throws IOException {
InputStreamFactory in = new ResourceAsStreamFactory(POSTaggerMETest.class,
- "/opennlp/tools/postag/AnnotatedSentences.txt");
+ "/opennlp/tools/postag/AnnotatedSentences.txt"); //PENN FORMAT
return new WordTagSampleStream(new PlainTextByLineStream(in, StandardCharsets.UTF_8));
}
@@ -49,7 +52,7 @@ private static ObjectStream createSampleStream() throws IOException {
*
* @return {@link POSModel}
*/
- public static POSModel trainPOSModel(ModelType type) throws IOException {
+ public static POSModel trainPennFormatPOSModel(ModelType type) throws IOException {
TrainingParameters params = new TrainingParameters();
params.put(TrainingParameters.ALGORITHM_PARAM, type.toString());
params.put(TrainingParameters.ITERATIONS_PARAM, 100);
@@ -61,25 +64,127 @@ public static POSModel trainPOSModel(ModelType type) throws IOException {
@Test
void testPOSTagger() throws IOException {
- POSModel posModel = trainPOSModel(ModelType.MAXENT);
+ final String[] sentence = {
+ "The",
+ "driver",
+ "got",
+ "badly",
+ "injured",
+ "."};
- POSTagger tagger = new POSTaggerME(posModel);
+ final String[] expected = {"DT", "NN", "VBD", "RB", "VBN", "."};
+ testPOSTagger(new POSTaggerME(trainPennFormatPOSModel(ModelType.MAXENT),
+ POSTagFormat.PENN), sentence, expected);
+ }
- String[] tags = tagger.tag(new String[] {
+ @Test
+ void testPOSTaggerPENNtoUD() throws IOException {
+ final String[] sentence = {
"The",
"driver",
"got",
"badly",
"injured",
- "."});
-
- Assertions.assertEquals(6, tags.length);
- Assertions.assertEquals("DT", tags[0]);
- Assertions.assertEquals("NN", tags[1]);
- Assertions.assertEquals("VBD", tags[2]);
- Assertions.assertEquals("RB", tags[3]);
- Assertions.assertEquals("VBN", tags[4]);
- Assertions.assertEquals(".", tags[5]);
+ "."};
+
+ final String[] expected = {"DET", "NOUN", "VERB", "ADV", "VERB", "PUNCT"};
+ //convert PENN to UD on the fly.
+ testPOSTagger(new POSTaggerME(trainPennFormatPOSModel(ModelType.MAXENT),
+ POSTagFormat.UD), sentence, expected);
+ }
+
+ @Test
+ @EnabledWhenCDNAvailable(hostname = "dlcdn.apache.org")
+ void testPOSTaggerDefault() throws IOException {
+ final String[] sentence = {
+ "The",
+ "driver",
+ "got",
+ "badly",
+ "injured",
+ "."};
+
+ final String[] expected = {"DET", "NOUN", "VERB", "ADV", "VERB", "PUNCT"};
+ //this downloads a UD model
+ testPOSTagger(new POSTaggerME("en"), sentence, expected);
+ }
+
+ @Test
+ @EnabledWhenCDNAvailable(hostname = "opennlp.sourceforge.net")
+ void testPOSTaggerLegacyPerceptronPennToUD() throws IOException {
+ final String[] sentence = {
+ "The",
+ "driver",
+ "got",
+ "badly",
+ "injured",
+ "."};
+
+ final String[] expected = {"DET", "NOUN", "VERB", "ADV", "VERB", "PUNCT"};
+ //convert PENN to UD on the fly.
+ testPOSTagger(new POSTaggerME(getVersion15Model("en-pos-perceptron.bin"),
+ POSTagFormat.UD), sentence, expected);
+ }
+
+ @Test
+ @EnabledWhenCDNAvailable(hostname = "opennlp.sourceforge.net")
+ void testPOSTaggerLegacyPerceptronPenn() throws IOException {
+ final String[] sentence = {
+ "The",
+ "driver",
+ "got",
+ "badly",
+ "injured",
+ "."};
+
+ final String[] expected = {"DT", "NN", "VBD", "RB", "VBN", "."};
+ //convert PENN to UD on the fly.
+ testPOSTagger(new POSTaggerME(getVersion15Model("en-pos-perceptron.bin"),
+ POSTagFormat.PENN), sentence, expected);
+ }
+
+ @Test
+ @EnabledWhenCDNAvailable(hostname = "opennlp.sourceforge.net")
+ void testPOSTaggerLegacyMaxentPennToUD() throws IOException {
+ final String[] sentence = {
+ "The",
+ "driver",
+ "got",
+ "badly",
+ "injured",
+ "."};
+
+ final String[] expected = {"DET", "NOUN", "VERB", "ADV", "VERB", "PUNCT"};
+ //convert PENN to UD on the fly.
+ testPOSTagger(new POSTaggerME(getVersion15Model("en-pos-maxent.bin"),
+ POSTagFormat.UD), sentence, expected);
+ }
+
+ @Test
+ @EnabledWhenCDNAvailable(hostname = "opennlp.sourceforge.net")
+ void testPOSTaggerLegacyMaxentPenn() throws IOException {
+ final String[] sentence = {
+ "The",
+ "driver",
+ "got",
+ "badly",
+ "injured",
+ "."};
+
+ final String[] expected = {"DT", "NN", "VBD", "RB", "VBN", "."};
+ //convert PENN to UD on the fly.
+ testPOSTagger(new POSTaggerME(getVersion15Model("en-pos-maxent.bin"),
+ POSTagFormat.PENN), sentence, expected);
+ }
+
+ private POSModel getVersion15Model(String modelName) throws IOException {
+ downloadVersion15Model(modelName);
+ final Path modelPath = OPENNLP_DIR.resolve(modelName);
+ return new POSModel(modelPath);
+ }
+
+ private void testPOSTagger(POSTagger tagger, String[] sentences, String[] expectedTags) {
+ Assertions.assertArrayEquals(expectedTags, tagger.tag(sentences));
}
@Test
diff --git a/opennlp-tools/src/test/java/opennlp/tools/util/featuregen/POSTaggerNameFeatureGeneratorTest.java b/opennlp-tools/src/test/java/opennlp/tools/util/featuregen/POSTaggerNameFeatureGeneratorTest.java
index ccd1f9f24..c349ee71a 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/util/featuregen/POSTaggerNameFeatureGeneratorTest.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/util/featuregen/POSTaggerNameFeatureGeneratorTest.java
@@ -33,7 +33,7 @@ public class POSTaggerNameFeatureGeneratorTest {
@Test
void testFeatureGeneration() throws IOException {
POSTaggerNameFeatureGenerator fg = new POSTaggerNameFeatureGenerator(
- POSTaggerMETest.trainPOSModel(ModelType.MAXENT));
+ POSTaggerMETest.trainPennFormatPOSModel(ModelType.MAXENT));
String[] tokens = {"Hi", "Mike", ",", "it", "'s", "Stefanie", "Schmidt", "."};
for (int i = 0; i < tokens.length; i++) {