4 Copyright (C) 2017 Eric Arnebäck
5 Copyright (C) 2019 Michael Zucchi
7 Permission is hereby granted, free of charge, to any person obtaining a copy
8 of this software and associated documentation files (the "Software"), to deal
9 in the Software without restriction, including without limitation the rights
10 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 copies of the Software, and to permit persons to whom the Software is
12 furnished to do so, subject to the following conditions:
14 The above copyright notice and this permission notice shall be included in
15 all copies or substantial portions of the Software.
17 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
28 * This is a Java conversion of a C conversion of this:
29 * https://github.com/Erkaman/vulkan_minimal_compute
31 * It's been simplified a bit and converted to the 'zvk' api.
36 import java.io.InputStream;
37 import java.io.FileOutputStream;
38 import java.io.IOException;
39 import java.nio.channels.Channels;
40 import java.nio.ByteBuffer;
41 import java.nio.ByteOrder;
43 import java.awt.Graphics;
44 import java.awt.Image;
45 import java.awt.Toolkit;
46 import java.awt.event.ActionEvent;
47 import java.awt.event.KeyEvent;
48 import java.awt.image.MemoryImageSource;
49 import javax.swing.AbstractAction;
50 import javax.swing.JComponent;
51 import javax.swing.JFrame;
52 import javax.swing.JPanel;
53 import javax.swing.KeyStroke;
55 import java.lang.ref.WeakReference;
57 import java.lang.invoke.*;
58 import jdk.incubator.foreign.*;
59 import jdk.incubator.foreign.MemoryLayout.PathElement;
60 import au.notzed.nativez.*;
64 import static vulkan.VkConstants.*;
66 public class TestMandelbrot {
67 static final boolean debug = true;
68 ResourceScope scope = ResourceScope.newSharedScope();
74 VkPhysicalDevice physicalDevice;
79 long dstBufferSize = WIDTH * HEIGHT * 4;
81 //VkDeviceMemory dstMemory;
84 VkDescriptorSetLayout descriptorSetLayout;
85 VkDescriptorPool descriptorPool;
86 HandleArray<VkDescriptorSet> descriptorSets;
88 int computeQueueIndex;
89 VkPhysicalDeviceMemoryProperties deviceMemoryProperties;
91 String mandelbrot_entry = "main";
92 IntArray mandelbrot_cs;
94 VkShaderModule mandelbrotShader;
95 VkPipelineLayout pipelineLayout;
96 HandleArray<VkPipeline> computePipeline = VkPipeline.createArray(1, (SegmentAllocator)scope);
98 VkCommandPool commandPool;
99 HandleArray<VkCommandBuffer> commandBuffers;
101 record BufferMemory ( VkBuffer buffer, VkDeviceMemory memory ) {};
103 VkDebugUtilsMessengerEXT logger;
105 void init_debug() throws Exception {
108 try (Frame frame = Frame.frame()) {
109 var cb = PFN_vkDebugUtilsMessengerCallbackEXT.upcall((severity, flags, data) -> {
110 System.out.printf("Debug: %d: %s\n", severity, data.getMessage());
113 VkDebugUtilsMessengerCreateInfoEXT info = VkDebugUtilsMessengerCreateInfoEXT.create(frame,
115 //VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT |
116 VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT
117 | VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT
118 | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT,
119 VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT
120 | VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT
121 | VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT,
125 logger = instance.vkCreateDebugUtilsMessengerEXT(info, scope);
129 void init_instance() throws Exception {
130 try (Frame frame = Frame.frame()) {
131 VkInstanceCreateInfo info = VkInstanceCreateInfo.create(frame,
133 VkApplicationInfo.create(frame, "test", 1, "test-engine", 2, VK_MAKE_API_VERSION(0, 1, 0, 0)),
134 new String[] { "VK_LAYER_KHRONOS_validation" },
135 debug ? new String[] { "VK_EXT_debug_utils" } : null
138 instance = VkInstance.vkCreateInstance(info, scope);
142 void init_device() throws Exception {
143 try (Frame frame = Frame.frame()) {
144 HandleArray<VkPhysicalDevice> devs;
148 devs = instance.vkEnumeratePhysicalDevices(frame, scope);
154 for (int i=0;i<devs.length();i++) {
155 VkPhysicalDevice dev = devs.getAtIndex(i);
156 VkQueueFamilyProperties famprops = dev.vkGetPhysicalDeviceQueueFamilyProperties(frame);
157 int family_count = (int)famprops.length();
159 for (int j=0;j<family_count;j++) {
160 var flags = famprops.getAtIndex(j).getQueueFlags();
163 if ((flags & VK_QUEUE_COMPUTE_BIT) != 0)
165 if ((flags & VK_QUEUE_GRAPHICS_BIT) == 0)
177 throw new Exception("Cannot find a suitable device");
179 computeQueueIndex = queueid;
180 physicalDevice = devs.getAtIndex(devid);
182 FloatArray qpri = FloatArray.create(frame, 0.0f);
183 VkDeviceQueueCreateInfo qinfo = VkDeviceQueueCreateInfo.create(
188 VkDeviceCreateInfo devinfo = VkDeviceCreateInfo.create(
196 device = physicalDevice.vkCreateDevice(devinfo, scope);
198 System.out.printf("device = %s\n", device.address());
201 deviceMemoryProperties = VkPhysicalDeviceMemoryProperties.create((SegmentAllocator)scope);
202 physicalDevice.vkGetPhysicalDeviceMemoryProperties(deviceMemoryProperties);
204 computeQueue = device.vkGetDeviceQueue(queueid, 0, scope);
209 * Buffers are created in three steps:
210 * 1) create buffer, specifying usage and size
211 * 2) allocate memory based on memory requirements
215 BufferMemory init_buffer(long dataSize, int usage, int properties) throws Exception {
216 try (Frame frame = Frame.frame()) {
217 VkMemoryRequirements req = VkMemoryRequirements.create(frame);
218 VkBufferCreateInfo buf_info = VkBufferCreateInfo.create(frame,
222 VK_SHARING_MODE_EXCLUSIVE,
225 VkBuffer buffer = device.vkCreateBuffer(buf_info, scope);
227 device.vkGetBufferMemoryRequirements(buffer, req);
229 VkMemoryAllocateInfo alloc = VkMemoryAllocateInfo.create(frame,
231 find_memory_type(deviceMemoryProperties, req.getMemoryTypeBits(), properties));
233 VkDeviceMemory memory = device.vkAllocateMemory(alloc, scope);
235 device.vkBindBufferMemory(buffer, memory, 0);
237 return new BufferMemory(buffer, memory);
242 * Descriptors are used to bind and describe memory blocks
245 * *Pool is used to allocate descriptors, it is per-device.
246 * *Layout is used to group descriptors for a given pipeline,
247 * The descriptors describe individually-addressable blocks.
249 void init_descriptor() throws Exception {
250 try (Frame frame = Frame.frame()) {
251 /* Create descriptorset layout */
252 VkDescriptorSetLayoutBinding layout_binding = VkDescriptorSetLayoutBinding.create(frame,
254 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
256 VK_SHADER_STAGE_COMPUTE_BIT,
259 VkDescriptorSetLayoutCreateInfo descriptor_layout = VkDescriptorSetLayoutCreateInfo.create(frame,
263 descriptorSetLayout = device.vkCreateDescriptorSetLayout(descriptor_layout, scope);
265 /* Create descriptor pool */
266 VkDescriptorPoolSize type_count = VkDescriptorPoolSize.create(frame,
267 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
270 VkDescriptorPoolCreateInfo descriptor_pool = VkDescriptorPoolCreateInfo.create(frame,
275 descriptorPool = device.vkCreateDescriptorPool(descriptor_pool, scope);
277 /* Allocate from pool */
278 HandleArray<VkDescriptorSetLayout> layout_table = VkDescriptorSetLayout.createArray(1, frame);
280 layout_table.setAtIndex(0, descriptorSetLayout);
282 VkDescriptorSetAllocateInfo alloc_info = VkDescriptorSetAllocateInfo.create(frame,
286 descriptorSets = device.vkAllocateDescriptorSets(alloc_info, (SegmentAllocator)scope);
288 /* Bind a buffer to the descriptor */
289 VkDescriptorBufferInfo bufferInfo = VkDescriptorBufferInfo.create(frame,
294 VkWriteDescriptorSet writeSet = VkWriteDescriptorSet.create(frame,
295 descriptorSets.getAtIndex(0),
299 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
304 System.out.println(writeSet);
306 device.vkUpdateDescriptorSets(writeSet, null);
311 * Create the compute pipeline. This is the shader and data layouts for it.
313 void init_pipeline() throws Exception {
314 try (Frame frame = Frame.frame()) {
315 /* Set shader code */
316 VkShaderModuleCreateInfo vsInfo = VkShaderModuleCreateInfo.create(frame,
318 mandelbrot_cs.length() * 4,
321 mandelbrotShader = device.vkCreateShaderModule(vsInfo, scope);
323 /* Link shader to layout */
324 HandleArray<VkDescriptorSetLayout> layout_table = VkDescriptorSetLayout.createArray(1, frame);
326 layout_table.setAtIndex(0, descriptorSetLayout);
328 VkPipelineLayoutCreateInfo pipelineinfo = VkPipelineLayoutCreateInfo.create(frame,
333 pipelineLayout = device.vkCreatePipelineLayout(pipelineinfo, scope);
335 /* Create pipeline */
336 VkComputePipelineCreateInfo pipeline = VkComputePipelineCreateInfo.create(frame,
342 VkPipelineShaderStageCreateInfo stage = pipeline.getStage();
344 stage.setStage(VK_SHADER_STAGE_COMPUTE_BIT);
345 stage.setModule(mandelbrotShader);
346 stage.setName(mandelbrot_entry, frame);
348 device.vkCreateComputePipelines(null, pipeline, computePipeline);
353 * Create a command buffer, this is somewhat like a display list.
355 void init_command_buffer() throws Exception {
356 try (Frame frame = Frame.frame()) {
357 VkCommandPoolCreateInfo poolinfo = VkCommandPoolCreateInfo.create(frame,
361 commandPool = device.vkCreateCommandPool(poolinfo, scope);
363 VkCommandBufferAllocateInfo cmdinfo = VkCommandBufferAllocateInfo.create(frame,
365 VK_COMMAND_BUFFER_LEVEL_PRIMARY,
368 // should it take a scope?
369 commandBuffers = device.vkAllocateCommandBuffers(cmdinfo, (SegmentAllocator)scope, scope);
371 /* Fill command buffer with commands for later operation */
372 VkCommandBufferBeginInfo beginInfo = VkCommandBufferBeginInfo.create(frame,
373 VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT,
376 commandBuffers.get(0).vkBeginCommandBuffer(beginInfo);
378 /* Bind the compute operation and data */
379 commandBuffers.get(0).vkCmdBindPipeline(VK_PIPELINE_BIND_POINT_COMPUTE, computePipeline.get(0));
380 commandBuffers.get(0).vkCmdBindDescriptorSets(VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout, 0, descriptorSets, null);
383 commandBuffers.get(0).vkCmdDispatch(WIDTH, HEIGHT, 1);
385 commandBuffers.get(0).vkEndCommandBuffer();
390 * Execute the pre-created command buffer.
392 * A fence is used to wait for completion.
394 void execute() throws Exception {
395 try (Frame frame = Frame.frame()) {
396 VkSubmitInfo submitInfo = VkSubmitInfo.create(frame);
398 submitInfo.setCommandBufferCount(1);
399 submitInfo.setCommandBuffers(commandBuffers);
401 /* Create fence to mark the task completion */
403 HandleArray<VkFence> fences = VkFence.createArray(1, frame);
404 VkFenceCreateInfo fenceInfo = VkFenceCreateInfo.create(frame);
406 // maybe this should take a HandleArray<Fence> rather than being a constructor
407 // FIXME: some local scope
408 fence = device.vkCreateFence(fenceInfo, scope);
409 fences.set(0, fence);
411 /* Await completion */
412 computeQueue.vkQueueSubmit(submitInfo, fence);
417 res = device.vkWaitForFences(fences, VK_TRUE, 1000000);
418 } while (res == VK_TIMEOUT);
420 device.vkDestroyFence(fence);
425 device.vkDestroyCommandPool(commandPool);
426 device.vkDestroyPipeline(computePipeline.getAtIndex(0));
427 device.vkDestroyPipelineLayout(pipelineLayout);
428 device.vkDestroyShaderModule(mandelbrotShader);
430 device.vkDestroyDescriptorPool(descriptorPool);
431 device.vkDestroyDescriptorSetLayout(descriptorSetLayout);
433 device.vkFreeMemory(dst.memory());
434 device.vkDestroyBuffer(dst.buffer());
436 device.vkDestroyDevice();
438 instance.vkDestroyDebugUtilsMessengerEXT(logger);
439 instance.vkDestroyInstance();
443 * Accesses the gpu buffer, converts it to RGB byte, and saves it as a pam file.
445 void save_result() throws Exception {
446 try (ResourceScope scope = ResourceScope.newConfinedScope()) {
447 MemorySegment mem = device.vkMapMemory(dst.memory(), 0, dstBufferSize, 0, scope);
448 byte[] pixels = new byte[WIDTH * HEIGHT * 3];
450 System.out.printf("map %d bytes\n", dstBufferSize);
452 for (int i = 0; i < WIDTH * HEIGHT; i++) {
453 pixels[i * 3 + 0] = mem.get(Memory.BYTE, i * 4 + 0);
454 pixels[i * 3 + 1] = mem.get(Memory.BYTE, i * 4 + 1);
455 pixels[i * 3 + 2] = mem.get(Memory.BYTE, i * 4 + 2);
458 device.vkUnmapMemory(dst.memory());
460 pam_save("mandelbrot.pam", WIDTH, HEIGHT, 3, pixels);
464 void show_result() throws Exception {
465 try (ResourceScope scope = ResourceScope.newConfinedScope()) {
466 MemorySegment mem = device.vkMapMemory(dst.memory(), 0, dstBufferSize, 0, scope);
467 int[] pixels = new int[WIDTH * HEIGHT];
469 System.out.printf("map %d bytes\n", dstBufferSize);
471 MemorySegment.ofArray(pixels).copyFrom(mem);
473 device.vkUnmapMemory(dst.memory());
475 swing_show(WIDTH, HEIGHT, pixels);
480 * Trivial pnm format image output.
482 void pam_save(String name, int width, int height, int depth, byte[] pixels) throws IOException {
483 try (FileOutputStream fos = new FileOutputStream(name)) {
484 fos.write(String.format("P6\n%d\n%d\n255\n", width, height).getBytes());
486 System.out.printf("wrote: %s\n", name);
490 static class DataImage extends JPanel {
492 final int w, h, stride;
493 final MemoryImageSource source;
497 public DataImage(int w, int h, int[] pixels) {
501 this.pixels = pixels;
502 this.source = new MemoryImageSource(w, h, pixels, 0, w);
503 this.source.setAnimated(true);
504 this.source.setFullBufferUpdates(true);
505 this.image = Toolkit.getDefaultToolkit().createImage(source);
509 protected void paintComponent(Graphics g) {
510 super.paintComponent(g);
511 g.drawImage(image, 0, 0, this);
515 void swing_show(int w, int h, int[] pixels) {
517 DataImage image = new DataImage(w, h, pixels);
519 window = new JFrame("mandelbrot");
520 window.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
521 window.setContentPane(image);
522 window.setSize(w, h);
523 window.setVisible(true);
526 IntArray loadSPIRV0(String name) throws IOException {
527 // hmm any way to just load this directly?
528 try (InputStream is = TestMandelbrot.class.getResourceAsStream(name)) {
529 ByteBuffer bb = ByteBuffer.allocateDirect(8192).order(ByteOrder.nativeOrder());
530 int length = Channels.newChannel(is).read(bb);
535 return IntArray.create(MemorySegment.ofByteBuffer(bb));
539 IntArray loadSPIRV(String name) throws IOException {
540 try (InputStream is = TestMandelbrot.class.getResourceAsStream(name)) {
541 MemorySegment seg = ((SegmentAllocator)scope).allocateArray(Memory.INT, 2048);
542 int length = Channels.newChannel(is).read(seg.asByteBuffer());
544 return IntArray.create(seg.asSlice(0, length));
549 * This finds the memory type index for the memory on a specific device.
551 static int find_memory_type(VkPhysicalDeviceMemoryProperties memory, int typeMask, int query) {
552 VkMemoryType mtypes = memory.getMemoryTypes();
554 for (int i = 0; i < memory.getMemoryTypeCount(); i++) {
555 if (((1 << i) & typeMask) != 0 && ((mtypes.getAtIndex(i).getPropertyFlags() & query) == query))
561 public static int VK_MAKE_API_VERSION(int variant, int major, int minor, int patch) {
562 return (variant << 29) | (major << 22) | (minor << 12) | patch;
565 void demo() throws Exception {
566 mandelbrot_cs = loadSPIRV("mandelbrot.bin");
573 dst = init_buffer(dstBufferSize,
574 VK_BUFFER_USAGE_STORAGE_BUFFER_BIT,
575 VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
580 init_command_buffer();
582 System.out.printf("Calculating %dx%d\n", WIDTH, HEIGHT);
584 //System.out.println("Saving ...");
586 System.out.println("Showing ...");
588 System.out.println("Done.");
594 public static void main(String[] args) throws Throwable {
595 System.loadLibrary("vulkan");
597 new TestMandelbrot().demo();