diff --git a/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMEIT.java b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormat.java similarity index 54% rename from opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMEIT.java rename to opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormat.java index 9b521ce89..ddb9cc5f5 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMEIT.java +++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormat.java @@ -14,38 +14,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package opennlp.tools.postag; -import java.io.IOException; - -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; - -public class POSTaggerMEIT { - - private static POSTagger tagger; - - @BeforeAll - public static void prepare() throws IOException { - tagger = new POSTaggerME("en"); - } - - @Test - void testPOSTagger() { - - String[] tags = tagger.tag(new String[] { - "The", - "driver", - "got", - "badly", - "injured", - "."}); - - // TODO OPENNLP-1539 Adjust this depending on the POSFormat - String[] expected = {"DET", "NOUN", "VERB", "ADV", "VERB", "PUNCT"}; - Assertions.assertArrayEquals(expected, tags); - } +/** + * Defines the format for part-of-speech tagging, i.e. + * PENN + * or UD format. + */ +public enum POSTagFormat { + UD, PENN, UNKNOWN } diff --git a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormatMapper.java b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormatMapper.java new file mode 100644 index 000000000..3001dfd18 --- /dev/null +++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormatMapper.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package opennlp.tools.postag; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A mapping implementation for converting between different POS tag formats. + * This class supports conversion between Penn Treebank (PENN) and Universal Dependencies (UD) formats. + * The conversion is based on the Universal Dependencies conversion table. + * Please note that when converting from UD to Penn format, there may be ambiguity in some cases. + */ +public class POSTagFormatMapper { + + private static final Logger logger = LoggerFactory.getLogger(POSTagFormatMapper.class); + + private static final Map CONVERSION_TABLE_PENN_TO_UD = new HashMap<>(); + private static final Map CONVERSION_TABLE_UD_TO_PENN = new HashMap<>(); + + static { + /* + * This is a conversion table to convert PENN to UD format as described in + * https://universaldependencies.org/tagset-conversion/en-penn-uposf.html + */ + CONVERSION_TABLE_PENN_TO_UD.put("#", "SYM"); + CONVERSION_TABLE_PENN_TO_UD.put("$", "SYM"); + CONVERSION_TABLE_PENN_TO_UD.put("''", "PUNCT"); + CONVERSION_TABLE_PENN_TO_UD.put(",", "PUNCT"); + CONVERSION_TABLE_PENN_TO_UD.put("-LRB-", "PUNCT"); + CONVERSION_TABLE_PENN_TO_UD.put("-RRB-", "PUNCT"); + CONVERSION_TABLE_PENN_TO_UD.put(".", "PUNCT"); + CONVERSION_TABLE_PENN_TO_UD.put(":", "PUNCT"); + CONVERSION_TABLE_PENN_TO_UD.put("AFX", "ADJ"); + CONVERSION_TABLE_PENN_TO_UD.put("CC", "CCONJ"); + CONVERSION_TABLE_PENN_TO_UD.put("CD", "NUM"); + CONVERSION_TABLE_PENN_TO_UD.put("DT", "DET"); + CONVERSION_TABLE_PENN_TO_UD.put("EX", "PRON"); + CONVERSION_TABLE_PENN_TO_UD.put("FW", "X"); + CONVERSION_TABLE_PENN_TO_UD.put("HYPH", "PUNCT"); + CONVERSION_TABLE_PENN_TO_UD.put("IN", "ADP"); + CONVERSION_TABLE_PENN_TO_UD.put("JJ", "ADJ"); + CONVERSION_TABLE_PENN_TO_UD.put("JJR", "ADJ"); + CONVERSION_TABLE_PENN_TO_UD.put("JJS", "ADJ"); + CONVERSION_TABLE_PENN_TO_UD.put("LS", "X"); + CONVERSION_TABLE_PENN_TO_UD.put("MD", "VERB"); + CONVERSION_TABLE_PENN_TO_UD.put("NIL", "X"); + CONVERSION_TABLE_PENN_TO_UD.put("NN", "NOUN"); + CONVERSION_TABLE_PENN_TO_UD.put("NNP", "PROPN"); + CONVERSION_TABLE_PENN_TO_UD.put("NNPS", "PROPN"); + CONVERSION_TABLE_PENN_TO_UD.put("NNS", "NOUN"); + CONVERSION_TABLE_PENN_TO_UD.put("PDT", "DET"); + CONVERSION_TABLE_PENN_TO_UD.put("POS", "PART"); + CONVERSION_TABLE_PENN_TO_UD.put("PRP", "PRON"); + CONVERSION_TABLE_PENN_TO_UD.put("PRP$", "DET"); + CONVERSION_TABLE_PENN_TO_UD.put("RB", "ADV"); + CONVERSION_TABLE_PENN_TO_UD.put("RBR", "ADV"); + CONVERSION_TABLE_PENN_TO_UD.put("RBS", "ADV"); + CONVERSION_TABLE_PENN_TO_UD.put("RP", "ADP"); + CONVERSION_TABLE_PENN_TO_UD.put("SYM", "SYM"); + CONVERSION_TABLE_PENN_TO_UD.put("TO", "PART"); + CONVERSION_TABLE_PENN_TO_UD.put("UH", "INTJ"); + CONVERSION_TABLE_PENN_TO_UD.put("VB", "VERB"); + CONVERSION_TABLE_PENN_TO_UD.put("VBD", "VERB"); + CONVERSION_TABLE_PENN_TO_UD.put("VBG", "VERB"); + CONVERSION_TABLE_PENN_TO_UD.put("VBN", "VERB"); + CONVERSION_TABLE_PENN_TO_UD.put("VBP", "VERB"); + CONVERSION_TABLE_PENN_TO_UD.put("VBZ", "VERB"); + CONVERSION_TABLE_PENN_TO_UD.put("WDT", "DET"); + CONVERSION_TABLE_PENN_TO_UD.put("WP", "PRON"); + CONVERSION_TABLE_PENN_TO_UD.put("WP$", "DET"); + CONVERSION_TABLE_PENN_TO_UD.put("WRB", "ADV"); + + /* + * Note: The back conversion might lose information. + */ + CONVERSION_TABLE_UD_TO_PENN.put("ADJ", "JJ"); + CONVERSION_TABLE_UD_TO_PENN.put("ADP", "IN"); + CONVERSION_TABLE_UD_TO_PENN.put("ADV", "RB"); + CONVERSION_TABLE_UD_TO_PENN.put("AUX", "MD"); + CONVERSION_TABLE_UD_TO_PENN.put("CCONJ", "CC"); + CONVERSION_TABLE_UD_TO_PENN.put("DET", "DT"); + CONVERSION_TABLE_UD_TO_PENN.put("INTJ", "UH"); + CONVERSION_TABLE_UD_TO_PENN.put("NOUN", "NN"); + CONVERSION_TABLE_UD_TO_PENN.put("NUM", "CD"); + CONVERSION_TABLE_UD_TO_PENN.put("PART", "RP"); + CONVERSION_TABLE_UD_TO_PENN.put("PRON", "PRP"); + CONVERSION_TABLE_UD_TO_PENN.put("PROPN", "NNP"); + CONVERSION_TABLE_UD_TO_PENN.put("PUNCT", "."); + CONVERSION_TABLE_UD_TO_PENN.put("SCONJ", "IN"); + CONVERSION_TABLE_UD_TO_PENN.put("SYM", "SYM"); + CONVERSION_TABLE_UD_TO_PENN.put("VERB", "VB"); + CONVERSION_TABLE_UD_TO_PENN.put("X", "FW"); + } + + private final POSTagFormat modelFormat; + + protected POSTagFormatMapper(final String[] possibleOutcomes) { + Objects.requireNonNull(possibleOutcomes, "Outcomes must not be NULL."); + this.modelFormat = guessModelTagFormat(possibleOutcomes); + } + + /** + * Converts a given tag to the specified format. + * + * @param tags a list of tags to be converted. + * @return the converted tag. + */ + public String[] convertTags(List tags) { + Objects.requireNonNull(tags, "Supplied tags must not be NULL."); + return tags.stream() + .map(this::convertTag) + .toArray(String[]::new); + } + + /** + * Converts a given tag to the specified format. + * + * @param tag no restrictions on this parameter. + * @return the converted tag. + */ + public String convertTag(String tag) { + switch (modelFormat) { + case UD -> { + return CONVERSION_TABLE_UD_TO_PENN.getOrDefault(tag, "?"); + } + case PENN -> { + if ("NOUN".equals(tag)) { + logger.warn("Ambiguity detected: NN can be 'NN' or 'NNS' depending on the number. " + + "Returning 'NN'."); + } + if ("PART".equals(tag)) { + logger.warn("Ambiguity detected: PART can be 'RP' or 'TO'. Returning 'TO'."); + } + if ("PROPN".equals(tag)) { + logger.warn("Ambiguity detected: Can be 'NNP' or 'NNPS. Returning 'NNP'"); + } + if ("PUNCT".equals(tag)) { + logger.warn("Ambiguity detected: PUNCT needs specific punctuation mapping. Returning '.'"); + } + if ("VERB".equals(tag)) { + logger.warn("Ambiguity detected: VERB can be 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ'. " + + "Returning 'VERB'."); + } + return CONVERSION_TABLE_PENN_TO_UD.getOrDefault(tag, "?"); + } + default -> { + return tag; + } + } + } + + /** + * + * @return The guessed {@link POSTagFormat}. Guaranteed to be not {@code null}. + */ + public POSTagFormat getGuessedFormat() { + return this.modelFormat; + } + + /** + * Guesses the {@link POSTagFormat} by using majority quorum. + * @param outcomes must not be {@code null}. + * @return the guessed {@link POSTagFormat}. + */ + private POSTagFormat guessModelTagFormat(final String[] outcomes) { + int udMatches = 0; + int pennMatches = 0; + + for (String outcome : outcomes) { + if (CONVERSION_TABLE_UD_TO_PENN.containsKey(outcome)) { + udMatches++; + } + if (CONVERSION_TABLE_PENN_TO_UD.containsKey(outcome)) { + pennMatches++; + } + } + + if (udMatches > pennMatches) { + return POSTagFormat.UD; + } else if (pennMatches > udMatches) { + return POSTagFormat.PENN; + } else { + logger.warn("Detected an unknown POS format."); + return POSTagFormat.UNKNOWN; + } + } +} diff --git a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java index 0268f48b5..56b77c329 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java +++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerME.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -83,16 +84,30 @@ public class POSTaggerME implements POSTagger { private final SequenceValidator sequenceValidator; + private final POSTagFormat posTagFormat; + private final POSTagFormatMapper posTagFormatMapper; + /** * Initializes a {@link POSTaggerME} by downloading a default model for a given * {@code language}. * * @param language An ISO conform language code. - * * @throws IOException Thrown if the model could not be downloaded or saved. */ public POSTaggerME(String language) throws IOException { - this(DownloadUtil.downloadModel(language, DownloadUtil.ModelType.POS, POSModel.class)); + this(language, POSTagFormat.UD); + } + + /** + * Initializes a {@link POSTaggerME} by downloading a default model for a given + * {@code language}. + * + * @param language An ISO conform language code. + * @param format A valid {@link POSTagFormat}. + * @throws IOException Thrown if the model could not be downloaded or saved. + */ + public POSTaggerME(String language, POSTagFormat format) throws IOException { + this(DownloadUtil.downloadModel(language, DownloadUtil.ModelType.POS, POSModel.class), format); } /** @@ -101,6 +116,17 @@ public POSTaggerME(String language) throws IOException { * @param model A valid {@link POSModel}. */ public POSTaggerME(POSModel model) { + this(model, POSTagFormat.UD); + } + + /** + * Initializes a {@link POSTaggerME} with the provided {@link POSModel model}. + * + * @param model A valid {@link POSModel}. + * @param format A valid {@link POSTagFormat}. + */ + public POSTaggerME(POSModel model, POSTagFormat format) { + this.posTagFormat = format; POSTaggerFactory factory = model.getFactory(); int beamSize = POSTaggerME.DEFAULT_BEAM_SIZE; @@ -121,12 +147,13 @@ public POSTaggerME(POSModel model) { if (model.getPosSequenceModel() != null) { this.model = model.getPosSequenceModel(); - } - else { + } else { this.model = new opennlp.tools.ml.BeamSearch<>(beamSize, model.getPosModel(), 0); } + this.posTagFormatMapper = new POSTagFormatMapper(getAllPosTags()); + } /** @@ -144,16 +171,15 @@ public String[] tag(String[] sentence) { @Override public String[] tag(String[] sentence, Object[] additionalContext) { bestSequence = model.bestSequence(sentence, additionalContext, contextGen, sequenceValidator); - List t = bestSequence.getOutcomes(); - return t.toArray(new String[0]); + final List t = bestSequence.getOutcomes(); + return convertTags(t); } /** * Returns at most the specified {@code numTaggings} for the specified {@code sentence}. * * @param numTaggings The number of tagging to be returned. - * @param sentence An array of tokens which make up a sentence. - * + * @param sentence An array of tokens which make up a sentence. * @return At most the specified number of taggings for the specified {@code sentence}. */ public String[][] tag(int numTaggings, String[] sentence) { @@ -162,11 +188,19 @@ public String[][] tag(int numTaggings, String[] sentence) { String[][] tags = new String[bestSequences.length][]; for (int si = 0; si < tags.length; si++) { List t = bestSequences[si].getOutcomes(); - tags[si] = t.toArray(new String[0]); + tags[si] = convertTags(t); } return tags; } + private String[] convertTags(List t) { + if (posTagFormatMapper.getGuessedFormat() == posTagFormat) { + return t.toArray(new String[0]); + } else { + return posTagFormatMapper.convertTags(t); + } + } + @Override public Sequence[] topKSequences(String[] sentence) { return this.topKSequences(sentence, null); @@ -194,10 +228,10 @@ public double[] probs() { } public String[] getOrderedTags(List words, List tags, int index) { - return getOrderedTags(words,tags,index,null); + return getOrderedTags(words, tags, index, null); } - public String[] getOrderedTags(List words, List tags, int index,double[] tprobs) { + public String[] getOrderedTags(List words, List tags, int index, double[] tprobs) { if (modelPackage.getPosModel() != null) { @@ -205,7 +239,7 @@ public String[] getOrderedTags(List words, List tags, int index, double[] probs = posModel.eval(contextGen.getContext(index, words.toArray(new String[0]), - tags.toArray(new String[0]),null)); + tags.toArray(new String[0]), null)); String[] orderedTags = new String[probs.length]; for (int i = 0; i < probs.length; i++) { @@ -221,17 +255,16 @@ public String[] getOrderedTags(List words, List tags, int index, } probs[max] = 0; } - return orderedTags; - } - else { + return convertTags(Arrays.stream(orderedTags).toList()); + } else { throw new UnsupportedOperationException("This method can only be called if the " + "classification model is an event model!"); } } public static POSModel train(String languageCode, - ObjectStream samples, TrainingParameters trainParams, - POSTaggerFactory posFactory) throws IOException { + ObjectStream samples, TrainingParameters trainParams, + POSTaggerFactory posFactory) throws IOException { int beamSize = trainParams.getIntParameter(BeamSearch.BEAM_SIZE_PARAMETER, POSTaggerME.DEFAULT_BEAM_SIZE); @@ -249,14 +282,12 @@ public static POSModel train(String languageCode, EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams, manifestInfoEntries); posModel = trainer.train(es); - } - else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) { + } else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) { POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator); EventModelSequenceTrainer trainer = TrainerFactory.getEventModelSequenceTrainer(trainParams, manifestInfoEntries); posModel = trainer.train(ss); - } - else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) { + } else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) { SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer( trainParams, manifestInfoEntries); @@ -264,15 +295,13 @@ else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) { POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator); seqPosModel = trainer.train(ss); - } - else { + } else { throw new IllegalArgumentException("Trainer type is not supported: " + trainerType); } if (posModel != null) { return new POSModel(languageCode, posModel, beamSize, manifestInfoEntries, posFactory); - } - else { + } else { return new POSModel(languageCode, seqPosModel, manifestInfoEntries, posFactory); } } @@ -282,9 +311,7 @@ else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) { * * @param samples The {@link ObjectStream} to process. * @param cutoff A non-negative cut-off value. - * * @return A valid {@link Dictionary} instance holding nGrams. - * * @throws IOException Thrown if IO errors occurred during dictionary construction. */ public static Dictionary buildNGramDictionary(ObjectStream samples, int cutoff) @@ -295,8 +322,9 @@ public static Dictionary buildNGramDictionary(ObjectStream samples, i while ((sample = samples.read()) != null) { String[] words = sample.getSentence(); - if (words.length > 0) + if (words.length > 0) { ngramModel.add(new StringList(words), 1, 1); + } } ngramModel.cutoff(cutoff, Integer.MAX_VALUE); @@ -308,13 +336,12 @@ public static Dictionary buildNGramDictionary(ObjectStream samples, i * Populates a {@link POSDictionary} from an {@link ObjectStream} of samples. * * @param samples The {@link ObjectStream} to process. - * @param dict The {@link MutableTagDictionary} to use during population. + * @param dict The {@link MutableTagDictionary} to use during population. * @param cutoff A non-negative cut-off value. - * * @throws IOException Thrown if IO errors occurred during dictionary construction. */ public static void populatePOSDictionary(ObjectStream samples, - MutableTagDictionary dict, int cutoff) throws IOException { + MutableTagDictionary dict, int cutoff) throws IOException { logger.info("Expanding POS Dictionary ..."); long start = System.nanoTime(); @@ -377,6 +404,7 @@ public static void populatePOSDictionary(ObjectStream samples, } } - logger.info("... finished expanding POS Dictionary. [ {} ms]", (System.nanoTime() - start) / 1000000 ); + logger.info("... finished expanding POS Dictionary. [ {} ms]", (System.nanoTime() - start) / 1000000); } + } diff --git a/opennlp-tools/src/test/java/opennlp/tools/namefind/TokenNameFinderModelTest.java b/opennlp-tools/src/test/java/opennlp/tools/namefind/TokenNameFinderModelTest.java index 5379a1f39..c77b28743 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/namefind/TokenNameFinderModelTest.java +++ b/opennlp-tools/src/test/java/opennlp/tools/namefind/TokenNameFinderModelTest.java @@ -57,7 +57,7 @@ void testNERWithPOSModel() throws IOException { Path resourcesFolder = Files.createTempDirectory("resources").toAbsolutePath(); // save a POS model there - POSModel posModel = POSTaggerMETest.trainPOSModel(ModelType.MAXENT); + POSModel posModel = POSTaggerMETest.trainPennFormatPOSModel(ModelType.MAXENT); Assertions.assertNotNull(posModel); File posModelFile = new File(resourcesFolder.toFile(), "pos-model.bin"); diff --git a/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java b/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java index 14086c16a..1565c45ba 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java +++ b/opennlp-tools/src/test/java/opennlp/tools/postag/POSModelTest.java @@ -30,7 +30,7 @@ public class POSModelTest { @Test void testPOSModelSerializationMaxent() throws IOException { - POSModel posModel = POSTaggerMETest.trainPOSModel(ModelType.MAXENT); + POSModel posModel = POSTaggerMETest.trainPennFormatPOSModel(ModelType.MAXENT); Assertions.assertFalse(posModel.isLoadedFromSerialized()); try (ByteArrayOutputStream out = new ByteArrayOutputStream()) { @@ -45,7 +45,7 @@ void testPOSModelSerializationMaxent() throws IOException { @Test void testPOSModelSerializationPerceptron() throws IOException { - POSModel posModel = POSTaggerMETest.trainPOSModel(ModelType.PERCEPTRON); + POSModel posModel = POSTaggerMETest.trainPennFormatPOSModel(ModelType.PERCEPTRON); Assertions.assertFalse(posModel.isLoadedFromSerialized()); try (ByteArrayOutputStream out = new ByteArrayOutputStream()) { diff --git a/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMETest.java b/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMETest.java index 945de120d..9f9b6dd12 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMETest.java +++ b/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerMETest.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import opennlp.tools.EnabledWhenCDNAvailable; import opennlp.tools.formats.ResourceAsStreamFactory; import opennlp.tools.util.InputStreamFactory; import opennlp.tools.util.InsufficientTrainingDataException; @@ -39,7 +40,7 @@ public class POSTaggerMETest { private static ObjectStream createSampleStream() throws IOException { InputStreamFactory in = new ResourceAsStreamFactory(POSTaggerMETest.class, - "/opennlp/tools/postag/AnnotatedSentences.txt"); + "/opennlp/tools/postag/AnnotatedSentences.txt"); //PENN FORMAT return new WordTagSampleStream(new PlainTextByLineStream(in, StandardCharsets.UTF_8)); } @@ -49,7 +50,7 @@ private static ObjectStream createSampleStream() throws IOException { * * @return {@link POSModel} */ - public static POSModel trainPOSModel(ModelType type) throws IOException { + public static POSModel trainPennFormatPOSModel(ModelType type) throws IOException { TrainingParameters params = new TrainingParameters(); params.put(TrainingParameters.ALGORITHM_PARAM, type.toString()); params.put(TrainingParameters.ITERATIONS_PARAM, 100); @@ -61,25 +62,53 @@ public static POSModel trainPOSModel(ModelType type) throws IOException { @Test void testPOSTagger() throws IOException { - POSModel posModel = trainPOSModel(ModelType.MAXENT); + final String[] sentence = { + "The", + "driver", + "got", + "badly", + "injured", + "."}; - POSTagger tagger = new POSTaggerME(posModel); + final String[] expected = {"DT", "NN", "VBD", "RB", "VBN", "."}; + testPOSTagger(new POSTaggerME(trainPennFormatPOSModel(ModelType.MAXENT), + POSTagFormat.PENN), sentence, expected); + } - String[] tags = tagger.tag(new String[] { + @Test + void testPOSTaggerPENNtoUD() throws IOException { + final String[] sentence = { "The", "driver", "got", "badly", "injured", - "."}); - - Assertions.assertEquals(6, tags.length); - Assertions.assertEquals("DT", tags[0]); - Assertions.assertEquals("NN", tags[1]); - Assertions.assertEquals("VBD", tags[2]); - Assertions.assertEquals("RB", tags[3]); - Assertions.assertEquals("VBN", tags[4]); - Assertions.assertEquals(".", tags[5]); + "."}; + + final String[] expected = {"DET", "NOUN", "VERB", "ADV", "VERB", "PUNCT"}; + //convert PENN to UD on the fly. + testPOSTagger(new POSTaggerME(trainPennFormatPOSModel(ModelType.MAXENT), + POSTagFormat.UD), sentence, expected); + } + + @Test + @EnabledWhenCDNAvailable(hostname = "dlcdn.apache.org") + void testPOSTaggerDefault() throws IOException { + final String[] sentence = { + "The", + "driver", + "got", + "badly", + "injured", + "."}; + + final String[] expected = {"DET", "NOUN", "VERB", "ADV", "VERB", "PUNCT"}; + //this downloads a UD model + testPOSTagger(new POSTaggerME("en"), sentence, expected); + } + + private void testPOSTagger(POSTagger tagger, String[] sentences, String[] expectedTags) { + Assertions.assertArrayEquals(expectedTags, tagger.tag(sentences)); } @Test diff --git a/opennlp-tools/src/test/java/opennlp/tools/util/featuregen/POSTaggerNameFeatureGeneratorTest.java b/opennlp-tools/src/test/java/opennlp/tools/util/featuregen/POSTaggerNameFeatureGeneratorTest.java index ccd1f9f24..c349ee71a 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/util/featuregen/POSTaggerNameFeatureGeneratorTest.java +++ b/opennlp-tools/src/test/java/opennlp/tools/util/featuregen/POSTaggerNameFeatureGeneratorTest.java @@ -33,7 +33,7 @@ public class POSTaggerNameFeatureGeneratorTest { @Test void testFeatureGeneration() throws IOException { POSTaggerNameFeatureGenerator fg = new POSTaggerNameFeatureGenerator( - POSTaggerMETest.trainPOSModel(ModelType.MAXENT)); + POSTaggerMETest.trainPennFormatPOSModel(ModelType.MAXENT)); String[] tokens = {"Hi", "Mike", ",", "it", "'s", "Stefanie", "Schmidt", "."}; for (int i = 0; i < tokens.length; i++) {