Skip to content

Commit 32948cd

Browse files
authored
Merge pull request #2038 from sommerlukas/raw-kernel-arg
Add support for raw_kernel_arg extension
2 parents 29eeac7 + 5fb74e6 commit 32948cd

16 files changed

+694
-1
lines changed

dpctl/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from ._sycl_platform import SyclPlatform, get_platforms, lsplatform
5151
from ._sycl_queue import (
5252
LocalAccessor,
53+
RawKernelArg,
5354
SyclKernelInvalidRangeError,
5455
SyclKernelSubmitError,
5556
SyclQueue,
@@ -106,6 +107,7 @@
106107
"SyclQueueCreationError",
107108
"WorkGroupMemory",
108109
"LocalAccessor",
110+
"RawKernelArg",
109111
]
110112
__all__ += [
111113
"get_device_cached_queue",

dpctl/_backend.pxd

+12
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ cdef extern from "syclinterface/dpctl_sycl_enum_types.h":
7171
_VOID_PTR "DPCTL_VOID_PTR",
7272
_LOCAL_ACCESSOR "DPCTL_LOCAL_ACCESSOR",
7373
_WORK_GROUP_MEMORY "DPCTL_WORK_GROUP_MEMORY"
74+
_RAW_KERNEL_ARG "DPCTL_RAW_KERNEL_ARG"
7475

7576
ctypedef enum _queue_property_type "DPCTLQueuePropertyType":
7677
_DEFAULT_PROPERTY "DPCTL_DEFAULT_PROPERTY"
@@ -571,3 +572,14 @@ cdef extern from "syclinterface/dpctl_sycl_extension_interface.h":
571572
DPCTLSyclWorkGroupMemoryRef Ref)
572573

573574
cdef bint DPCTLWorkGroupMemory_Available()
575+
576+
cdef struct DPCTLOpaqueRawKernelArg
577+
ctypedef DPCTLOpaqueRawKernelArg *DPCTLSyclRawKernelArgRef
578+
579+
cdef DPCTLSyclRawKernelArgRef DPCTLRawKernelArg_Create(void* bytes,
580+
size_t count)
581+
582+
cdef void DPCTLRawKernelArg_Delete(
583+
DPCTLSyclRawKernelArgRef Ref)
584+
585+
cdef bint DPCTLRawKernelArg_Available()

dpctl/_sycl_queue.pxd

+11
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ from libcpp cimport bool as cpp_bool
2525
from ._backend cimport (
2626
DPCTLSyclDeviceRef,
2727
DPCTLSyclQueueRef,
28+
DPCTLSyclRawKernelArgRef,
2829
DPCTLSyclWorkGroupMemoryRef,
2930
_arg_data_type,
3031
)
@@ -115,3 +116,13 @@ cdef public api class WorkGroupMemory(_WorkGroupMemory) [
115116
object PyWorkGroupMemoryObject, type PyWorkGroupMemoryType
116117
]:
117118
pass
119+
120+
cdef public api class _RawKernelArg [
121+
object Py_RawKernelArgObject, type Py_RawKernelArgType
122+
]:
123+
cdef DPCTLSyclRawKernelArgRef _arg_ref
124+
125+
cdef public api class RawKernelArg(_RawKernelArg) [
126+
object PyRawKernelArgObject, type PyRawKernelArgType
127+
]:
128+
pass

dpctl/_sycl_queue.pyx

+110
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ from ._backend cimport ( # noqa: E211
5151
DPCTLQueue_SubmitNDRange,
5252
DPCTLQueue_SubmitRange,
5353
DPCTLQueue_Wait,
54+
DPCTLRawKernelArg_Available,
55+
DPCTLRawKernelArg_Create,
56+
DPCTLRawKernelArg_Delete,
5457
DPCTLSyclContextRef,
5558
DPCTLSyclDeviceSelectorRef,
5659
DPCTLSyclEventRef,
@@ -364,6 +367,15 @@ cdef class _kernel_arg_type:
364367
_arg_data_type._WORK_GROUP_MEMORY
365368
)
366369

370+
@property
371+
def dpctl_raw_kernel_arg(self):
372+
cdef str p_name = "dpctl_raw_kernel_arg"
373+
return kernel_arg_type_attribute(
374+
self._name,
375+
p_name,
376+
_arg_data_type._RAW_KERNEL_ARG
377+
)
378+
367379

368380
kernel_arg_type = _kernel_arg_type()
369381

@@ -973,6 +985,9 @@ cdef class SyclQueue(_SyclQueue):
973985
elif isinstance(arg, LocalAccessor):
974986
kargs[idx] = <void*>((<LocalAccessor>arg).addressof())
975987
kargty[idx] = _arg_data_type._LOCAL_ACCESSOR
988+
elif isinstance(arg, RawKernelArg):
989+
kargs[idx] = <void*>(<size_t>arg._ref)
990+
kargty[idx] = _arg_data_type._RAW_KERNEL_ARG
976991
else:
977992
ret = -1
978993
return ret
@@ -1738,3 +1753,98 @@ cdef class WorkGroupMemory:
17381753
"""
17391754
def __get__(self):
17401755
return <size_t>self._mem_ref
1756+
1757+
1758+
cdef class _RawKernelArg:
1759+
def __dealloc(self):
1760+
if(self._arg_ref):
1761+
DPCTLRawKernelArg_Delete(self._arg_ref)
1762+
1763+
1764+
cdef class RawKernelArg:
1765+
"""
1766+
RawKernelArg(*args)
1767+
Python class representing the ``raw_kernel_arg`` class from the Raw Kernel
1768+
Argument oneAPI SYCL extension for passing binary data as data to kernels.
1769+
1770+
This class is intended to be used as kernel argument when launching kernels.
1771+
1772+
This is based on a DPC++ SYCL extension and only available in newer
1773+
versions. Use ``is_available()`` to check availability in your build.
1774+
1775+
There are multiple ways to create a ``RawKernelArg``.
1776+
1777+
- If the constructor is invoked with just a single argument, this argument
1778+
is expected to expose the Python buffer interface. The raw kernel arg will
1779+
be constructed from the data in that buffer.
1780+
1781+
- If the constructor is invoked with two arguments, the first argument is
1782+
interpreted as the number of bytes in the binary argument, while the
1783+
second argument is interpreted as a pointer to the data.
1784+
1785+
Note that construction of the ``RawKernelArg`` copies the bytes, so
1786+
modifications made after construction of the ``RawKernelArg`` will not be
1787+
reflected in the kernel launch.
1788+
1789+
Args:
1790+
args:
1791+
Variadic argument, see class documentation.
1792+
1793+
Raises:
1794+
TypeError: In case of incorrect arguments given to constructurs,
1795+
unexpected types of input arguments.
1796+
"""
1797+
def __cinit__(self, *args):
1798+
cdef void* ptr = NULL
1799+
cdef size_t count
1800+
cdef int ret_code = 0
1801+
cdef Py_buffer _buffer
1802+
cdef bint _is_buf
1803+
1804+
if not DPCTLRawKernelArg_Available():
1805+
raise RuntimeError("Raw kernel arg extension not available")
1806+
1807+
if not (0 < len(args) < 3):
1808+
raise TypeError("RawKernelArg constructor takes 1 or 2 "
1809+
f"arguments, but {len(args)} were given")
1810+
1811+
if len(args) == 1:
1812+
if not _is_buffer(args[0]):
1813+
raise TypeError("RawKernelArg single argument constructor"
1814+
"expects argument to be buffer",
1815+
f"but got {type(args[0])}")
1816+
1817+
ret_code = PyObject_GetBuffer(args[0], &(_buffer),
1818+
PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS)
1819+
if ret_code != 0: # pragma: no cover
1820+
raise RuntimeError("Could not access buffer")
1821+
1822+
ptr = _buffer.buf
1823+
count = _buffer.len
1824+
_is_buf = True
1825+
else:
1826+
if not isinstance(args[0], numbers.Integral):
1827+
raise TypeError("RawKernelArg constructor expects first"
1828+
"argument to be `int`, but got {type(args[0])}")
1829+
if not isinstance(args[1], numbers.Integral):
1830+
raise TypeError("RawKernelArg constructor expects second"
1831+
"argument to be `int`, but got {type(args[1])}")
1832+
1833+
_is_buf = False
1834+
count = args[0]
1835+
ptr = <void*>(<unsigned long long>args[1])
1836+
1837+
self._arg_ref = DPCTLRawKernelArg_Create(ptr, count)
1838+
if(_is_buf):
1839+
PyBuffer_Release(&(_buffer))
1840+
1841+
@staticmethod
1842+
def is_available():
1843+
return DPCTLRawKernelArg_Available()
1844+
1845+
property _ref:
1846+
"""Returns the address of the C API ``DPCTLRawKernelArgRef`` pointer
1847+
as a ``size_t``.
1848+
"""
1849+
def __get__(self):
1850+
return <size_t>self._arg_ref
1.6 KB
Binary file not shown.

dpctl/tests/test_raw_kernel_arg.py

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Defines unit test cases for the work_group_memory in a SYCL kernel"""
18+
19+
import ctypes
20+
import os
21+
22+
import pytest
23+
24+
import dpctl
25+
import dpctl.tensor
26+
27+
28+
def get_spirv_abspath(fn):
29+
curr_dir = os.path.dirname(os.path.abspath(__file__))
30+
spirv_file = os.path.join(curr_dir, "input_files", fn)
31+
return spirv_file
32+
33+
34+
# The kernel in the SPIR-V file used in this test was generated from the
35+
# following SYCL source code:
36+
# #include <sycl/sycl.hpp>
37+
#
38+
# using namespace sycl;
39+
#
40+
# namespace syclexp = sycl::ext::oneapi::experimental;
41+
# namespace syclext = sycl::ext::oneapi;
42+
#
43+
# using data_t = int32_t;
44+
#
45+
# struct Params { data_t mul; data_t add; };
46+
#
47+
# extern "C" SYCL_EXTERNAL
48+
# SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>))
49+
# void raw_arg_kernel(data_t* in, data_t* out, Params p){
50+
# auto item = syclext::this_work_item::get_nd_item<1>();
51+
# size_t global_id = item.get_global_linear_id();
52+
# out[global_id] = (in[global_id] * p.mul) + p.add;
53+
# }
54+
55+
56+
class Params(ctypes.Structure):
57+
_fields_ = [("mul", ctypes.c_int32), ("add", ctypes.c_int32)]
58+
59+
60+
def launch_raw_arg_kernel(raw):
61+
if not dpctl.RawKernelArg.is_available():
62+
pytest.skip("Raw kernel arg extension not supported")
63+
64+
try:
65+
q = dpctl.SyclQueue("level_zero")
66+
except dpctl.SyclQueueCreationError:
67+
pytest.skip("LevelZero queue could not be created")
68+
spirv_file = get_spirv_abspath("raw-arg-kernel.spv")
69+
with open(spirv_file, "br") as spv:
70+
spv_bytes = spv.read()
71+
prog = dpctl.program.create_program_from_spirv(q, spv_bytes)
72+
kernel = prog.get_sycl_kernel("__sycl_kernel_raw_arg_kernel")
73+
local_size = 16
74+
global_size = local_size * 8
75+
76+
x = dpctl.tensor.ones(global_size, dtype="int32")
77+
y = dpctl.tensor.zeros(global_size, dtype="int32")
78+
x.sycl_queue.wait()
79+
y.sycl_queue.wait()
80+
81+
try:
82+
q.submit(
83+
kernel,
84+
[
85+
x.usm_data,
86+
y.usm_data,
87+
raw,
88+
],
89+
[global_size],
90+
[local_size],
91+
)
92+
q.wait()
93+
except dpctl._sycl_queue.SyclKernelSubmitError:
94+
pytest.skip(f"Kernel submission to {q.sycl_device} failed")
95+
96+
assert dpctl.tensor.all(y == 9)
97+
98+
99+
def test_submit_raw_kernel_arg_pointer():
100+
paramStruct = Params(4, 5)
101+
raw = dpctl.RawKernelArg(
102+
ctypes.sizeof(paramStruct), ctypes.addressof(paramStruct)
103+
)
104+
launch_raw_arg_kernel(raw)
105+
106+
107+
def test_submit_raw_kernel_arg_buffer():
108+
paramStruct = Params(4, 5)
109+
byteArr = bytearray(paramStruct)
110+
raw = dpctl.RawKernelArg(byteArr)
111+
del byteArr
112+
launch_raw_arg_kernel(raw)

dpctl/tests/test_sycl_kernel_submit.py

+1
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def test_kernel_arg_type():
280280
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_void_ptr)
281281
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_local_accessor)
282282
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_work_group_memory)
283+
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_raw_kernel_arg)
283284

284285

285286
def get_spirv_abspath(fn):

libsyclinterface/include/syclinterface/dpctl_sycl_enum_types.h

+1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ typedef enum
101101
DPCTL_VOID_PTR,
102102
DPCTL_LOCAL_ACCESSOR,
103103
DPCTL_WORK_GROUP_MEMORY,
104+
DPCTL_RAW_KERNEL_ARG,
104105
DPCTL_UNSUPPORTED_KERNEL_ARG
105106
} DPCTLKernelArgType;
106107

libsyclinterface/include/syclinterface/dpctl_sycl_extension_interface.h

+12
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,16 @@ void DPCTLWorkGroupMemory_Delete(__dpctl_take DPCTLSyclWorkGroupMemoryRef Ref);
5353
DPCTL_API
5454
bool DPCTLWorkGroupMemory_Available();
5555

56+
typedef struct DPCTLOpaqueSyclRawKernelArg *DPCTLSyclRawKernelArgRef;
57+
58+
DPCTL_API
59+
__dpctl_give DPCTLSyclRawKernelArgRef DPCTLRawKernelArg_Create(void *bytes,
60+
size_t count);
61+
62+
DPCTL_API
63+
void DPCTLRawKernelArg_Delete(__dpctl_take DPCTLSyclRawKernelArgRef Ref);
64+
65+
DPCTL_API
66+
bool DPCTLRawKernelArg_Available();
67+
5668
DPCTL_C_EXTERN_C_END

libsyclinterface/include/syclinterface/dpctl_sycl_type_casters.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ DEFINE_SIMPLE_CONVERSION_FUNCTIONS(std::vector<DPCTLSyclEventRef>,
8484
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(RawWorkGroupMemory,
8585
DPCTLSyclWorkGroupMemoryRef)
8686

87+
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(std::vector<unsigned char>,
88+
DPCTLSyclRawKernelArgRef)
89+
8790
#endif
8891

8992
} // namespace dpctl::syclinterface

libsyclinterface/source/dpctl_sycl_extension_interface.cpp

+35
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,38 @@ bool DPCTLWorkGroupMemory_Available()
6262
return false;
6363
#endif
6464
}
65+
66+
using raw_kernel_arg_t = std::vector<unsigned char>;
67+
68+
DPCTL_API
69+
__dpctl_give DPCTLSyclRawKernelArgRef DPCTLRawKernelArg_Create(void *bytes,
70+
size_t count)
71+
{
72+
DPCTLSyclRawKernelArgRef rka = nullptr;
73+
try {
74+
auto RawKernelArg =
75+
std::unique_ptr<raw_kernel_arg_t>(new raw_kernel_arg_t(count));
76+
std::memcpy(RawKernelArg->data(), bytes, count);
77+
rka = wrap<raw_kernel_arg_t>(RawKernelArg.get());
78+
RawKernelArg.release();
79+
} catch (std::exception const &e) {
80+
error_handler(e, __FILE__, __func__, __LINE__);
81+
}
82+
return rka;
83+
}
84+
85+
DPCTL_API
86+
void DPCTLRawKernelArg_Delete(__dpctl_take DPCTLSyclRawKernelArgRef Ref)
87+
{
88+
delete unwrap<raw_kernel_arg_t>(Ref);
89+
}
90+
91+
DPCTL_API
92+
bool DPCTLRawKernelArg_Available()
93+
{
94+
#ifdef SYCL_EXT_ONEAPI_RAW_KERNEL_ARG
95+
return true;
96+
#else
97+
return false;
98+
#endif
99+
}

0 commit comments

Comments
 (0)