Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cpp/core/compute/Runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,14 @@ class Runtime : public std::enable_shared_from_this<Runtime> {
return objStore_->save(obj);
}

int64_t taskAttemptId() const {
return taskAttemptId_;
}

void setTaskAttemptId(int64_t id) {
taskAttemptId_ = id;
}

protected:
std::string kind_;
MemoryManager* memoryManager_;
Expand All @@ -206,5 +214,6 @@ class Runtime : public std::enable_shared_from_this<Runtime> {

std::optional<SparkTaskInfo> taskInfo_{std::nullopt};
std::shared_ptr<WholeStageDumper> dumper_{nullptr};
int64_t taskAttemptId_{-1};
};
} // namespace gluten
16 changes: 12 additions & 4 deletions cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,8 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_runtime_RuntimeJniWrapper_createR
jstring jBackendType,
jlong nmmHandle,
jlong ntmHandle,
jbyteArray sessionConf) {
jbyteArray sessionConf,
jlong taskAttemptId) {
JNI_METHOD_START
MemoryManager* memoryManager = jniCastOrThrow<MemoryManager>(nmmHandle);
ThreadManager* threadManager = jniCastOrThrow<ThreadManager>(ntmHandle);
Expand All @@ -354,6 +355,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_runtime_RuntimeJniWrapper_createR
auto backendType = jStringToCString(env, jBackendType);

auto runtime = Runtime::create(backendType, memoryManager, threadManager, sparkConf);
runtime->setTaskAttemptId(static_cast<int64_t>(taskAttemptId));
return reinterpret_cast<jlong>(runtime);
JNI_METHOD_END(kInvalidObjectHandle)
}
Expand All @@ -378,13 +380,15 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_memory_NativeMemoryManagerJniWrap
jclass,
jstring jBackendType,
jobject jListener,
jbyteArray sessionConf) {
jbyteArray sessionConf,
jstring jName) {
JNI_METHOD_START
JavaVM* vm;
if (env->GetJavaVM(&vm) != JNI_OK) {
throw GlutenException("Unable to get JavaVM instance");
}
auto backendType = jStringToCString(env, jBackendType);
auto name = jStringToCString(env, jName);
auto safeArray = getByteArrayElementsSafe(env, sessionConf);
auto sparkConf = parseConfMap(env, safeArray.elems(), safeArray.length());
std::unique_ptr<AllocationListener> listener = std::make_unique<SparkAllocationListener>(vm, jListener);
Expand All @@ -393,6 +397,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_memory_NativeMemoryManagerJniWrap
listener = std::make_unique<BacktraceAllocationListener>(std::move(listener));
}
MemoryManager* mm = MemoryManager::create(backendType, std::move(listener));
mm->setName(name);
return reinterpret_cast<jlong>(mm);
JNI_METHOD_END(-1L)
}
Expand Down Expand Up @@ -457,7 +462,9 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_memory_NativeMemoryManagerJniWrap
JNIEXPORT void JNICALL Java_org_apache_gluten_memory_NativeMemoryManagerJniWrapper_hold( // NOLINT
JNIEnv* env,
jclass,
jlong nmmHandle) {
jlong nmmHandle,
jstring jName,
jlong taskAttemptId) {
JNI_METHOD_START
auto* memoryManager = jniCastOrThrow<MemoryManager>(nmmHandle);
memoryManager->hold();
Expand All @@ -467,7 +474,8 @@ JNIEXPORT void JNICALL Java_org_apache_gluten_memory_NativeMemoryManagerJniWrapp
JNIEXPORT void JNICALL Java_org_apache_gluten_memory_NativeMemoryManagerJniWrapper_release( // NOLINT
JNIEnv* env,
jclass,
jlong nmmHandle) {
jlong nmmHandle,
jlong taskAttemptId) {
JNI_METHOD_START
auto* memoryManager = jniCastOrThrow<MemoryManager>(nmmHandle);
MemoryManager::release(memoryManager);
Expand Down
9 changes: 9 additions & 0 deletions cpp/core/memory/MemoryManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ class MemoryManager {
return kind_;
}

std::string name() const {
return name_;
}

void setName(const std::string& name) {
name_ = name;
}

// Get the default Arrow memory pool for this memory manager. This memory pool is held by the memory manager.
virtual arrow::MemoryPool* defaultArrowMemoryPool() = 0;

Expand All @@ -58,6 +66,7 @@ class MemoryManager {

private:
std::string kind_;
std::string name_;
};

} // namespace gluten
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ public class NativeMemoryManagerJniWrapper {
private NativeMemoryManagerJniWrapper() {}

public static native long create(
String backendType, ReservationListener listener, byte[] sessionConf);
String backendType, ReservationListener listener, byte[] sessionConf, String name);

public static native byte[] collectUsage(long handle);

public static native long shrink(long handle, long size);

public static native void hold(long handle);
public static native void hold(long handle, String name, long taskAttemptId);

public static native void release(long handle);
public static native void release(long handle, long taskAttemptId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public class RuntimeJniWrapper {
private RuntimeJniWrapper() {}

public static native long createRuntime(
String backendType, long nmm, long ntm, byte[] sessionConf);
String backendType, long nmm, long ntm, byte[] sessionConf, long taskAttemptId);

public static native void releaseRuntime(long handle);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.gluten.memory.memtarget.{KnownNameAndStats, MemoryTarget, Spil
import org.apache.gluten.proto.MemoryUsageStats
import org.apache.gluten.utils.ConfigUtil

import org.apache.spark.TaskContext
import org.apache.spark.memory.SparkMemoryUtil
import org.apache.spark.sql.internal.{GlutenConfigUtil, SQLConf}
import org.apache.spark.task.{TaskResource, TaskResources}
Expand Down Expand Up @@ -54,7 +55,8 @@ object NativeMemoryManager {
ConfigUtil.serialize(
GlutenConfig
.getNativeSessionConf(backendName, GlutenConfigUtil.parseConfig(SQLConf.get.getAllConfs))
.asJava)
.asJava),
name
)
spillers.append(new Spiller() {
override def spill(self: MemoryTarget, phase: Spiller.Phase, size: Long): Long = phase match {
Expand All @@ -76,7 +78,8 @@ object NativeMemoryManager {
private val released: AtomicBoolean = new AtomicBoolean(false)

override def addSpiller(spiller: Spiller): Unit = spillers.append(spiller)
override def hold(): Unit = NativeMemoryManagerJniWrapper.hold(handle)
override def hold(): Unit =
NativeMemoryManagerJniWrapper.hold(handle, name, TaskContext.get().taskAttemptId())
override def getHandle(): Long = handle
override def release(): Unit = {
if (!released.compareAndSet(false, true)) {
Expand All @@ -97,7 +100,7 @@ object NativeMemoryManager {
LOGGER.debug("About to release memory manager, " + dump())
}

NativeMemoryManagerJniWrapper.release(handle)
NativeMemoryManagerJniWrapper.release(handle, TaskContext.get().taskAttemptId())

if (rl.getUsedBytes != 0) {
LOGGER.warn(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.gluten.memory.NativeMemoryManager
import org.apache.gluten.threads.{NativeThreadManager, TaskChildThreadInitializer}
import org.apache.gluten.utils.ConfigUtil

import org.apache.spark.TaskContext
import org.apache.spark.sql.internal.{GlutenConfigUtil, SQLConf}
import org.apache.spark.task.{TaskResource, TaskResources}

Expand Down Expand Up @@ -64,7 +65,8 @@ object Runtime {
(GlutenConfig
.getNativeSessionConf(
backendName,
GlutenConfigUtil.parseConfig(SQLConf.get.getAllConfs)) ++ extraConf.asScala).asJava)
GlutenConfigUtil.parseConfig(SQLConf.get.getAllConfs)) ++ extraConf.asScala).asJava),
TaskContext.get().taskAttemptId()
)

private val released: AtomicBoolean = new AtomicBoolean(false)
Expand Down
Loading