Skip to content

Commit e9f19b8

Browse files
[aot] Add stream_ variable for CUDAContext to use a specific CUDA stream to launch CUDA kernel (#8579)
### Brief Summary copilot:summary ### Walkthrough copilot:walkthrough --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 45b3275 commit e9f19b8

5 files changed

Lines changed: 62 additions & 3 deletions

File tree

c_api/include/taichi/taichi_cuda.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,17 @@ ti_export_cuda_memory(TiRuntime runtime,
2020
TiMemory memory,
2121
TiCudaMemoryInteropInfo *interop_info);
2222

23+
// Function `ti_import_cuda_memory`
2324
TI_DLL_EXPORT TiMemory TI_API_CALL ti_import_cuda_memory(TiRuntime runtime,
2425
void *ptr,
2526
size_t memory_size);
2627

28+
// Function `ti_set_cuda_stream`
29+
TI_DLL_EXPORT void TI_API_CALL ti_set_cuda_stream(void *stream);
30+
31+
// Function `ti_get_cuda_stream`
32+
TI_DLL_EXPORT void TI_API_CALL ti_get_cuda_stream(void **stream);
33+
2734
#ifdef __cplusplus
2835
} // extern "C"
2936
#endif // __cplusplus

c_api/src/taichi_llvm_impl.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#ifdef TI_WITH_CUDA
1616
#include "taichi/rhi/cuda/cuda_device.h"
17+
#include "taichi/rhi/cuda/cuda_context.h"
1718
#include "taichi/runtime/cuda/kernel_launcher.h"
1819
#endif
1920

@@ -242,4 +243,24 @@ TI_DLL_EXPORT TiMemory TI_API_CALL ti_import_cuda_memory(TiRuntime runtime,
242243
#endif
243244
}
244245

246+
// function.set_cuda_stream
247+
TI_DLL_EXPORT void TI_API_CALL ti_set_cuda_stream(void *stream) {
248+
#ifdef TI_WITH_CUDA
249+
taichi::lang::CUDAContext::get_instance().set_stream(stream);
250+
251+
#else
252+
TI_NOT_IMPLEMENTED;
253+
#endif
254+
}
255+
256+
// function.get_cuda_stream
257+
TI_DLL_EXPORT void TI_API_CALL ti_get_cuda_stream(void **stream) {
258+
#ifdef TI_WITH_CUDA
259+
*stream = taichi::lang::CUDAContext::get_instance().get_stream();
260+
#else
261+
TI_NOT_IMPLEMENTED;
262+
263+
#endif
264+
}
265+
245266
#endif // TI_WITH_LLVM

c_api/tests/c_api_interop_test.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,23 @@ TEST_F(CapiTest, TestCUDAImport) {
160160
EXPECT_EQ(data_out[3], 4.0);
161161
}
162162
#endif // TI_WITH_CUDA
163+
164+
#ifdef TI_WITH_CUDA
165+
TEST_F(CapiTest, TestCUDAStreamSet) {
166+
void *temp_stream = nullptr;
167+
168+
ti_get_cuda_stream(&temp_stream);
169+
EXPECT_EQ(temp_stream, nullptr);
170+
171+
void *stream1 = reinterpret_cast<void *>(0x12345678);
172+
void *stream2 = reinterpret_cast<void *>(0x87654321);
173+
174+
ti_set_cuda_stream(stream1);
175+
ti_get_cuda_stream(&temp_stream);
176+
EXPECT_EQ(temp_stream, stream1);
177+
178+
ti_set_cuda_stream(stream2);
179+
ti_get_cuda_stream(&temp_stream);
180+
EXPECT_EQ(temp_stream, stream2);
181+
}
182+
#endif

taichi/rhi/cuda/cuda_context.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
namespace taichi::lang {
1313

1414
CUDAContext::CUDAContext()
15-
: profiler_(nullptr), driver_(CUDADriver::get_instance_without_context()) {
15+
: profiler_(nullptr),
16+
driver_(CUDADriver::get_instance_without_context()),
17+
stream_(nullptr) {
1618
// CUDA initialization
1719
dev_count_ = 0;
1820
driver_.init(0);
@@ -156,14 +158,14 @@ void CUDAContext::launch(void *func,
156158
dynamic_shared_mem_bytes);
157159
}
158160
driver_.launch_kernel(func, grid_dim, 1, 1, block_dim, 1, 1,
159-
dynamic_shared_mem_bytes, nullptr,
161+
dynamic_shared_mem_bytes, stream_,
160162
arg_pointers.data(), nullptr);
161163
}
162164
if (profiler_)
163165
profiler_->stop(task_handle);
164166

165167
if (debug_) {
166-
driver_.stream_synchronize(nullptr);
168+
driver_.stream_synchronize(stream_);
167169
}
168170
}
169171

taichi/rhi/cuda/cuda_context.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class CUDAContext {
2929
int max_shared_memory_bytes_;
3030
bool debug_;
3131
bool supports_mem_pool_;
32+
void *stream_;
3233

3334
public:
3435
CUDAContext();
@@ -108,6 +109,14 @@ class CUDAContext {
108109
}
109110

110111
static CUDAContext &get_instance();
112+
113+
void set_stream(void *stream) {
114+
stream_ = stream;
115+
}
116+
117+
void *get_stream() const {
118+
return stream_;
119+
}
111120
};
112121

113122
} // namespace taichi::lang

0 commit comments

Comments
 (0)