Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jvm-packages] Supports external memory #11186

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jvm-packages/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ build.sh
xgboost4j-tester/pom.xml
xgboost4j-tester/iris.csv
dependency-reduced-pom.xml
.factorypath
3 changes: 3 additions & 0 deletions jvm-packages/create_jni.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def native_build(cli_args: argparse.Namespace) -> None:
os.environ["JAVA_HOME"] = (
subprocess.check_output("/usr/libexec/java_home").strip().decode()
)
if cli_args.use_debug == "ON":
CONFIG["CMAKE_BUILD_TYPE"] = "Debug"

print("building Java wrapper", flush=True)
with cd(".."):
Expand Down Expand Up @@ -187,5 +189,6 @@ def native_build(cli_args: argparse.Namespace) -> None:
)
parser.add_argument("--use-cuda", type=str, choices=["ON", "OFF"], default="OFF")
parser.add_argument("--use-openmp", type=str, choices=["ON", "OFF"], default="ON")
parser.add_argument("--use-debug", type=str, choices=["ON", "OFF"], default="OFF")
cli_args = parser.parse_args()
native_build(cli_args)
1 change: 1 addition & 0 deletions jvm-packages/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
<log.capi.invocation>OFF</log.capi.invocation>
<use.cuda>OFF</use.cuda>
<use.openmp>ON</use.openmp>
<use.debug>OFF</use.debug>
<cudf.version>24.10.0</cudf.version>
<spark.rapids.version>24.10.0</spark.rapids.version>
<spark.rapids.classifier>cuda12</spark.rapids.classifier>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,17 @@ private List<CudfColumn> initializeCudfColumns(Table table) {
.collect(Collectors.toList());
}

// visible for testing
public Table getFeatureTable() {
return featureTable;
}

// visible for testing
public Table getLabelTable() {
return labelTable;
}


public List<CudfColumn> getFeatures() {
return features;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
Copyright (c) 2025 by Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;

import java.util.Iterator;
import java.util.Map;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;

public class ExtMemQuantileDMatrix extends QuantileDMatrix {
// on_host is set to true by default as we only support GPU at the moment
// cache_prefix is not used yet since we have on_host=true.
public ExtMemQuantileDMatrix(Iterator<ColumnBatch> iter,
float missing,
int maxBin,
DMatrix ref,
int nthread,
int maxNumDevicePages,
int maxQuantileBatches,
int minCachePageBytes) throws XGBoostError {
long[] out = new long[1];
long[] refHandle = null;
if (ref != null) {
refHandle = new long[1];
refHandle[0] = ref.getHandle();
}
String conf = this.getConfig(missing, maxBin, nthread, maxNumDevicePages,
maxQuantileBatches, minCachePageBytes);
XGBoostJNI.checkCall(XGBoostJNI.XGExtMemQuantileDMatrixCreateFromCallback(
iter, refHandle, conf, out));
handle = out[0];
}

public ExtMemQuantileDMatrix(
Iterator<ColumnBatch> iter,
float missing,
int maxBin,
DMatrix ref) throws XGBoostError {
this(iter, missing, maxBin, ref, 1, -1, -1, -1);
}

public ExtMemQuantileDMatrix(
Iterator<ColumnBatch> iter,
float missing,
int maxBin) throws XGBoostError {
this(iter, missing, maxBin, null);
}

private String getConfig(float missing, int maxBin, int nthread, int maxNumDevicePages,
int maxQuantileBatches, int minCachePageBytes) {
Map<String, Object> conf = new java.util.HashMap<>();
conf.put("missing", missing);
conf.put("max_bin", maxBin);
conf.put("nthread", nthread);

if (maxNumDevicePages > 0) {
conf.put("max_num_device_pages", maxNumDevicePages);
}
if (maxQuantileBatches > 0) {
conf.put("max_quantile_batches", maxQuantileBatches);
}
if (minCachePageBytes > 0) {
conf.put("min_cache_page_bytes", minCachePageBytes);
}

conf.put("on_host", true);
conf.put("cache_prefix", ".");
ObjectMapper mapper = new ObjectMapper();

// Handle NaN values. Jackson by default serializes NaN values into strings.
SimpleModule module = new SimpleModule();
module.addSerializer(Double.class, new F64NaNSerializer());
module.addSerializer(Float.class, new F32NaNSerializer());
mapper.registerModule(module);

try {
return mapper.writeValueAsString(conf);
} catch (JsonProcessingException e) {
throw new RuntimeException("Failed to serialize configuration", e);
}
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ public void serialize(Float value, JsonGenerator gen,
* QuantileDMatrix will only be used to train
*/
public class QuantileDMatrix extends DMatrix {
// implicit constructor for the ext mem version of the QDM.
protected QuantileDMatrix() {
super(0);
}

/**
* Create QuantileDMatrix from iterator based on the cuda array interface
*
Expand Down Expand Up @@ -158,8 +163,7 @@ private String getConfig(float missing, int maxBin, int nthread) {
mapper.registerModule(module);

try {
String config = mapper.writeValueAsString(conf);
return config;
return mapper.writeValueAsString(conf);
} catch (JsonProcessingException e) {
throw new RuntimeException("Failed to serialize configuration", e);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
Copyright (c) 2025 by Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala

import scala.collection.JavaConverters._

import ml.dmlc.xgboost4j.java.{ColumnBatch, ExtMemQuantileDMatrix => jExtMemQuantileDMatrix}

class ExtMemQuantileDMatrix private[scala](
private[scala] override val jDMatrix: jExtMemQuantileDMatrix) extends QuantileDMatrix(jDMatrix) {

def this(iter: Iterator[ColumnBatch],
missing: Float,
maxBin: Int,
ref: Option[QuantileDMatrix],
nthread: Int,
maxNumDevicePages: Int,
maxQuantileBatches: Int,
minCachePageBytes: Int) {
this(new jExtMemQuantileDMatrix(iter.asJava, missing, maxBin,
ref.map(_.jDMatrix).orNull,
nthread, maxNumDevicePages, maxQuantileBatches, minCachePageBytes))
}

def this(iter: Iterator[ColumnBatch], missing: Float, maxBin: Int) {
this(new jExtMemQuantileDMatrix(iter.asJava, missing, maxBin))
}

def this(
iter: Iterator[ColumnBatch],
ref: ExtMemQuantileDMatrix,
missing: Float,
maxBin: Int
) {
this(new jExtMemQuantileDMatrix(iter.asJava, missing, maxBin, ref.jDMatrix))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,22 @@ class QuantileDMatrix private[scala](
/**
* Create QuantileDMatrix from iterator based on the array interface
*
* @param iter the XGBoost ColumnBatch batch to provide the corresponding array interface
* @param refDMatrix The reference QuantileDMatrix that provides quantile information, needed
* when creating validation/test dataset with QuantileDMatrix. Supplying the
* training DMatrix as a reference means that the same quantisation applied
* to the training data is applied to the validation/test data
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @param iter the XGBoost ColumnBatch batch to provide the corresponding array interface
* @param ref The reference QuantileDMatrix that provides quantile information, needed
* when creating validation/test dataset with QuantileDMatrix. Supplying the
* training DMatrix as a reference means that the same quantisation applied
* to the training data is applied to the validation/test data
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @throws XGBoostError
*/
def this(iter: Iterator[ColumnBatch],
ref: QuantileDMatrix,
ref: Option[QuantileDMatrix],
missing: Float,
maxBin: Int,
nthread: Int) {
this(new JQuantileDMatrix(iter.asJava, ref.jDMatrix, missing, maxBin, nthread))
this(new JQuantileDMatrix(iter.asJava, ref.map(_.jDMatrix).orNull, missing, maxBin, nthread))
}

/**
Expand Down
Loading
Loading