diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/QuantileDMatrix.java b/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/QuantileDMatrix.java index ce752aec046e..d254da979dcc 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/QuantileDMatrix.java +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/QuantileDMatrix.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2021-2024 by Contributors + Copyright (c) 2021-2025 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,11 +15,40 @@ */ package ml.dmlc.xgboost4j.java; +import java.io.IOException; import java.util.Iterator; import java.util.Map; +import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonSerializer; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.module.SimpleModule; + +class F64NaNSerializer extends JsonSerializer { + @Override + public void serialize(Double value, JsonGenerator gen, + SerializerProvider serializers) throws IOException { + if (value.isNaN()) { + gen.writeRawValue("NaN"); // Write NaN without quotes + } else { + gen.writeNumber(value); + } + } +} + +class F32NaNSerializer extends JsonSerializer { + @Override + public void serialize(Float value, JsonGenerator gen, + SerializerProvider serializers) throws IOException { + if (value.isNaN()) { + gen.writeRawValue("NaN"); // Write NaN without quotes + } else { + gen.writeNumber(value); + } + } +} /** * QuantileDMatrix will only be used to train @@ -121,8 +150,17 @@ private String getConfig(float missing, int maxBin, int nthread) { conf.put("max_bin", maxBin); conf.put("nthread", nthread); ObjectMapper mapper = new ObjectMapper(); + + // Handle NaN values. Jackson by default serializes NaN values into strings. + SimpleModule module = new SimpleModule(); + module.addSerializer(Double.class, new F64NaNSerializer()); + module.addSerializer(Float.class, new F32NaNSerializer()); + mapper.registerModule(module); + try { - return mapper.writeValueAsString(conf); + String config = mapper.writeValueAsString(conf); + System.out.println(config); + return config; } catch (JsonProcessingException e) { throw new RuntimeException("Failed to serialize configuration", e); } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Communicator.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Communicator.java index ee1bc7b4a5a9..aaf4517c7934 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Communicator.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Communicator.java @@ -3,8 +3,6 @@ import java.io.Serializable; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.util.LinkedList; -import java.util.List; import java.util.Map; import com.fasterxml.jackson.core.JsonProcessingException;