Java与WebGPU/Vulkan:图形渲染、通用计算的GPU加速与接口设计

Java与WebGPU/Vulkan:图形渲染、通用计算的GPU加速与接口设计

各位同学,大家好。今天我们来探讨一个非常有意思的主题:Java与WebGPU/Vulkan,看看如何在Java环境下利用这些现代图形API进行GPU加速,包括图形渲染和通用计算,并探讨接口设计的相关问题。

1. GPU加速的必要性:从CPU到GPU

在传统的计算模型中,CPU承担了大部分的计算任务。然而,随着数据量的爆炸式增长,以及图形渲染和机器学习等领域对计算能力的巨大需求,CPU的性能逐渐成为瓶颈。GPU(图形处理器)最初设计用于图形渲染,但其并行计算架构使其非常适合处理大规模并行计算任务。GPU拥有成百上千个核心,可以同时执行大量的简单计算,从而显著提高计算效率。

2. WebGPU与Vulkan:现代图形API的选择

WebGPU和Vulkan是两种现代图形API,它们都旨在提供更底层的硬件访问能力,从而实现更高的性能和更灵活的控制。

  • Vulkan: 跨平台、低开销的图形和计算API,由Khronos Group维护。它提供了对GPU硬件的直接控制,允许开发者进行精细的性能优化。Vulkan旨在取代OpenGL,提供更高的性能和更低的CPU开销。Vulkan对内存管理、同步机制等要求较高,上手难度较大,但潜力巨大。

  • WebGPU: 一种新的Web API,旨在为Web应用提供高性能的3D图形和计算能力。WebGPU的设计目标是安全性、可移植性和性能。它借鉴了Vulkan、Metal和DirectX 12等现代图形API的优点,并在Web环境下进行了优化。WebGPU相对Vulkan来说,更容易上手,更注重Web平台的兼容性。

3. Java与GPU加速:桥梁的选择

要在Java中使用WebGPU或Vulkan,我们需要一个桥梁。目前主流的方案有以下几种:

  • JNI (Java Native Interface): JNI允许Java代码调用本地代码(如C/C++),从而可以访问Vulkan或WebGPU的C/C++ API。这是一种常用的方法,但需要编写大量的本地代码,并且维护成本较高。

  • LWJGL (Lightweight Java Game Library): LWJGL是一个Java库,提供了对OpenGL、Vulkan、GLFW等底层API的访问。它通过JNI的方式封装了这些API,提供了Java友好的接口。LWJGL是Java游戏开发中常用的库,也适用于GPU加速。

  • GraalVM Native Image: GraalVM可以将Java代码编译成原生可执行文件,从而可以直接访问Vulkan或WebGPU的C/C++ API,无需JNI。这种方法可以提高性能,但需要一定的配置和调试。

  • WebGPU Emulation Layer (Dawn/wgpu): 通过Dawn(Google Chrome背后的WebGPU实现)或wgpu-native,可以在Java中模拟WebGPU环境,并使用WebGPU的API。这种方法适用于WebGPU的实验和原型开发,但可能无法完全发挥GPU的性能。

4. Java与LWJGL:Vulkan的实践

我们以LWJGL为例,演示如何在Java中使用Vulkan进行GPU加速。

4.1 环境搭建

  • 下载并安装LWJGL。可以从LWJGL官网下载:https://www.lwjgl.org/
  • 配置LWJGL的依赖项。在Maven项目中,可以添加以下依赖:
<dependency>
    <groupId>org.lwjgl</groupId>
    <artifactId>lwjgl</artifactId>
    <version>${lwjgl.version}</version>
</dependency>
<dependency>
    <groupId>org.lwjgl</groupId>
    <artifactId>lwjgl-vulkan</artifactId>
    <version>${lwjgl.version}</version>
</dependency>
<dependency>
    <groupId>org.lwjgl</groupId>
    <artifactId>lwjgl-glfw</artifactId>
    <version>${lwjgl.version}</version>
</dependency>
<dependency>
    <groupId>org.lwjgl</groupId>
    <artifactId>lwjgl-opengl</artifactId>
    <version>${lwjgl.version}</version>
</dependency>

<!-- Native libraries -->
<dependency>
    <groupId>org.lwjgl</groupId>
    <artifactId>lwjgl</artifactId>
    <version>${lwjgl.version}</version>
    <classifier>${native.platform}</classifier>
    <scope>runtime</scope>
</dependency>
<dependency>
    <groupId>org.lwjgl</groupId>
    <artifactId>lwjgl-vulkan</artifactId>
    <version>${lwjgl.version}</version>
    <classifier>${native.platform}</classifier>
    <scope>runtime</scope>
</dependency>
<dependency>
    <groupId>org.lwjgl</groupId>
    <artifactId>lwjgl-glfw</artifactId>
    <version>${lwjgl.version}</version>
    <classifier>${native.platform}</classifier>
    <scope>runtime</scope>
</dependency>
<dependency>
    <groupId>org.lwjgl</groupId>
    <artifactId>lwjgl-opengl</artifactId>
    <version>${lwjgl.version}</version>
    <classifier>${native.platform}</classifier>
    <scope>runtime</scope>
</dependency>
  • 其中${lwjgl.version}${native.platform}需要根据你的实际情况进行替换,比如lwjgl.version为3.3.1,native.platform为natives-windows(Windows平台)。

4.2 Vulkan初始化

import org.lwjgl.glfw.*;
import org.lwjgl.system.*;
import org.lwjgl.vulkan.*;

import java.nio.*;

import static org.lwjgl.glfw.GLFW.*;
import static org.lwjgl.glfw.GLFWVulkan.*;
import static org.lwjgl.system.MemoryStack.*;
import static org.lwjgl.system.MemoryUtil.*;
import static org.lwjgl.vulkan.VK10.*;

public class VulkanExample {

    private long window;
    private VkInstance instance;
    private VkPhysicalDevice physicalDevice;
    private VkDevice device;
    private VkQueue graphicsQueue;
    private int queueFamilyIndex;

    public void run() {
        init();
        loop();
        cleanup();
    }

    private void init() {
        initGLFW();
        initVulkan();
    }

    private void initGLFW() {
        if (!glfwInit()) {
            throw new IllegalStateException("Unable to initialize GLFW");
        }

        glfwWindowHint(GLFW_CLIENT_API, GLFW_NO_API);
        window = glfwCreateWindow(800, 600, "Vulkan Example", NULL, NULL);

        if (window == NULL) {
            throw new RuntimeException("Failed to create the GLFW window");
        }
    }

    private void initVulkan() {
        createInstance();
        pickPhysicalDevice();
        createLogicalDevice();
    }

    private void createInstance() {
        try (MemoryStack stack = stackPush()) {
            VkApplicationInfo appInfo = VkApplicationInfo.calloc(stack);
            appInfo.sType(VK_STRUCTURE_TYPE_APPLICATION_INFO);
            appInfo.pApplicationName(stack.UTF8Safe("Vulkan Example"));
            appInfo.applicationVersion(VK_MAKE_VERSION(1, 0, 0));
            appInfo.pEngineName(stack.UTF8Safe("No Engine"));
            appInfo.engineVersion(VK_MAKE_VERSION(1, 0, 0));
            appInfo.apiVersion(VK_API_VERSION_1_0);

            VkInstanceCreateInfo createInfo = VkInstanceCreateInfo.calloc(stack);
            createInfo.sType(VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO);
            createInfo.pApplicationInfo(appInfo);

            PointerBuffer ppEnabledExtensionNames = stack.pointers(
                    memUTF8(GLFW_EXTENSIONS)
            );

            createInfo.ppEnabledExtensionNames(ppEnabledExtensionNames);

            // Validation layers (optional)
            boolean enableValidationLayers = true; // Set to false in release builds
            if (enableValidationLayers && !checkValidationLayerSupport()) {
                throw new RuntimeException("Validation layers requested, but not available!");
            }

            if (enableValidationLayers) {
                createInfo.ppEnabledLayerNames(stack.pointers(memUTF8("VK_LAYER_KHRONOS_validation")));
            }

            PointerBuffer pInstance = stack.mallocPointer(1);
            if (vkCreateInstance(createInfo, null, pInstance) != VK_SUCCESS) {
                throw new RuntimeException("Failed to create instance!");
            }

            instance = new VkInstance(pInstance.get(0), createInfo);
        }
    }

    private boolean checkValidationLayerSupport() {
        try (MemoryStack stack = stackPush()) {
            IntBuffer layerCount = stack.mallocInt(1);
            vkEnumerateInstanceLayerProperties(layerCount, null);

            VkLayerProperties.Buffer availableLayers = VkLayerProperties.malloc(layerCount.get(0), stack);
            vkEnumerateInstanceLayerProperties(layerCount, availableLayers);

            // Check if all of the requested layers are available.
            // This example only uses the "VK_LAYER_KHRONOS_validation" layer.
            while (availableLayers.hasRemaining()) {
                VkLayerProperties layerProperties = availableLayers.get();
                if (layerProperties.layerNameString().equals("VK_LAYER_KHRONOS_validation")) {
                    return true;
                }
            }

            return false;
        }
    }

    private void pickPhysicalDevice() {
        try (MemoryStack stack = stackPush()) {
            IntBuffer deviceCount = stack.mallocInt(1);
            vkEnumeratePhysicalDevices(instance, deviceCount, null);

            int deviceCountValue = deviceCount.get(0);
            if (deviceCountValue == 0) {
                throw new RuntimeException("Failed to find GPUs with Vulkan support!");
            }

            PointerBuffer devices = stack.mallocPointer(deviceCountValue);
            vkEnumeratePhysicalDevices(instance, deviceCount, devices);

            for (int i = 0; i < deviceCountValue; i++) {
                VkPhysicalDevice device = new VkPhysicalDevice(devices.get(i), instance);
                if (isDeviceSuitable(device)) {
                    physicalDevice = device;
                    return;
                }
            }

            throw new RuntimeException("Failed to find a suitable GPU!");
        }
    }

    private boolean isDeviceSuitable(VkPhysicalDevice device) {
        try (MemoryStack stack = stackPush()) {
            VkPhysicalDeviceProperties deviceProperties = VkPhysicalDeviceProperties.malloc(stack);
            vkGetPhysicalDeviceProperties(device, deviceProperties);

            VkPhysicalDeviceFeatures deviceFeatures = VkPhysicalDeviceFeatures.malloc(stack);
            vkGetPhysicalDeviceFeatures(device, deviceFeatures);

            QueueFamilyIndices indices = findQueueFamilies(device);

            boolean extensionsSupported = checkDeviceExtensionSupport(device);

            boolean swapChainAdequate = false;
            if (extensionsSupported) {
                SwapChainSupportDetails swapChainSupport = querySwapChainSupport(device);
                swapChainAdequate = swapChainSupport.formats.hasRemaining() && swapChainSupport.presentModes.hasRemaining();
            }
            boolean suitable = indices.isComplete() && extensionsSupported && swapChainAdequate && deviceFeatures.geometryShader();

            if (suitable) {
                System.out.println("Using device: " + deviceProperties.deviceNameString());
            }

            return suitable;
        }
    }

    private QueueFamilyIndices findQueueFamilies(VkPhysicalDevice device) {
        QueueFamilyIndices indices = new QueueFamilyIndices();

        try (MemoryStack stack = stackPush()) {
            IntBuffer queueFamilyCount = stack.mallocInt(1);
            vkGetPhysicalDeviceQueueFamilyProperties(device, queueFamilyCount, null);

            VkQueueFamilyProperties.Buffer queueFamilies = VkQueueFamilyProperties.malloc(queueFamilyCount.get(0), stack);
            vkGetPhysicalDeviceQueueFamilyProperties(device, queueFamilyCount, queueFamilies);

            int i = 0;
            while (queueFamilies.hasRemaining()) {
                VkQueueFamilyProperties queueFamily = queueFamilies.get();

                IntBuffer presentSupport = stack.mallocInt(1);
                glfwGetPhysicalDevicePresentationSupport(instance, device, i, presentSupport);

                if (queueFamily.queueFlags() && VK_QUEUE_GRAPHICS_BIT != 0) {
                    indices.graphicsFamily = i;
                }

                if (presentSupport.get(0) == GLFW_TRUE) {
                    indices.presentFamily = i;
                }

                if (indices.isComplete()) {
                    break;
                }

                i++;
            }
        }

        return indices;
    }

    static class QueueFamilyIndices {
        Integer graphicsFamily;
        Integer presentFamily;

        boolean isComplete() {
            return graphicsFamily != null && presentFamily != null;
        }
    }

    private boolean checkDeviceExtensionSupport(VkPhysicalDevice device) {
        try (MemoryStack stack = stackPush()) {
            IntBuffer extensionCount = stack.mallocInt(1);
            vkEnumerateDeviceExtensionProperties(device, (String) null, extensionCount, null);

            VkExtensionProperties.Buffer availableExtensions = VkExtensionProperties.malloc(extensionCount.get(0), stack);
            vkEnumerateDeviceExtensionProperties(device, (String) null, extensionCount, availableExtensions);

            java.util.Set<String> requiredExtensions = new java.util.HashSet<>(java.util.Arrays.asList(VK_KHR_SWAPCHAIN_EXTENSION_NAME));

            while (availableExtensions.hasRemaining()) {
                VkExtensionProperties extension = availableExtensions.get();
                requiredExtensions.remove(extension.extensionNameString());
            }

            return requiredExtensions.isEmpty();
        }
    }

    static class SwapChainSupportDetails {
        VkSurfaceCapabilitiesKHR capabilities;
        VkFormat.Buffer formats;
        VkPresentModeKHR.Buffer presentModes;
    }

    private SwapChainSupportDetails querySwapChainSupport(VkPhysicalDevice device) {
        SwapChainSupportDetails details = new SwapChainSupportDetails();

        try (MemoryStack stack = stackPush()) {
            // ... (Implementation for querying swap chain support details)
            // This involves querying surface capabilities, formats, and present modes.
            // The implementation is omitted for brevity, as it requires more detailed Vulkan knowledge.
            // Refer to Vulkan tutorials for detailed information on querying swap chain support.
            // This part typically uses vkGetPhysicalDeviceSurfaceCapabilitiesKHR, vkGetPhysicalDeviceSurfaceFormatsKHR, and vkGetPhysicalDeviceSurfacePresentModesKHR.
        }

        return details;
    }

    private void createLogicalDevice() {
        QueueFamilyIndices indices = findQueueFamilies(physicalDevice);

        try (MemoryStack stack = stackPush()) {
            FloatBuffer queuePriorities = stack.floats(1.0f);

            VkDeviceQueueCreateInfo.Buffer queueCreateInfos = VkDeviceQueueCreateInfo.malloc(1, stack);

            VkDeviceQueueCreateInfo queueCreateInfo = queueCreateInfos.get(0);
            queueCreateInfo.sType(VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO);
            queueCreateInfo.queueFamilyIndex(indices.graphicsFamily);
            queueCreateInfo.pQueuePriorities(queuePriorities);

            VkPhysicalDeviceFeatures deviceFeatures = VkPhysicalDeviceFeatures.calloc(stack);

            VkDeviceCreateInfo createInfo = VkDeviceCreateInfo.calloc(stack);
            createInfo.sType(VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO);
            createInfo.pQueueCreateInfos(queueCreateInfos);
            createInfo.pEnabledFeatures(deviceFeatures);

            createInfo.ppEnabledExtensionNames(stack.pointers(memUTF8(VK_KHR_SWAPCHAIN_EXTENSION_NAME)));

            // Validation layers (optional)
            boolean enableValidationLayers = true; // Set to false in release builds
            if (enableValidationLayers) {
                createInfo.ppEnabledLayerNames(stack.pointers(memUTF8("VK_LAYER_KHRONOS_validation")));
            }

            PointerBuffer pDevice = stack.mallocPointer(1);
            if (vkCreateDevice(physicalDevice, createInfo, null, pDevice) != VK_SUCCESS) {
                throw new RuntimeException("Failed to create logical device!");
            }

            device = new VkDevice(pDevice.get(0), physicalDevice, createInfo);
            graphicsQueue = new VkQueue();
            vkGetDeviceQueue(device, indices.graphicsFamily, 0, graphicsQueue);

            queueFamilyIndex = indices.graphicsFamily;

        }
    }

    private void loop() {
        while (!glfwWindowShouldClose(window)) {
            glfwPollEvents();
            // Rendering logic will go here later
        }
    }

    private void cleanup() {
        vkDestroyDevice(device, null);
        vkDestroyInstance(instance, null);
        glfwDestroyWindow(window);
        glfwTerminate();
    }

    public static void main(String[] args) {
        new VulkanExample().run();
    }
}

这段代码演示了Vulkan的初始化过程,包括:

  • GLFW初始化:创建窗口
  • Vulkan实例创建:创建Vulkan实例
  • 物理设备选择:选择合适的GPU
  • 逻辑设备创建:创建逻辑设备和队列

4.3 通用计算示例:向量加法

// Shader代码 (compute.glsl)
#version 450

layout (local_size_x = 256) in;

layout(set = 0, binding = 0) buffer A {
    vec4 data[];
} a;

layout(set = 0, binding = 1) buffer B {
    vec4 data[];
} b;

layout(set = 0, binding = 2) buffer C {
    vec4 data[];
} c;

void main() {
    uint index = gl_GlobalInvocationID.x;
    c.data[index] = a.data[index] + b.data[index];
}

// Java代码
import org.lwjgl.system.MemoryStack;
import org.lwjgl.vulkan.*;

import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;

import static org.lwjgl.system.MemoryStack.stackPush;
import static org.lwjgl.vulkan.VK10.*;
import static org.lwjgl.vulkan.VK10.vkFreeMemory;
import static org.lwjgl.vulkan.VK10.vkMapMemory;
import static org.lwjgl.vulkan.VK10.vkUnmapMemory;

public class VulkanCompute {

    private final VkDevice device;
    private final VkQueue computeQueue;
    private final int computeQueueFamilyIndex;
    private final VkPhysicalDevice physicalDevice;

    private long computePipeline;
    private long pipelineLayout;
    private long descriptorSetLayout;
    private long descriptorPool;
    private long descriptorSet;

    private long bufferA;
    private long bufferB;
    private long bufferC;
    private long memoryA;
    private long memoryB;
    private long memoryC;

    private int bufferSize;

    public VulkanCompute(VkDevice device, VkQueue computeQueue, int computeQueueFamilyIndex, VkPhysicalDevice physicalDevice) {
        this.device = device;
        this.computeQueue = computeQueue;
        this.computeQueueFamilyIndex = computeQueueFamilyIndex;
        this.physicalDevice = physicalDevice;
    }

    public void init(int size) {
        this.bufferSize = size;
        createBuffers();
        createDescriptorSetLayout();
        createComputePipeline();
        createDescriptorPool();
        createDescriptorSet();
        updateDescriptorSet();
    }

    private void createBuffers() {
        try (MemoryStack stack = stackPush()) {
            // Create input buffer A
            VkBufferCreateInfo bufferInfoA = VkBufferCreateInfo.calloc(stack);
            bufferInfoA.sType(VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO);
            bufferInfoA.size(bufferSize * 4 * 4); // size * vec4 * float
            bufferInfoA.usage(VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
            bufferInfoA.sharingMode(VK_SHARING_MODE_EXCLUSIVE);

            LongBuffer pBufferA = stack.mallocLong(1);
            if (vkCreateBuffer(device, bufferInfoA, null, pBufferA) != VK_SUCCESS) {
                throw new RuntimeException("Failed to create buffer A");
            }
            bufferA = pBufferA.get(0);

            VkMemoryRequirements memRequirementsA = VkMemoryRequirements.malloc(stack);
            vkGetBufferMemoryRequirements(device, bufferA, memRequirementsA);

            VkMemoryAllocateInfo allocInfoA = VkMemoryAllocateInfo.calloc(stack);
            allocInfoA.sType(VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO);
            allocInfoA.allocationSize(memRequirementsA.size());
            allocInfoA.memoryTypeIndex(findMemoryType(memRequirementsA.memoryTypeBits(), VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT));

            LongBuffer pMemoryA = stack.mallocLong(1);
            if (vkAllocateMemory(device, allocInfoA, null, pMemoryA) != VK_SUCCESS) {
                throw new RuntimeException("Failed to allocate memory A");
            }
            memoryA = pMemoryA.get(0);
            vkBindBufferMemory(device, bufferA, memoryA, 0);

            // Create input buffer B (similar to A)
            VkBufferCreateInfo bufferInfoB = VkBufferCreateInfo.calloc(stack);
            bufferInfoB.sType(VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO);
            bufferInfoB.size(bufferSize * 4 * 4); // size * vec4 * float
            bufferInfoB.usage(VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
            bufferInfoB.sharingMode(VK_SHARING_MODE_EXCLUSIVE);

            LongBuffer pBufferB = stack.mallocLong(1);
            if (vkCreateBuffer(device, bufferInfoB, null, pBufferB) != VK_SUCCESS) {
                throw new RuntimeException("Failed to create buffer B");
            }
            bufferB = pBufferB.get(0);

            VkMemoryRequirements memRequirementsB = VkMemoryRequirements.malloc(stack);
            vkGetBufferMemoryRequirements(device, bufferB, memRequirementsB);

            VkMemoryAllocateInfo allocInfoB = VkMemoryAllocateInfo.calloc(stack);
            allocInfoB.sType(VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO);
            allocInfoB.allocationSize(memRequirementsB.size());
            allocInfoB.memoryTypeIndex(findMemoryType(memRequirementsB.memoryTypeBits(), VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT));

            LongBuffer pMemoryB = stack.mallocLong(1);
            if (vkAllocateMemory(device, allocInfoB, null, pMemoryB) != VK_SUCCESS) {
                throw new RuntimeException("Failed to allocate memory B");
            }
            memoryB = pMemoryB.get(0);
            vkBindBufferMemory(device, bufferB, memoryB, 0);

            // Create output buffer C (similar to A and B)
            VkBufferCreateInfo bufferInfoC = VkBufferCreateInfo.calloc(stack);
            bufferInfoC.sType(VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO);
            bufferInfoC.size(bufferSize * 4 * 4); // size * vec4 * float
            bufferInfoC.usage(VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
            bufferInfoC.sharingMode(VK_SHARING_MODE_EXCLUSIVE);

            LongBuffer pBufferC = stack.mallocLong(1);
            if (vkCreateBuffer(device, bufferInfoC, null, pBufferC) != VK_SUCCESS) {
                throw new RuntimeException("Failed to create buffer C");
            }
            bufferC = pBufferC.get(0);

            VkMemoryRequirements memRequirementsC = VkMemoryRequirements.malloc(stack);
            vkGetBufferMemoryRequirements(device, bufferC, memRequirementsC);

            VkMemoryAllocateInfo allocInfoC = VkMemoryAllocateInfo.calloc(stack);
            allocInfoC.sType(VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO);
            allocInfoC.allocationSize(memRequirementsC.size());
            allocInfoC.memoryTypeIndex(findMemoryType(memRequirementsC.memoryTypeBits(), VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT));

            LongBuffer pMemoryC = stack.mallocLong(1);
            if (vkAllocateMemory(device, allocInfoC, null, pMemoryC) != VK_SUCCESS) {
                throw new RuntimeException("Failed to allocate memory C");
            }
            memoryC = pMemoryC.get(0);
            vkBindBufferMemory(device, bufferC, memoryC, 0);
        }
    }

    private int findMemoryType(int typeFilter, int properties) {
        try (MemoryStack stack = stackPush()) {
            VkPhysicalDeviceMemoryProperties memProperties = VkPhysicalDeviceMemoryProperties.malloc(stack);
            vkGetPhysicalDeviceMemoryProperties(physicalDevice, memProperties);

            for (int i = 0; i < memProperties.memoryTypeCount(); i++) {
                if ((typeFilter & (1 << i)) != 0 && (memProperties.memoryTypes(i).propertyFlags() & properties) == properties) {
                    return i;
                }
            }

            throw new RuntimeException("Failed to find suitable memory type!");
        }
    }

    private void createDescriptorSetLayout() {
        try (MemoryStack stack = stackPush()) {
            VkDescriptorSetLayoutBinding.Buffer layoutBindings = VkDescriptorSetLayoutBinding.malloc(3, stack);

            // Binding 0: Buffer A
            VkDescriptorSetLayoutBinding uboLayoutBindingA = layoutBindings.get(0);
            uboLayoutBindingA.binding(0);
            uboLayoutBindingA.descriptorType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
            uboLayoutBindingA.descriptorCount(1);
            uboLayoutBindingA.stageFlags(VK_SHADER_STAGE_COMPUTE_BIT);

            // Binding 1: Buffer B
            VkDescriptorSetLayoutBinding uboLayoutBindingB = layoutBindings.get(1);
            uboLayoutBindingB.binding(1);
            uboLayoutBindingB.descriptorType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
            uboLayoutBindingB.descriptorCount(1);
            uboLayoutBindingB.stageFlags(VK_SHADER_STAGE_COMPUTE_BIT);

            // Binding 2: Buffer C
            VkDescriptorSetLayoutBinding uboLayoutBindingC = layoutBindings.get(2);
            uboLayoutBindingC.binding(2);
            uboLayoutBindingC.descriptorType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
            uboLayoutBindingC.descriptorCount(1);
            uboLayoutBindingC.stageFlags(VK_SHADER_STAGE_COMPUTE_BIT);

            VkDescriptorSetLayoutCreateInfo layoutInfo = VkDescriptorSetLayoutCreateInfo.calloc(stack);
            layoutInfo.sType(VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO);
            layoutInfo.pBindings(layoutBindings);

            LongBuffer pDescriptorSetLayout = stack.mallocLong(1);
            if (vkCreateDescriptorSetLayout(device, layoutInfo, null, pDescriptorSetLayout) != VK_SUCCESS) {
                throw new RuntimeException("Failed to create descriptor set layout!");
            }
            descriptorSetLayout = pDescriptorSetLayout.get(0);
        }
    }

    private void createComputePipeline() {
        try (MemoryStack stack = stackPush()) {
            ByteBuffer computeShaderCode = ShaderLoader.loadShader("compute.glsl"); // 加载shader代码

            long computeShaderModule = createShaderModule(computeShaderCode);

            VkPipelineShaderStageCreateInfo computeShaderStageInfo = VkPipelineShaderStageCreateInfo.calloc(stack);
            computeShaderStageInfo.sType(VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO);
            computeShaderStageInfo.stage(VK_SHADER_STAGE_COMPUTE_BIT);
            computeShaderStageInfo.module(computeShaderModule);
            computeShaderStageInfo.pName(stack.UTF8Safe("main"));

            VkPipelineLayoutCreateInfo pipelineLayoutInfo = VkPipelineLayoutCreateInfo.calloc(stack);
            pipelineLayoutInfo.sType(VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO);
            LongBuffer pSetLayout = stack.longs(descriptorSetLayout);
            pipelineLayoutInfo.pSetLayouts(pSetLayout);

            LongBuffer pPipelineLayout = stack.mallocLong(1);
            if (vkCreatePipelineLayout(device, pipelineLayoutInfo, null, pPipelineLayout) != VK_SUCCESS) {
                throw new RuntimeException("Failed to create pipeline layout!");
            }
            pipelineLayout = pPipelineLayout.get(0);

            VkComputePipelineCreateInfo pipelineInfo = VkComputePipelineCreateInfo.calloc(stack);
            pipelineInfo.sType(VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO);
            pipelineInfo.stage(computeShaderStageInfo);
            pipelineInfo.layout(pipelineLayout);

            LongBuffer pComputePipeline = stack.mallocLong(1);
            if (vkCreateComputePipelines(device, 0, pipelineInfo, null, pComputePipeline) != VK_SUCCESS) {
                throw new RuntimeException("Failed to create compute pipeline!");
            }
            computePipeline = pComputePipeline.get(0);

            vkDestroyShaderModule(device, computeShaderModule, null);
        }
    }

    private long createShaderModule(ByteBuffer code) {
        try (MemoryStack stack = stackPush()) {
            VkShaderModuleCreateInfo createInfo = VkShaderModuleCreateInfo.calloc(stack);
            createInfo.sType(VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO);
            createInfo.pCode(code);

            LongBuffer pShaderModule = stack.mallocLong(1);
            if (vkCreateShaderModule(device, createInfo, null, pShaderModule) != VK_SUCCESS) {
                throw new RuntimeException("Failed to create shader module!");
            }
            return pShaderModule.get(0);
        }
    }

    private void createDescriptorPool() {
        try (MemoryStack stack = stackPush()) {
            VkDescriptorPoolSize.Buffer poolSizes = VkDescriptorPoolSize.malloc(1, stack);
            VkDescriptorPoolSize poolSize = poolSizes.get(0);
            poolSize.type(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
            poolSize.descriptorCount(3); // Three buffers

            VkDescriptorPoolCreateInfo poolInfo = VkDescriptorPoolCreateInfo.calloc(stack);
            poolInfo.sType(VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO);
            poolInfo.pPoolSizes(poolSizes);
            poolInfo.maxSets(1);

            LongBuffer pDescriptorPool = stack.mallocLong(1);
            if (vkCreateDescriptorPool(device, poolInfo, null, pDescriptorPool) != VK_SUCCESS) {
                throw new RuntimeException("Failed to create descriptor pool!");
            }
            descriptorPool = pDescriptorPool.get(0);
        }
    }

    private void createDescriptorSet() {
        try (MemoryStack stack = stackPush()) {
            LongBuffer allocInfo = stack.longs(descriptorSetLayout);

            VkDescriptorSetAllocateInfo allocateInfo = VkDescriptorSetAllocateInfo.calloc(stack);
            allocateInfo.sType(VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO);
            allocateInfo.descriptorPool(descriptorPool);
            allocateInfo.pSetLayouts(allocInfo);
            IntBuffer pDescriptorSetCount = stack.ints(1);

            LongBuffer pDescriptorSet = stack.mallocLong(1);
            if (vkAllocateDescriptorSets(device, allocateInfo, pDescriptorSet) != VK_SUCCESS) {
                throw new RuntimeException("Failed to allocate descriptor sets!");
            }
            descriptorSet = pDescriptorSet.get(0);
        }
    }

    private void updateDescriptorSet() {
        try (MemoryStack stack = stackPush()) {
            VkDescriptorBufferInfo.Buffer bufferInfoA = VkDescriptorBufferInfo.malloc(1, stack);
            bufferInfoA.buffer(bufferA);
            bufferInfoA.offset(0);
            bufferInfoA.range(VK_WHOLE_SIZE);

            VkDescriptorBufferInfo.Buffer bufferInfoB = VkDescriptorBufferInfo.malloc(1, stack);
            bufferInfoB.buffer(bufferB);
            bufferInfoB.offset(0);
            bufferInfoB.range(VK_WHOLE_SIZE);

            VkDescriptorBufferInfo.Buffer bufferInfoC = VkDescriptorBufferInfo.malloc(1, stack);
            bufferInfoC.buffer(bufferC);
            bufferInfoC.offset(0);
            bufferInfoC.range(VK_WHOLE_SIZE);

            VkWriteDescriptorSet.Buffer descriptorWrites = VkWriteDescriptorSet.malloc(3, stack);

            // Buffer A
            VkWriteDescriptorSet writeDescriptorSetA = descriptorWrites.get(0);
            writeDescriptorSetA.sType(VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET);
            writeDescriptorSetA.dstSet(descriptorSet);
            writeDescriptorSetA.dstBinding(0);
            writeDescriptorSetA.dstArrayElement(0);
            writeDescriptorSetA.descriptorType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
            writeDescriptorSetA.descriptorCount(1);
            writeDescriptorSetA.pBufferInfo(bufferInfoA);

            // Buffer B
            VkWriteDescriptorSet writeDescriptorSetB = descriptorWrites.get(1);
            writeDescriptorSetB.sType(VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET);
            writeDescriptorSetB.dstSet(descriptorSet);
            writeDescriptorSetB.dstBinding(1);
            writeDescriptorSetB.dstArrayElement(0);
            writeDescriptorSetB.descriptorType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
            writeDescriptorSetB.descriptorCount(1);
            writeDescriptorSetB.pBufferInfo(bufferInfoB);

            // Buffer C
            VkWriteDescriptorSet writeDescriptorSetC = descriptorWrites.get(2);
            writeDescriptorSetC.sType(VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET);
            writeDescriptorSetC.dstSet(descriptorSet);
            writeDescriptorSetC.dstBinding(2);
            writeDescriptorSetC.dstArrayElement(0);
            writeDescriptorSetC.descriptorType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
            writeDescriptorSetC.descriptorCount(1);
            writeDescriptorSetC.pBufferInfo(bufferInfoC);

            vkUpdateDescriptorSets(device, descriptorWrites, null);
        }
    }

    public void compute(float[] inputA, float[] inputB, float[] output) {

        if(inputA.length != bufferSize * 4 || inputB.length != bufferSize * 4 || output.length != bufferSize * 4){
            throw new IllegalArgumentException("Input/Output array size mismatch");
        }
        try (MemoryStack stack = stackPush()) {
            // Map memory and copy data to buffers
            FloatBuffer mappedMemoryA = mapMemory(memoryA, bufferSize * 4 * 4);
            mappedMemoryA.put(inputA);
            unmapMemory(memoryA);

            FloatBuffer mappedMemoryB = mapMemory(memoryB, bufferSize * 4 * 4);
            mappedMemoryB.put(inputB);
            unmapMemory(memory

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注