Introduction
For the past few years, I've been following the Enzyme AD project and its extremely fancy capabilities—autodiff for arbitrary functions (within reason) and integrating it into the optimization pipeline. Although I haven't found a good use for it in my own hacking, I have found it fun to explore.
Enzyme is implemented as compiler passes on LLVM and MLIR, and is usable from many LLVM-focused languages, including Julia, Rust (through rustc
), and Swift. However, the most common LLVM frontend used with Enzyme is likely Clang, which handles support for C, C++, and CUDA (of course with all the caveats that come with abandoning NVIDIA's closed-source NVCC compiler driver).
This got my gears turning: Clang happens to support one more language in my wheelhouse—OpenCL, particularly compiling OpenCL C to SPIR-V. This leads to an obvious composition: Can we use Enzyme to get forward- and reverse-mode gradients of OpenCL kernels? Enzyme does not advertise support for OpenCL, but it does support other Clang languages.
What I found was, though it isn't supported and has caveats with it having very limited functionality, it does function today with upstream Enzyme.
Proof-of-concept
The classical demonstration for Enzyme is to demonstrate reverse-mode differentiation of the unary operation of a squaring function, showing that it results in a generated derivative of a doubling of the input argument. We can do one better for demonstration of parallel compute, and sample the function and show it holds true everywhere. This can be accomplished quite straightforwardly on a linear view of a buffer in OpenCL C:
#pragma OPENCL EXTENSION __cl_clang_function_pointers : enable
extern void __enzyme_autodiff(void*, global float*, global float*, global float*, global float*);
kernel void square_impl(global float* input, global float* output) {
size_t i = get_global_id(0);
output[i] = input[i] * input[i];
}
kernel void square_diff(global float* input, global float* d_input, global float* output, global float* d_output) {
__enzyme_autodiff(*square_impl,
input, d_input,
output, d_output);
}
Where square_impl
is our function with an input and output, and square_diff
is the full function with outputs for partial derivatives of the input and output in addition to the other two arguments.
To use this kernel, we can first compile to SPIR-V with Clang, specifying the Enzyme plugin in the dispatch.
clang -fpass-plugin=/path/to/ClangEnzyme-20.so -target spirv64 -cl-std=cl3.0 kernel.cl -O2 -o kernel.spv
Under Clang 20, this compiles to the following LLVM, which is thereafter lowered to the SPIR-V target.
; ModuleID = 'kernel.cl'
source_filename = "kernel.cl"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spirv64"
; Function Attrs: convergent mustprogress nofree norecurse nounwind willreturn memory(argmem: readwrite)
define spir_kernel void @square_impl(ptr addrspace(1) nocapture noundef readonly align 4 %0, ptr addrspace(1) nocapture noundef writeonly align 4 %1) local_unnamed_addr #0 !kernel_arg_addr_space !4 !kernel_arg_access_qual !5 !kernel_arg_type !6 !kernel_arg_base_type !6 !kernel_arg_type_qual !7 {
%3 = tail call spir_func i64 @_Z13get_global_idj(i32 noundef 0) #2
%4 = getelementptr inbounds float, ptr addrspace(1) %0, i64 %3
%5 = load float, ptr addrspace(1) %4, align 4, !tbaa !8
%6 = fmul float %5, %5
%7 = getelementptr inbounds float, ptr addrspace(1) %1, i64 %3
store float %6, ptr addrspace(1) %7, align 4, !tbaa !8
ret void
}
; Function Attrs: convergent mustprogress nofree nounwind willreturn memory(none)
declare spir_func i64 @_Z13get_global_idj(i32 noundef) local_unnamed_addr #1
; Function Attrs: convergent mustprogress nofree norecurse nounwind willreturn memory(argmem: readwrite)
define spir_kernel void @square_diff(ptr addrspace(1) nocapture noundef readonly align 4 %0, ptr addrspace(1) nocapture noundef align 4 %1, ptr addrspace(1) nocapture noundef writeonly align 4 %2, ptr addrspace(1) nocapture noundef align 4 %3) local_unnamed_addr #0 !kernel_arg_addr_space !12 !kernel_arg_access_qual !13 !kernel_arg_type !14 !kernel_arg_base_type !14 !kernel_arg_type_qual !15 {
%5 = tail call spir_func i64 @_Z13get_global_idj(i32 noundef 0) #3
%6 = getelementptr inbounds float, ptr addrspace(1) %1, i64 %5
%7 = getelementptr inbounds float, ptr addrspace(1) %0, i64 %5
%8 = load float, ptr addrspace(1) %7, align 4, !tbaa !8, !alias.scope !16, !noalias !19
%9 = fmul float %8, %8
%10 = getelementptr inbounds float, ptr addrspace(1) %3, i64 %5
%11 = getelementptr inbounds float, ptr addrspace(1) %2, i64 %5
store float %9, ptr addrspace(1) %11, align 4, !tbaa !8, !alias.scope !21, !noalias !24
%12 = load float, ptr addrspace(1) %10, align 4, !tbaa !8, !alias.scope !24, !noalias !21
store float 0.000000e+00, ptr addrspace(1) %10, align 4, !tbaa !8, !alias.scope !24, !noalias !21
%13 = load float, ptr addrspace(1) %6, align 4, !tbaa !8, !alias.scope !19, !noalias !16
%14 = fmul fast float %8, 2.000000e+00
%15 = fmul fast float %14, %12
%16 = fadd fast float %15, %13
store float %16, ptr addrspace(1) %6, align 4, !tbaa !8, !alias.scope !19, !noalias !16
ret void
}
attributes #0 = { convergent mustprogress nofree norecurse nounwind willreturn memory(argmem: readwrite) "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "uniform-work-group-size"="false" }
attributes #1 = { convergent mustprogress nofree nounwind willreturn memory(none) "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #2 = { convergent nounwind willreturn memory(none) }
attributes #3 = { convergent mustprogress nounwind willreturn memory(none) }
!llvm.module.flags = !{!0, !1}
!opencl.ocl.version = !{!2}
!llvm.ident = !{!3}
!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 7, !"frame-pointer", i32 2}
!2 = !{i32 3, i32 0}
!3 = !{!"clang version 16.0.6 (Fedora 16.0.6-3.fc38)"}
!4 = !{i32 1, i32 1}
!5 = !{!"none", !"none"}
!6 = !{!"float*", !"float*"}
!7 = !{!"", !""}
!8 = !{!9, !9, i64 0}
!9 = !{!"float", !10, i64 0}
!10 = !{!"omnipotent char", !11, i64 0}
!11 = !{!"Simple C/C++ TBAA"}
!12 = !{i32 1, i32 1, i32 1, i32 1}
!13 = !{!"none", !"none", !"none", !"none"}
!14 = !{!"float*", !"float*", !"float*", !"float*"}
!15 = !{!"", !"", !"", !""}
!16 = !{!17}
!17 = distinct !{!17, !18, !"primal"}
!18 = distinct !{!18, !" diff: %"}
!19 = !{!20}
!20 = distinct !{!20, !18, !"shadow_0"}
!21 = !{!22}
!22 = distinct !{!22, !23, !"primal"}
!23 = distinct !{!23, !" diff: %"}
!24 = !{!25}
!25 = distinct !{!25, !23, !"shadow_0"}
We can tell that Enzyme has done its trick by the SSA expression %14 = fmul fast float %8, 2.000000e+00
in square_diff
.
To use this kernel, we can write a simple C23 program with a traditional OpenCL pipeline, using #embed
to avoid having to read in the SPIR-V from file or inserting it into the source directly as an array.
// C standard includes
#include <stdio.h>
#include <stdlib.h> // For malloc, free, exit
#include <string.h> // For strstr, strcmp, strncpy
// OpenCL includes
//#include <CL/cl.h>
#include <CL/cl_ext.h>
// For CL_DEVICE_ILS_WITH_VERSION_KHR and cl_name_version_khr, ensure your OpenCL headers are somewhat recent.
// These are part of the cl_khr_il_program extension, which is core in OpenCL 3.0 for this functionality.
#ifndef CL_NAME_MAX_KHR // From cl_platform.h if using older individual headers
#define CL_NAME_MAX_KHR 64
#endif
#ifndef CL_VERSION_MAJOR_KHR
#define CL_VERSION_MAJOR_KHR(version) ((version) >> 22)
#endif
#ifndef CL_VERSION_MINOR_KHR
#define CL_VERSION_MINOR_KHR(version) (((version) >> 12) & 0x3ff)
#endif
#ifndef CL_VERSION_PATCH_KHR
#define CL_VERSION_PATCH_KHR(version) ((version) & 0xfff)
#endif
static const char spirv[] = {
#embed "kernel.spv"
};
void print_float_array(float *array, size_t length) {
putc('{', stdout);
if(length > 0) printf("%f", array[0]);
for (size_t index = 1; index < length; index++) {
printf(", %f", array[index]);
}
fputs("};\n", stdout);
}
void check_cl_error(cl_int err, const char *operation) {
if (err != CL_SUCCESS) {
fprintf(stderr, "OpenCL Error during %s: %d\n", operation, err);
// Add more detailed error code to string conversion here if desired
exit(EXIT_FAILURE);
}
}
void find_spirv_capable_device(cl_platform_id* p_platform, cl_device_id* p_device) {
cl_int CL_err;
cl_uint numPlatforms = 0;
CL_err = clGetPlatformIDs(0, NULL, &numPlatforms);
if (CL_err != CL_SUCCESS || numPlatforms == 0) {
fprintf(stderr, "No OpenCL platforms found or error getting count: %d\n", CL_err);
exit(EXIT_FAILURE);
}
cl_platform_id *platforms = (cl_platform_id *)malloc(sizeof(cl_platform_id) * numPlatforms);
if (!platforms) { perror("malloc platforms"); exit(EXIT_FAILURE); }
CL_err = clGetPlatformIDs(numPlatforms, platforms, NULL);
if (CL_err != CL_SUCCESS) {
fprintf(stderr, "Error getting platform IDs: %d\n", CL_err);
free(platforms);
exit(EXIT_FAILURE);
}
*p_platform = NULL;
*p_device = NULL;
printf("%u platform(s) found\n", numPlatforms);
for (cl_uint i = 0; i < numPlatforms; ++i) {
cl_platform_id current_platform = platforms[i];
char platform_name[256];
char platform_version_str[256];
clGetPlatformInfo(current_platform, CL_PLATFORM_NAME, sizeof(platform_name), platform_name, NULL);
clGetPlatformInfo(current_platform, CL_PLATFORM_VERSION, sizeof(platform_version_str), platform_version_str, NULL);
printf("Platform %u: %s (Version: %s)\n", i, platform_name, platform_version_str);
cl_uint numDev = 0;
CL_err = clGetDeviceIDs(current_platform, CL_DEVICE_TYPE_ALL, 0, NULL, &numDev);
if (CL_err != CL_SUCCESS || numDev == 0) {
// printf(" No devices found on platform %u or error: %d\n", i, CL_err); // Less verbose
continue;
}
cl_device_id *devs = (cl_device_id *)malloc(numDev * sizeof(cl_device_id));
if (!devs) { perror("malloc devs for platform"); continue; }
CL_err = clGetDeviceIDs(current_platform, CL_DEVICE_TYPE_ALL, numDev, devs, NULL);
if (CL_err != CL_SUCCESS) {
// printf(" Error getting device IDs for platform %u: %d\n", i, CL_err); // Less verbose
free(devs);
continue;
}
for (cl_uint j = 0; j < numDev; ++j) {
cl_device_id current_device = devs[j];
char device_name_str[256] = {0};
char device_version_str[256] = {0}; // OpenCL device version (e.g., OpenCL 1.2, OpenCL 3.0)
char extensions_str[16384] = {0}; // Increased size for extensions
clGetDeviceInfo(current_device, CL_DEVICE_NAME, sizeof(device_name_str), device_name_str, NULL);
clGetDeviceInfo(current_device, CL_DEVICE_VERSION, sizeof(device_version_str), device_version_str, NULL);
clGetDeviceInfo(current_device, CL_DEVICE_EXTENSIONS, sizeof(extensions_str), extensions_str, NULL);
printf(" Device %u: %s (Device Version: %s)\n", j, device_name_str, device_version_str);
int supports_spirv = 0;
int supports_create_with_il = 0;
// Determine if clCreateProgramWithIL is likely supported
if (strstr(device_version_str, "OpenCL 2.1") || strstr(device_version_str, "OpenCL 2.2")) {
supports_create_with_il = 1;
} else if (strstr(extensions_str, "cl_khr_il_program")) {
supports_create_with_il = 1;
}
// For OpenCL 3.0, cl_khr_il_program is an optional feature.
// If CL_DEVICE_ILS_WITH_VERSION_KHR query succeeds later, it also implies support.
// Check 1: OpenCL 2.2+ mandatory SPIR-V support via CL_DEVICE_SPIR_VERSIONS
char spir_versions_str[256] = {0};
if (clGetDeviceInfo(current_device, CL_DEVICE_SPIR_VERSIONS, sizeof(spir_versions_str), spir_versions_str, NULL) == CL_SUCCESS && strlen(spir_versions_str) > 0) {
if (strstr(spir_versions_str, "1.2") || strstr(spir_versions_str, "1.3") || strstr(spir_versions_str, "1.4") || strstr(spir_versions_str, "1.5") || strstr(spir_versions_str, "2.0")) {
printf(" SPIR-V support confirmed via CL_DEVICE_SPIR_VERSIONS: %s\n", spir_versions_str);
supports_spirv = 1;
supports_create_with_il = 1; // Mandatory for OpenCL 2.2+
}
}
// Check 2: OpenCL 2.1 specific: CL_DEVICE_IL_VERSION. Also for OpenCL 3.0 with cl_khr_il_program feature.
if (!supports_spirv) {
char il_version_str[1024] = {0}; // Increased size
if (clGetDeviceInfo(current_device, CL_DEVICE_IL_VERSION_KHR, sizeof(il_version_str), il_version_str, NULL) == CL_SUCCESS && strlen(il_version_str) > 0) {
if (strstr(il_version_str, "SPIR-V")) {
printf(" SPIR-V support confirmed via CL_DEVICE_IL_VERSION: %s\n", il_version_str);
supports_spirv = 1;
if (strstr(device_version_str, "OpenCL 2.1")) supports_create_with_il = 1; // Mandatory for 2.1
// If cl_khr_il_program was seen, supports_create_with_il would be true already
}
}
}
// Check 3: OpenCL 3.0 specific: CL_DEVICE_ILS_WITH_VERSION_KHR
if (!supports_spirv && (strstr(device_version_str, "OpenCL 3.0") || strstr(extensions_str, "cl_khr_il_program"))) {
size_t num_ils_size = 0; // Use size_t for the size parameter
cl_uint num_ils = 0; // For the actual count
// First, get the size of the data (number of cl_name_version_khr entries)
CL_err = clGetDeviceInfo(current_device, CL_DEVICE_ILS_WITH_VERSION_KHR, 0, NULL, &num_ils_size);
if (CL_err == CL_SUCCESS && num_ils_size > 0) {
num_ils = num_ils_size / sizeof(cl_name_version_khr); // Calculate number of entries
if (num_ils > 0) {
supports_create_with_il = 1; // If this query works and returns items, IL loading is supported
cl_name_version_khr *ils = (cl_name_version_khr *)malloc(num_ils_size);
if (ils) {
CL_err = clGetDeviceInfo(current_device, CL_DEVICE_ILS_WITH_VERSION_KHR, num_ils_size, ils, NULL);
if (CL_err == CL_SUCCESS) {
for (cl_uint k = 0; k < num_ils; ++k) {
char il_name[CL_NAME_MAX_KHR + 1] = {0}; // Ensure null termination
strncpy(il_name, ils[k].name, CL_NAME_MAX_KHR);
if (strcmp(il_name, "SPIR-V") == 0) {
printf(" SPIR-V support confirmed via CL_DEVICE_ILS_WITH_VERSION_KHR (Version: %u.%u.%u)\n",
CL_VERSION_MAJOR_KHR(ils[k].version),
CL_VERSION_MINOR_KHR(ils[k].version),
CL_VERSION_PATCH_KHR(ils[k].version));
supports_spirv = 1;
break;
}
}
}
free(ils);
}
}
}
}
// Check 4: Fallback - if clCreateProgramWithIL is supported, check for SPIR-V extensions
if (supports_create_with_il && !supports_spirv) {
if (strstr(extensions_str, "cl_khr_spirv_no_priority") ||
strstr(extensions_str, "cl_khr_spirv_linker") || // Example of other SPIR-V related extensions
strstr(extensions_str, "cl_khr_spirv_validation")) {
printf(" SPIR-V support inferred from cl_khr_il_program and specific SPIR-V extensions.\n");
supports_spirv = 1;
}
}
if (supports_spirv && supports_create_with_il) {
printf(" Selected device: %s on platform %s for SPIR-V.\n", device_name_str, platform_name);
*p_platform = current_platform;
*p_device = current_device;
free(devs);
goto found_device_and_platform;
}
}
free(devs);
}
found_device_and_platform:
free(platforms); // Free the list of all platform IDs
if (*p_device == NULL) {
fprintf(stderr, "\nNo suitable OpenCL platform/device found that supports clCreateProgramWithIL with SPIR-V.\n");
fprintf(stderr, "Please ensure your OpenCL drivers are up to date and support SPIR-V.\n");
exit(EXIT_FAILURE);
}
}
int main() {
cl_int CL_err = CL_SUCCESS;
cl_platform_id platform;
cl_device_id device;
find_spirv_capable_device(&platform, &device);
cl_context_properties context_props[] = {CL_CONTEXT_PLATFORM, (cl_context_properties)platform, 0};
cl_context context = clCreateContext(context_props, 1, &device, NULL, NULL, &CL_err);
check_cl_error(CL_err, "clCreateContext");
cl_command_queue command_queue = clCreateCommandQueueWithProperties(context, device, NULL, &CL_err);
check_cl_error(CL_err, "clCreateCommandQueueWithProperties");
const size_t num = 32;
cl_float *p_x = (cl_float *)malloc(sizeof(cl_float) * num);
cl_float *p_y = (cl_float *)malloc(sizeof(cl_float) * num);
cl_float *p_dx = (cl_float *)malloc(sizeof(cl_float) * num);
cl_float *p_dy = (cl_float *)malloc(sizeof(cl_float) * num);
if (!p_x || !p_y || !p_dx || !p_dy) {
perror("malloc host buffers");
// Basic cleanup, more would be needed in a real app
if(command_queue) clReleaseCommandQueue(command_queue);
if(context) clReleaseContext(context);
exit(EXIT_FAILURE);
}
cl_mem x = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR | CL_MEM_HOST_WRITE_ONLY, sizeof(cl_float) * num, p_x, &CL_err);
check_cl_error(CL_err, "clCreateBuffer x");
cl_mem y = clCreateBuffer(context, CL_MEM_WRITE_ONLY | CL_MEM_USE_HOST_PTR | CL_MEM_HOST_READ_ONLY, sizeof(cl_float) * num, p_y, &CL_err);
check_cl_error(CL_err, "clCreateBuffer y");
cl_mem dx = clCreateBuffer(context, CL_MEM_WRITE_ONLY | CL_MEM_USE_HOST_PTR | CL_MEM_HOST_READ_ONLY, sizeof(cl_float) * num, p_dx, &CL_err);
check_cl_error(CL_err, "clCreateBuffer dx");
cl_mem dy = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_USE_HOST_PTR | CL_MEM_HOST_WRITE_ONLY, sizeof(cl_float) * num, p_dy, &CL_err);
check_cl_error(CL_err, "clCreateBuffer dy");
for (size_t index = 0; index < num; index++) {
p_x[index] = (cl_float)(index);
p_dx[index] = (cl_float)0.0f;
p_dy[index] = (cl_float)1.0f;
}
cl_event write_ops_complete[2];
CL_err = clEnqueueWriteBuffer(command_queue, x, CL_FALSE, 0, sizeof(cl_float) * num, (const void *)p_x, 0, NULL, &(write_ops_complete[0]));
check_cl_error(CL_err, "clEnqueueWriteBuffer x");
CL_err = clEnqueueWriteBuffer(command_queue, dy, CL_FALSE, 0, sizeof(cl_float) * num, (const void *)p_dy, 0, NULL, &(write_ops_complete[1]));
check_cl_error(CL_err, "clEnqueueWriteBuffer dy");
cl_program square_program = clCreateProgramWithIL(context, (const void *)(spirv), sizeof(spirv), &CL_err);
check_cl_error(CL_err, "clCreateProgramWithIL");
CL_err = clBuildProgram(square_program, 1, &device, NULL, NULL, NULL);
if (CL_err != CL_SUCCESS) {
fprintf(stderr, "Error building program: %d\n", CL_err);
size_t log_size;
clGetProgramBuildInfo(square_program, device, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
if (log_size > 1) {
char *log = (char *)malloc(log_size);
if (log) {
clGetProgramBuildInfo(square_program, device, CL_PROGRAM_BUILD_LOG, log_size, log, NULL);
fprintf(stderr, "Build Log:\n%s\n", log);
free(log);
}
}
exit(EXIT_FAILURE);
}
cl_kernel kernel_square_diff = clCreateKernel(square_program, "square_diff", &CL_err);
check_cl_error(CL_err, "clCreateKernel square_diff");
CL_err = clSetKernelArg(kernel_square_diff, 0, sizeof(cl_mem), (void *)&x); check_cl_error(CL_err, "clSetKernelArg 0");
CL_err = clSetKernelArg(kernel_square_diff, 1, sizeof(cl_mem), (void *)&dx); check_cl_error(CL_err, "clSetKernelArg 1");
CL_err = clSetKernelArg(kernel_square_diff, 2, sizeof(cl_mem), (void *)&y); check_cl_error(CL_err, "clSetKernelArg 2");
CL_err = clSetKernelArg(kernel_square_diff, 3, sizeof(cl_mem), (void *)&dy); check_cl_error(CL_err, "clSetKernelArg 3");
cl_event kernel_complete;
const size_t global_work_offset[1] = { 0U };
const size_t global_work_size[1] = { num };
const size_t local_work_size[1] = { 16U };
CL_err = clEnqueueNDRangeKernel(command_queue, kernel_square_diff, 1, global_work_offset, global_work_size, local_work_size, 2, write_ops_complete, &kernel_complete);
check_cl_error(CL_err, "clEnqueueNDRangeKernel");
cl_event read_complete[2];
CL_err = clEnqueueReadBuffer(command_queue, y, CL_FALSE, 0, sizeof(cl_float) * num, p_y, 1, &kernel_complete, &(read_complete[0]));
check_cl_error(CL_err, "clEnqueueReadBuffer y");
CL_err = clEnqueueReadBuffer(command_queue, dx, CL_FALSE, 0, sizeof(cl_float) * num, p_dx, 1, &kernel_complete, &(read_complete[1]));
check_cl_error(CL_err, "clEnqueueReadBuffer dx");
CL_err = clWaitForEvents(2, read_complete);
check_cl_error(CL_err, "clWaitForEvents for reads");
printf("float x[%zu] = ", num); print_float_array(p_x, num);
printf("float y[%zu] = ", num); print_float_array(p_y, num);
printf("float dx[%zu] = ", num); print_float_array(p_dx, num);
printf("float dy[%zu] = ", num); print_float_array(p_dy, num);
clFlush(command_queue);
clFinish(command_queue);
clReleaseEvent(write_ops_complete[0]);
clReleaseEvent(write_ops_complete[1]);
clReleaseEvent(kernel_complete);
clReleaseEvent(read_complete[0]);
clReleaseEvent(read_complete[1]);
clReleaseMemObject(x);
clReleaseMemObject(y);
clReleaseMemObject(dx);
clReleaseMemObject(dy);
clReleaseKernel(kernel_square_diff);
clReleaseProgram(square_program);
clReleaseCommandQueue(command_queue);
clReleaseContext(context);
free(p_x); free(p_y); free(p_dx); free(p_dy);
return 0;
}
The relatively high complexity of platform and device selection in this program exposes the biggest restriction of this approach: Since Clang can only output the compiled OpenCL to SPIR-V, the underlying implementation must support creating programs from SPIR-V with clCreateProgramWithIL
. This is something OpenCL 1.2 and basic OpenCL 3.0 implementations do not support—said functionality was introduced in OpenCL 2.1, retained in OpenCL 2.2, and made optional in OpenCL 3.0. This means that only implementations that care to provide a rich OpenCL 3.0 experience carry the feature, with the only vendors who do so being Intel (with Compute Runtime) and the Mesa project (through Rusticl). Luckily, these are widely available enough and cover enough hardware that portability of such programs should be rather good to all GPU vendors under Linux, but NVIDIA and AMD are notably behind Intel with their lack of first-party implementations—Rusticl really is a godsend here to, in combination with Zink, leverage the excellent Vulkan drivers available for those lacking platforms (such as NVK and NVIDIA's proprietary driver for NVIDIA, and RADV for AMD).
Regardless, running the program provides us with what we are after, printed nicely C-formatted for clarity:
float x[32] = {0.000000, 1.000000, 2.000000, 3.000000, 4.000000, 5.000000, 6.000000, 7.000000, 8.000000, 9.000000, 10.000000, 11.000000, 12.000000, 13.000000, 14.000000, 15.000000, 16.000000, 17.000000, 18.000000, 19.000000, 20.000000, 21.000000, 22.000000, 23.000000, 24.000000, 25.000000, 26.000000, 27.000000, 28.000000, 29.000000, 30.000000, 31.000000};
float y[32] = {0.000000, 1.000000, 4.000000, 9.000000, 16.000000, 25.000000, 36.000000, 49.000000, 64.000000, 81.000000, 100.000000, 121.000000, 144.000000, 169.000000, 196.000000, 225.000000, 256.000000, 289.000000, 324.000000, 361.000000, 400.000000, 441.000000, 484.000000, 529.000000, 576.000000, 625.000000, 676.000000, 729.000000, 784.000000, 841.000000, 900.000000, 961.000000};
float dx[32] = {0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 16.000000, 18.000000, 20.000000, 22.000000, 24.000000, 26.000000, 28.000000, 30.000000, 32.000000, 34.000000, 36.000000, 38.000000, 40.000000, 42.000000, 44.000000, 46.000000, 48.000000, 50.000000, 52.000000, 54.000000, 56.000000, 58.000000, 60.000000, 62.000000};
float dy[32] = {1.000000, 1.000000, 1.000000, 1.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000};
Future Work
The fact that this is possible at all is a testament to how flexible, robust, and well-designed Clang and LLVM are, but this is a very simple example and more performant kernels would need to leverage features that Enzyme cannot currently handle (such as subgroup operations). It is however conceivable that these limitations could be eliminated through contributions upstream to enable and robustify support for OpenCL.