From 982faf7052ece3ab02c7cd59ab62e38cf416838a Mon Sep 17 00:00:00 2001 From: Britta Weber Date: Tue, 7 Oct 2014 13:14:08 +0200 Subject: [PATCH] Remove setNextScore in SearchScript Due to a change in elasticsearch 1.4.0, we need to apply a similar patch here. See elasticsearch/elasticsearch#6864 See elasticsearch/elasticsearch#7819 Closes #16. Closes #21. (cherry picked from commit cd7756c) --- .../python/PythonScriptEngineService.java | 3 +- .../python/PythonScriptSearchTests.java | 53 ++++++++++++++++++- 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/elasticsearch/script/python/PythonScriptEngineService.java b/src/main/java/org/elasticsearch/script/python/PythonScriptEngineService.java index d71d3fb..8e48c4d 100644 --- a/src/main/java/org/elasticsearch/script/python/PythonScriptEngineService.java +++ b/src/main/java/org/elasticsearch/script/python/PythonScriptEngineService.java @@ -26,6 +26,7 @@ import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.script.ExecutableScript; +import org.elasticsearch.script.ScoreAccessor; import org.elasticsearch.script.ScriptEngineService; import org.elasticsearch.script.SearchScript; import org.elasticsearch.search.lookup.SearchLookup; @@ -164,7 +165,7 @@ public PythonSearchScript(PyCode code, Map vars, SearchLookup lo @Override public void setScorer(Scorer scorer) { - lookup.setScorer(scorer); + pyVars.__setitem__("_score", Py.java2py(new ScoreAccessor(scorer))); } @Override diff --git a/src/test/java/org/elasticsearch/script/python/PythonScriptSearchTests.java b/src/test/java/org/elasticsearch/script/python/PythonScriptSearchTests.java index b24bc06..24778c9 100644 --- a/src/test/java/org/elasticsearch/script/python/PythonScriptSearchTests.java +++ b/src/test/java/org/elasticsearch/script/python/PythonScriptSearchTests.java @@ -26,12 +26,15 @@ import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders; import org.elasticsearch.plugins.PluginsService; import org.elasticsearch.script.ScriptService; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.test.ElasticsearchIntegrationTest; import org.hamcrest.CoreMatchers; +import org.hamcrest.Matchers; import org.junit.After; import org.junit.Test; +import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -40,7 +43,10 @@ import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; import static org.elasticsearch.index.query.FilterBuilders.scriptFilter; import static org.elasticsearch.index.query.QueryBuilders.*; +import static org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders.scriptFunction; +import static org.elasticsearch.search.aggregations.AggregationBuilders.terms; import static org.elasticsearch.search.builder.SearchSourceBuilder.searchSource; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; @@ -193,7 +199,7 @@ public void testCustomScriptBoost() throws Exception { response = client().search(searchRequest() .searchType(SearchType.QUERY_THEN_FETCH) .source(searchSource().explain(true).query(functionScoreQuery(termQuery("test", "value")) - .add(ScoreFunctionBuilders.scriptFunction("doc['num1'].value * _score").lang("python")))) + .add(ScoreFunctionBuilders.scriptFunction("doc['num1'].value * _score.doubleValue()").lang("python")))) ).actionGet(); assertThat("Failures " + Arrays.toString(response.getShardFailures()), response.getShardFailures().length, equalTo(0)); @@ -208,7 +214,7 @@ public void testCustomScriptBoost() throws Exception { response = client().search(searchRequest() .searchType(SearchType.QUERY_THEN_FETCH) .source(searchSource().explain(true).query(functionScoreQuery(termQuery("test", "value")) - .add(ScoreFunctionBuilders.scriptFunction("param1 * param2 * _score").param("param1", 2).param("param2", 2).lang("python")))) + .add(ScoreFunctionBuilders.scriptFunction("param1 * param2 * _score.doubleValue()").param("param1", 2).param("param2", 2).lang("python")))) ).actionGet(); assertThat("Failures " + Arrays.toString(response.getShardFailures()), response.getShardFailures().length, equalTo(0)); @@ -239,4 +245,47 @@ public void testPythonEmptyParameters() throws Exception { assertThat((String) value, CoreMatchers.equalTo("bar")); } + + @Test + public void testScriptScoresNested() throws IOException { + createIndex("index"); + ensureYellow(); + index("index", "testtype", "1", jsonBuilder().startObject().field("dummy_field", 1).endObject()); + refresh(); + SearchResponse response = client().search( + searchRequest().source( + searchSource().query( + functionScoreQuery( + functionScoreQuery( + functionScoreQuery().add(scriptFunction("1").lang("python"))) + .add(scriptFunction("_score.doubleValue()").lang("python"))) + .add(scriptFunction("_score.doubleValue()").lang("python") + ) + ) + ) + ).actionGet(); + assertSearchResponse(response); + assertThat(response.getHits().getAt(0).score(), equalTo(1.0f)); + } + + @Test + public void testScriptScoresWithAgg() throws IOException { + createIndex("index"); + ensureYellow(); + index("index", "testtype", "1", jsonBuilder().startObject().field("dummy_field", 1).endObject()); + refresh(); + SearchResponse response = client().search( + searchRequest().source( + searchSource().query( + functionScoreQuery() + .add(scriptFunction("_score.doubleValue()").lang("python") + ) + ).aggregation(terms("score_agg").script("_score.doubleValue()").lang("python")) + ) + ).actionGet(); + assertSearchResponse(response); + assertThat(response.getHits().getAt(0).score(), equalTo(1.0f)); + assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getKeyAsNumber().floatValue(), Matchers.is(1f)); + assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getDocCount(), Matchers.is(1l)); + } }