Skip to content

Commit

Permalink
[WIP] Small cleanup for the jvm interface.
Browse files Browse the repository at this point in the history
Fix create jni.

serialization.

lint.

cleanup.

[wip][jvm-packages] Add java class for `ExtMemQdm`.

cleanup.

cleanup.

Debug build.

Fix CPU build.

Cleanup.
  • Loading branch information
trivialfis committed Jan 22, 2025
1 parent 2e1626c commit 73c95c7
Show file tree
Hide file tree
Showing 11 changed files with 215 additions and 42 deletions.
3 changes: 3 additions & 0 deletions jvm-packages/create_jni.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def native_build(cli_args: argparse.Namespace) -> None:
os.environ["JAVA_HOME"] = (
subprocess.check_output("/usr/libexec/java_home").strip().decode()
)
if cli_args.use_debug == "ON":
CONFIG["CMAKE_BUILD_TYPE"] = "Debug"

print("building Java wrapper", flush=True)
with cd(".."):
Expand Down Expand Up @@ -187,5 +189,6 @@ def native_build(cli_args: argparse.Namespace) -> None:
)
parser.add_argument("--use-cuda", type=str, choices=["ON", "OFF"], default="OFF")
parser.add_argument("--use-openmp", type=str, choices=["ON", "OFF"], default="ON")
parser.add_argument("--use-debug", type=str, choices=["ON", "OFF"], default="OFF")
cli_args = parser.parse_args()
native_build(cli_args)
1 change: 1 addition & 0 deletions jvm-packages/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
<log.capi.invocation>OFF</log.capi.invocation>
<use.cuda>OFF</use.cuda>
<use.openmp>ON</use.openmp>
<use.debug>OFF</use.debug>
<cudf.version>24.10.0</cudf.version>
<spark.rapids.version>24.10.0</spark.rapids.version>
<spark.rapids.classifier>cuda12</spark.rapids.classifier>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) 2025, XGBoost Contributors
*/
package ml.dmlc.xgboost4j.java;

import java.util.Iterator;
import java.util.Map;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;


public class ExtMemQuantileDMatrix extends QuantileDMatrix {
// on_host is set to true by default as we only support GPU at the moment
// cache_prefix is not used yet since we have on_host=true.
public ExtMemQuantileDMatrix(Iterator<ColumnBatch> iter,
float missing,
int maxBin,
DMatrix ref,
int nthread,
int max_num_device_pages,
int max_quantile_batches,
int min_cache_page_bytes) throws XGBoostError {
long[] out = new long[1];
long[] ref_handle = null;
if (ref != null) {
ref_handle = new long[1];
ref_handle[0] = ref.getHandle();
}
String conf = this.getConfig(missing, maxBin, nthread, max_num_device_pages,
max_quantile_batches, min_cache_page_bytes);
XGBoostJNI.checkCall(XGBoostJNI.XGExtMemQuantileDMatrixCreateFromCallback(
iter, ref_handle, conf, out));
handle = out[0];
}

private String getConfig(float missing, int maxBin, int nthread, int max_num_device_pages,
int max_quantile_batches,
int min_cache_page_bytes) {
Map<String, Object> conf = new java.util.HashMap<>();
conf.put("missing", missing);
conf.put("max_bin", maxBin);
conf.put("nthread", nthread);
conf.put("max_num_device_pages", max_num_device_pages);
conf.put("max_quantile_batches", max_quantile_batches);
conf.put("min_cache_page_bytes", min_cache_page_bytes);
conf.put("on_host", true);
conf.put("cache_prefix", ".");
ObjectMapper mapper = new ObjectMapper();

// Handle NaN values. Jackson by default serializes NaN values into strings.
SimpleModule module = new SimpleModule();
module.addSerializer(Double.class, new F64NaNSerializer());
module.addSerializer(Float.class, new F32NaNSerializer());
mapper.registerModule(module);

try {
String config = mapper.writeValueAsString(conf);
return config;
} catch (JsonProcessingException e) {
throw new RuntimeException("Failed to serialize configuration", e);
}
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ public void serialize(Float value, JsonGenerator gen,
* QuantileDMatrix will only be used to train
*/
public class QuantileDMatrix extends DMatrix {
// implicit constructor for the ext mem version of the QDM.
protected QuantileDMatrix() {
super(0);
}

/**
* Create QuantileDMatrix from iterator based on the cuda array interface
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,5 +198,4 @@ private float[] convertFloatTofloat(Float[]... datas) {
}
return floatArray;
}

}
2 changes: 2 additions & 0 deletions jvm-packages/xgboost4j/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@
<argument>${use.cuda}</argument>
<argument>--use-openmp</argument>
<argument>${use.openmp}</argument>
<argument>--use-debug</argument>
<argument>${use.debug}</argument>
</arguments>
<workingDirectory>${user.dir}</workingDirectory>
<skip>${skip.native.build}</skip>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ public final static native int XGDMatrixSetInfoFromInterface(
long handle, String field, String json);

public final static native int XGQuantileDMatrixCreateFromCallback(
java.util.Iterator<ColumnBatch> iter, long[] ref, String config, long[] out);
java.util.Iterator<ColumnBatch> iter, long[] ref, String config, long[] out);

public final static native int XGExtMemQuantileDMatrixCreateFromCallback(
java.util.Iterator<ColumnBatch> iter, long[] ref, String config, long[] out);

public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
String featureJson, float missing, int nthread, long[] out);
Expand Down
4 changes: 1 addition & 3 deletions jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
#include "../../../../src/common/common.h"

namespace xgboost::jni {
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls, jobject jdata_iter,
jlongArray jref, char const *config,
jlongArray jout) {
int QdmFromCallback(JNIEnv *, jobject, jlongArray, char const *, bool, jlongArray) {
API_BEGIN();
common::AssertGPUSupport();
API_END();
Expand Down
139 changes: 107 additions & 32 deletions jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <jni.h>
#include <xgboost/c_api.h>

#include "../../../../src/common/common.h"
#include "../../../../src/common/cuda_pinned_allocator.h"
#include "../../../../src/common/device_vector.cuh" // for device_vector
#include "../../../../src/data/array_interface.h"
Expand Down Expand Up @@ -58,7 +59,7 @@ void CopyColumnMask(xgboost::ArrayInterface<1> const &interface, std::vector<Jso
LOG(FATAL) << "Invalid shape of mask";
}
out["mask"]["typestr"] = String("<t1");
out["mask"]["version"] = Integer(3);
out["mask"]["version"] = Integer{3};
}

template <typename DCont, typename VCont>
Expand Down Expand Up @@ -201,7 +202,7 @@ class DMatrixProxy {
}
};

class DataIteratorProxy {
class HostMemProxy {
DMatrixProxy proxy_;
JvmIter jiter_;

Expand Down Expand Up @@ -237,24 +238,24 @@ class DataIteratorProxy {
cudaStream_t copy_stream_;

public:
explicit DataIteratorProxy(jobject jiter) : jiter_{jiter} {
explicit HostMemProxy(jobject jiter) : jiter_{jiter} {
this->Reset();
dh::safe_cuda(cudaStreamCreateWithFlags(&copy_stream_, cudaStreamNonBlocking));
}
~DataIteratorProxy() { dh::safe_cuda(cudaStreamDestroy(copy_stream_)); }
~HostMemProxy() { dh::safe_cuda(cudaStreamDestroy(copy_stream_)); }

DMatrixHandle GetDMatrixHandle() const { return proxy_.GetDMatrixHandle(); }

// Helper function for staging meta info.
void StageMetaInfo(Json json_interface) {
CHECK(!IsA<Null>(json_interface));
auto json_map = get<Object const>(json_interface);
void StageMetaInfo(Json jaif) {
CHECK(!IsA<Null>(jaif));
auto json_map = get<Object const>(jaif);
auto it = json_map.find(Symbols::kLabel);
if (it == json_map.cend()) {
LOG(FATAL) << "Must have a label field.";
}

Json label = json_interface[Symbols::kLabel.c_str()];
Json label = jaif[Symbols::kLabel.c_str()];
CHECK(!IsA<Null>(label));
labels_.emplace_back(std::make_unique<dh::device_vector<float>>());
CopyMetaInfo(&label, labels_.back().get(), copy_stream_);
Expand All @@ -263,7 +264,7 @@ class DataIteratorProxy {

it = json_map.find(Symbols::kWeight);
if (it != json_map.cend()) {
Json weight = json_interface[Symbols::kWeight.c_str()];
Json weight = jaif[Symbols::kWeight.c_str()];
CHECK(!IsA<Null>(weight));
weights_.emplace_back(new dh::device_vector<float>);
CopyMetaInfo(&weight, weights_.back().get(), copy_stream_);
Expand All @@ -274,7 +275,7 @@ class DataIteratorProxy {

it = json_map.find(Symbols::kBaseMargin);
if (it != json_map.cend()) {
Json base_margin = json_interface[Symbols::kBaseMargin.c_str()];
Json base_margin = jaif[Symbols::kBaseMargin.c_str()];
base_margins_.emplace_back(new dh::device_vector<float>);
CopyMetaInfo(&base_margin, base_margins_.back().get(), copy_stream_);
margin_interfaces_.emplace_back(base_margin);
Expand All @@ -284,7 +285,7 @@ class DataIteratorProxy {

it = json_map.find(Symbols::kQid);
if (it != json_map.cend()) {
Json qid = json_interface[Symbols::kQid.c_str()];
Json qid = jaif[Symbols::kQid.c_str()];
qids_.emplace_back(new dh::device_vector<int>);
CopyMetaInfo(&qid, qids_.back().get(), copy_stream_);
qid_interfaces_.emplace_back(qid);
Expand All @@ -304,13 +305,13 @@ class DataIteratorProxy {
using T = decltype(host_columns_)::value_type::element_type;
host_columns_.emplace_back(std::make_unique<T>());

// Stage the meta info.
auto json_interface = Json::Load({interface_str.c_str(), interface_str.size()});
CHECK(!IsA<Null>(json_interface));
// Stage the meta info, Json array interface.
auto jaif = Json::Load({interface_str.c_str(), interface_str.size()});
CHECK(!IsA<Null>(jaif));

StageMetaInfo(json_interface);
StageMetaInfo(jaif);

Json features = json_interface["features"];
Json features = jaif["features"];
auto json_columns = get<Array const>(features);
std::vector<ArrayInterface<1>> interfaces;

Expand Down Expand Up @@ -394,26 +395,84 @@ class DataIteratorProxy {
}
return NextSecondLoop();
}
};
}
};

namespace {
void Reset(DataIterHandle self) {
static_cast<xgboost::jni::DataIteratorProxy *>(self)->Reset();
}
// An iterator proxy for external memory.
class ExtMemProxy {
JvmIter jiter_;
DMatrixProxy proxy_;

int Next(DataIterHandle self) {
return static_cast<xgboost::jni::DataIteratorProxy *>(self)->Next();
}
public:
explicit ExtMemProxy(jobject jiter) : jiter_(jiter) {}

~ExtMemProxy() = default;

DMatrixHandle GetDMatrixHandle() const { return proxy_.GetDMatrixHandle(); }

void SetArrayInterface(StringView aif) {
auto jaif = Json::Load(aif);
CHECK(!IsA<Null>(jaif));

Json features = jaif["features"];
proxy_.SetData(features);

// set the meta info.
auto json_map = get<Object const>(jaif);
if (json_map.find(Symbols::kLabel) == json_map.cend()) {
LOG(FATAL) << "Must have a label field.";
}
Json label = jaif[Symbols::kLabel.c_str()];
CHECK(!IsA<Null>(label));
proxy_.SetInfo(Symbols::kLabel, label);

if (json_map.find(Symbols::kWeight) != json_map.cend()) {
Json weight = jaif[Symbols::kWeight.c_str()];
CHECK(!IsA<Null>(weight));
proxy_.SetInfo(Symbols::kWeight, weight);
}

if (json_map.find(Symbols::kBaseMargin) != json_map.cend()) {
Json basemargin = jaif[Symbols::kBaseMargin.c_str()];
proxy_.SetInfo("base_margin", basemargin);
}

if (json_map.find(Symbols::kQid) != json_map.cend()) {
Json qid = jaif[Symbols::kQid.c_str()];
proxy_.SetInfo(Symbols::kQid, qid);
}
}

int Next() {
try {
if (this->jiter_.PullIterFromJVM(
[this](char const *cjaif) { this->SetArrayInterface(cjaif); })) {
return 1;
} else {
return 0;
}
} catch (dmlc::Error const &e) {
if (jiter_.Status() == JNI_EDETACHED) {
GlobalJvm()->DetachCurrentThread();
}
LOG(FATAL) << e.what();
}
return 0;
}

void Reset() { this->jiter_.CloseJvmBatch(); }
};

namespace {
template <typename T>
using Deleter = std::function<void(T *)>;
} // anonymous namespace
} // anonymous namespace

XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass, jobject jdata_iter,
jlongArray jref, char const *config,
jlongArray jout) {
xgboost::jni::DataIteratorProxy proxy(jdata_iter);
/**
* @brief Create QuantileDMatrix for both in-core version and the external memory version.
*/
int QdmFromCallback(JNIEnv *jenv, jobject jdata_iter, jlongArray jref, char const *config,
bool is_extmem, jlongArray jout) {
DMatrixHandle result;
DMatrixHandle ref{nullptr};

Expand All @@ -426,9 +485,25 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass, jobjec
ref = reinterpret_cast<DMatrixHandle>(refptr.get()[0]);
}

auto ret = XGQuantileDMatrixCreateFromCallback(&proxy, proxy.GetDMatrixHandle(), ref, Reset, Next,
config, &result);
int ret = 0;
if (is_extmem) {
xgboost::jni::ExtMemProxy proxy{jdata_iter};
ret = XGExtMemQuantileDMatrixCreateFromCallback(
&proxy, proxy.GetDMatrixHandle(), ref,
[](DataIterHandle self) { static_cast<xgboost::jni::ExtMemProxy *>(self)->Reset(); },
[](DataIterHandle self) { return static_cast<xgboost::jni::ExtMemProxy *>(self)->Next(); },
config, &result);
} else {
xgboost::jni::HostMemProxy proxy{jdata_iter};
ret = XGQuantileDMatrixCreateFromCallback(
&proxy, proxy.GetDMatrixHandle(), ref,
[](DataIterHandle self) { static_cast<xgboost::jni::HostMemProxy *>(self)->Reset(); },
[](DataIterHandle self) { return static_cast<xgboost::jni::HostMemProxy *>(self)->Next(); },
config, &result);
}

JVM_CHECK_CALL(ret);
setHandle(jenv, jout, result);
return ret;
}
} // namespace xgboost::jni
} // namespace xgboost::jni
Loading

0 comments on commit 73c95c7

Please sign in to comment.