4 Copyright (C) 2017 Eric Arnebäck
5 Copyright (C) 2019 Michael Zucchi
7 Permission is hereby granted, free of charge, to any person obtaining a copy
8 of this software and associated documentation files (the "Software"), to deal
9 in the Software without restriction, including without limitation the rights
10 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 copies of the Software, and to permit persons to whom the Software is
12 furnished to do so, subject to the following conditions:
14 The above copyright notice and this permission notice shall be included in
15 all copies or substantial portions of the Software.
17 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
28 * This is a Java conversion of a C conversion of this:
29 * https://github.com/Erkaman/vulkan_minimal_compute
31 * It's been simplified a bit and converted to the 'zvk' api.
36 import java.io.InputStream;
37 import java.io.FileOutputStream;
38 import java.io.IOException;
39 import java.nio.channels.Channels;
40 import java.nio.ByteBuffer;
41 import java.nio.ByteOrder;
43 import java.awt.Graphics;
44 import java.awt.Image;
45 import java.awt.Toolkit;
46 import java.awt.event.ActionEvent;
47 import java.awt.event.KeyEvent;
48 import java.awt.image.MemoryImageSource;
49 import javax.swing.AbstractAction;
50 import javax.swing.JComponent;
51 import javax.swing.JFrame;
52 import javax.swing.JPanel;
53 import javax.swing.KeyStroke;
55 import java.lang.ref.WeakReference;
57 import java.lang.invoke.*;
58 import jdk.incubator.foreign.*;
59 import jdk.incubator.foreign.MemoryLayout.PathElement;
60 import au.notzed.nativez.*;
64 import static vulkan.VkConstants.*;
66 public class TestMandelbrot {
67 static final boolean debug = true;
68 ResourceScope scope = ResourceScope.newSharedScope();
74 VkPhysicalDevice physicalDevice;
79 long dstBufferSize = WIDTH * HEIGHT * 4;
81 //VkDeviceMemory dstMemory;
84 VkDescriptorSetLayout descriptorSetLayout;
85 VkDescriptorPool descriptorPool;
86 HandleArray<VkDescriptorSet> descriptorSets = VkDescriptorSet.createArray(1, (SegmentAllocator)scope);
88 int computeQueueIndex;
89 VkPhysicalDeviceMemoryProperties deviceMemoryProperties;
91 String mandelbrot_entry = "main";
92 IntArray mandelbrot_cs;
94 VkShaderModule mandelbrotShader;
95 VkPipelineLayout pipelineLayout;
96 HandleArray<VkPipeline> computePipeline = VkPipeline.createArray(1, (SegmentAllocator)scope);
98 VkCommandPool commandPool;
99 HandleArray<VkCommandBuffer> commandBuffers;
101 record BufferMemory ( VkBuffer buffer, VkDeviceMemory memory ) {};
103 VkDebugUtilsMessengerEXT logger;
105 void init_debug() throws Exception {
108 try (Frame frame = Frame.frame()) {
109 var cb = PFN_vkDebugUtilsMessengerCallbackEXT.upcall((severity, flags, data, dummy) -> {
110 System.out.printf("Debug: %d: %s\n", severity, data.getMessage());
113 VkDebugUtilsMessengerCreateInfoEXT info = VkDebugUtilsMessengerCreateInfoEXT.create(frame,
115 VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT
116 | VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT
117 | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT,
118 VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT,
122 logger = instance.vkCreateDebugUtilsMessengerEXT(info, null, scope);
125 //typedef VkBool32 (*PFN_vkDebugUtilsMessengerCallbackEXT)(VkDebugUtilsMessageSeverityFlagBitsEXT, VkDebugUtilsMessageTypeFlagsEXT, const VkDebugUtilsMessengerCallbackDataEXT *, void *);
129 void dumpStructure(GroupLayout layout, MemorySegment s) {
130 System.out.println(s.address());
131 for (MemoryLayout m: layout.memberLayouts()) {
132 if (m instanceof jdk.incubator.foreign.ValueLayout) {
133 var vh = layout.varHandle(jdk.incubator.foreign.MemoryLayout.PathElement.groupElement(m.name().get()));
134 System.out.printf(" %s: %s\n", m.name().get(), vh.get(s));
139 void init_instance() throws Exception {
140 try (Frame frame = Frame.frame()) {
141 VkInstanceCreateInfo info = VkInstanceCreateInfo.create(frame,
143 VkApplicationInfo.create(frame, "test", 1, "test-engine", 2, VK_MAKE_API_VERSION(0, 1, 0, 0)),
144 new String[] { "VK_LAYER_KHRONOS_validation" },
145 debug ? new String[] { "VK_EXT_debug_utils" } : null
148 instance = VkInstance.vkCreateInstance(info, null, scope);
152 void init_device() throws Exception {
153 try (Frame frame = Frame.frame()) {
154 HandleArray<VkPhysicalDevice> devs;
158 devs = instance.vkEnumeratePhysicalDevices(frame, scope);
164 for (int i=0;i<devs.length();i++) {
165 VkPhysicalDevice dev = devs.getAtIndex(i);
166 VkQueueFamilyProperties famprops = dev.vkGetPhysicalDeviceQueueFamilyProperties(frame);
167 int family_count = (int)famprops.length();
169 for (int j=0;j<family_count;j++) {
172 if ((famprops.getQueueFlagsAtIndex(j) & VK_QUEUE_COMPUTE_BIT) != 0)
174 if ((famprops.getQueueFlagsAtIndex(j) & VK_QUEUE_GRAPHICS_BIT) == 0)
186 throw new Exception("Cannot find a suitable device");
188 computeQueueIndex = queueid;
189 physicalDevice = devs.getAtIndex(devid);
191 FloatArray qpri = FloatArray.create(frame, 0.0f);
192 VkDeviceQueueCreateInfo qinfo = VkDeviceQueueCreateInfo.create(
197 VkDeviceCreateInfo devinfo = VkDeviceCreateInfo.create(
205 device = physicalDevice.vkCreateDevice(devinfo, null, scope);
207 System.out.printf("device = %s\n", device.address());
210 deviceMemoryProperties = VkPhysicalDeviceMemoryProperties.create((SegmentAllocator)scope);
211 physicalDevice.vkGetPhysicalDeviceMemoryProperties(deviceMemoryProperties);
213 computeQueue = device.vkGetDeviceQueue(queueid, 0, scope);
218 * Buffers are created in three steps:
219 * 1) create buffer, specifying usage and size
220 * 2) allocate memory based on memory requirements
224 BufferMemory init_buffer(long dataSize, int usage, int properties) throws Exception {
225 try (Frame frame = Frame.frame()) {
226 VkMemoryRequirements req = VkMemoryRequirements.create(frame);
227 VkBufferCreateInfo buf_info = VkBufferCreateInfo.create(frame,
231 VK_SHARING_MODE_EXCLUSIVE,
234 VkBuffer buffer = device.vkCreateBuffer(buf_info, null, scope);
236 device.vkGetBufferMemoryRequirements(buffer, req);
238 VkMemoryAllocateInfo alloc = VkMemoryAllocateInfo.create(frame,
240 find_memory_type(deviceMemoryProperties, req.getMemoryTypeBits(), properties));
242 VkDeviceMemory memory = device.vkAllocateMemory(alloc, null, scope);
244 device.vkBindBufferMemory(buffer, memory, 0);
246 return new BufferMemory(buffer, memory);
251 * Descriptors are used to bind and describe memory blocks
254 * *Pool is used to allocate descriptors, it is per-device.
255 * *Layout is used to group descriptors for a given pipeline,
256 * The descriptors describe individually-addressable blocks.
258 void init_descriptor() throws Exception {
259 try (Frame frame = Frame.frame()) {
260 /* Create descriptorset layout */
261 VkDescriptorSetLayoutBinding layout_binding = VkDescriptorSetLayoutBinding.create(frame,
263 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
265 VK_SHADER_STAGE_COMPUTE_BIT,
268 VkDescriptorSetLayoutCreateInfo descriptor_layout = VkDescriptorSetLayoutCreateInfo.create(frame,
272 descriptorSetLayout = device.vkCreateDescriptorSetLayout(descriptor_layout, null, scope);
274 /* Create descriptor pool */
275 VkDescriptorPoolSize type_count = VkDescriptorPoolSize.create(frame,
276 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
279 VkDescriptorPoolCreateInfo descriptor_pool = VkDescriptorPoolCreateInfo.create(frame,
284 descriptorPool = device.vkCreateDescriptorPool(descriptor_pool, null, scope);
286 /* Allocate from pool */
287 HandleArray<VkDescriptorSetLayout> layout_table = VkDescriptorSetLayout.createArray(1, frame);
289 layout_table.setAtIndex(0, descriptorSetLayout);
291 VkDescriptorSetAllocateInfo alloc_info = VkDescriptorSetAllocateInfo.create(frame,
296 device.vkAllocateDescriptorSets(alloc_info, descriptorSets);
298 /* Bind a buffer to the descriptor */
299 VkDescriptorBufferInfo bufferInfo = VkDescriptorBufferInfo.create(frame,
304 VkWriteDescriptorSet writeSet = VkWriteDescriptorSet.create(frame,
305 descriptorSets.getAtIndex(0),
308 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
313 device.vkUpdateDescriptorSets(1, writeSet, 0, null);
318 * Create the compute pipeline. This is the shader and data layouts for it.
320 void init_pipeline() throws Exception {
321 try (Frame frame = Frame.frame()) {
322 /* Set shader code */
323 VkShaderModuleCreateInfo vsInfo = VkShaderModuleCreateInfo.create(frame,
325 mandelbrot_cs.length() * 4,
328 mandelbrotShader = device.vkCreateShaderModule(vsInfo, null, scope);
330 /* Link shader to layout */
331 HandleArray<VkDescriptorSetLayout> layout_table = VkDescriptorSetLayout.createArray(1, frame);
333 layout_table.setAtIndex(0, descriptorSetLayout);
335 VkPipelineLayoutCreateInfo pipelineinfo = VkPipelineLayoutCreateInfo.create(frame,
341 pipelineLayout = device.vkCreatePipelineLayout(pipelineinfo, null, scope);
343 /* Create pipeline */
344 VkComputePipelineCreateInfo pipeline = VkComputePipelineCreateInfo.create(frame,
350 VkPipelineShaderStageCreateInfo stage = pipeline.getStage();
352 stage.setStage(VK_SHADER_STAGE_COMPUTE_BIT);
353 stage.setModule(mandelbrotShader);
354 stage.setName(mandelbrot_entry);
356 device.vkCreateComputePipelines(null, 1, pipeline, null, computePipeline);
361 * Create a command buffer, this is somewhat like a display list.
363 void init_command_buffer() throws Exception {
364 try (Frame frame = Frame.frame()) {
365 VkCommandPoolCreateInfo poolinfo = VkCommandPoolCreateInfo.create(frame,
369 commandPool = device.vkCreateCommandPool(poolinfo, null, scope);
371 VkCommandBufferAllocateInfo cmdinfo = VkCommandBufferAllocateInfo.create(frame,
373 VK_COMMAND_BUFFER_LEVEL_PRIMARY,
376 // should it take a scope?
377 commandBuffers = VkCommandBuffer.createArray(instance, 1, (SegmentAllocator)scope, scope);
378 device.vkAllocateCommandBuffers(cmdinfo, commandBuffers);
380 /* Fill command buffer with commands for later operation */
381 VkCommandBufferBeginInfo beginInfo = VkCommandBufferBeginInfo.create(frame,
382 VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT,
385 commandBuffers.get(0).vkBeginCommandBuffer(beginInfo);
387 /* Bind the compute operation and data */
388 commandBuffers.get(0).vkCmdBindPipeline(VK_PIPELINE_BIND_POINT_COMPUTE, computePipeline.get(0));
389 commandBuffers.get(0).vkCmdBindDescriptorSets(VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout, 0, 1, descriptorSets, 0, null);
392 commandBuffers.get(0).vkCmdDispatch(WIDTH, HEIGHT, 1);
394 commandBuffers.get(0).vkEndCommandBuffer();
399 * Execute the pre-created command buffer.
401 * A fence is used to wait for completion.
403 void execute() throws Exception {
404 try (Frame frame = Frame.frame()) {
405 VkSubmitInfo submitInfo = VkSubmitInfo.create(frame);
407 submitInfo.setCommandBufferCount(1);
408 submitInfo.setCommandBuffers(commandBuffers);
410 /* Create fence to mark the task completion */
412 HandleArray<VkFence> fences = VkFence.createArray(1, frame);
413 VkFenceCreateInfo fenceInfo = VkFenceCreateInfo.create(frame);
415 // maybe this should take a HandleArray<Fence> rather than being a constructor
416 // FIXME: some local scope
417 fence = device.vkCreateFence(fenceInfo, null, scope);
418 fences.set(0, fence);
420 /* Await completion */
421 computeQueue.vkQueueSubmit(1, submitInfo, fence);
426 res = device.vkWaitForFences(1, fences, VK_TRUE, 1000000);
427 } while (res == VK_TIMEOUT);
429 device.vkDestroyFence(fence, null);
434 device.vkDestroyCommandPool(commandPool, null);
435 device.vkDestroyPipeline(computePipeline.getAtIndex(0), null);
436 device.vkDestroyPipelineLayout(pipelineLayout, null);
437 device.vkDestroyShaderModule(mandelbrotShader, null);
439 device.vkDestroyDescriptorPool(descriptorPool, null);
440 device.vkDestroyDescriptorSetLayout(descriptorSetLayout, null);
442 device.vkFreeMemory(dst.memory(), null);
443 device.vkDestroyBuffer(dst.buffer(), null);
445 device.vkDestroyDevice(null);
447 instance.vkDestroyDebugUtilsMessengerEXT(logger, null);
448 instance.vkDestroyInstance(null);
452 * Accesses the gpu buffer, converts it to RGB byte, and saves it as a pam file.
454 void save_result() throws Exception {
455 try (ResourceScope scope = ResourceScope.newConfinedScope()) {
456 MemorySegment mem = device.vkMapMemory(dst.memory(), 0, dstBufferSize, 0, scope);
457 byte[] pixels = new byte[WIDTH * HEIGHT * 3];
459 System.out.printf("map %d bytes\n", dstBufferSize);
461 for (int i = 0; i < WIDTH * HEIGHT; i++) {
462 pixels[i * 3 + 0] = mem.get(Memory.BYTE, i * 4 + 0);
463 pixels[i * 3 + 1] = mem.get(Memory.BYTE, i * 4 + 1);
464 pixels[i * 3 + 2] = mem.get(Memory.BYTE, i * 4 + 2);
467 device.vkUnmapMemory(dst.memory());
469 pam_save("mandelbrot.pam", WIDTH, HEIGHT, 3, pixels);
473 void show_result() throws Exception {
474 try (ResourceScope scope = ResourceScope.newConfinedScope()) {
475 MemorySegment mem = device.vkMapMemory(dst.memory(), 0, dstBufferSize, 0, scope);
476 int[] pixels = new int[WIDTH * HEIGHT];
478 System.out.printf("map %d bytes\n", dstBufferSize);
480 MemorySegment.ofArray(pixels).copyFrom(mem);
482 device.vkUnmapMemory(dst.memory());
484 swing_show(WIDTH, HEIGHT, pixels);
489 * Trivial pnm format image output.
491 void pam_save(String name, int width, int height, int depth, byte[] pixels) throws IOException {
492 try (FileOutputStream fos = new FileOutputStream(name)) {
493 fos.write(String.format("P6\n%d\n%d\n255\n", width, height).getBytes());
495 System.out.printf("wrote: %s\n", name);
499 static class DataImage extends JPanel {
501 final int w, h, stride;
502 final MemoryImageSource source;
506 public DataImage(int w, int h, int[] pixels) {
510 this.pixels = pixels;
511 this.source = new MemoryImageSource(w, h, pixels, 0, w);
512 this.source.setAnimated(true);
513 this.source.setFullBufferUpdates(true);
514 this.image = Toolkit.getDefaultToolkit().createImage(source);
518 protected void paintComponent(Graphics g) {
519 super.paintComponent(g);
520 g.drawImage(image, 0, 0, this);
524 void swing_show(int w, int h, int[] pixels) {
526 DataImage image = new DataImage(w, h, pixels);
528 window = new JFrame("mandelbrot");
529 window.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
530 window.setContentPane(image);
531 window.setSize(w, h);
532 window.setVisible(true);
535 IntArray loadSPIRV0(String name) throws IOException {
536 // hmm any way to just load this directly?
537 try (InputStream is = TestMandelbrot.class.getResourceAsStream(name)) {
538 ByteBuffer bb = ByteBuffer.allocateDirect(8192).order(ByteOrder.nativeOrder());
539 int length = Channels.newChannel(is).read(bb);
544 return IntArray.create(MemorySegment.ofByteBuffer(bb));
548 IntArray loadSPIRV(String name) throws IOException {
549 try (InputStream is = TestMandelbrot.class.getResourceAsStream(name)) {
550 MemorySegment seg = ((SegmentAllocator)scope).allocateArray(Memory.INT, 2048);
551 int length = Channels.newChannel(is).read(seg.asByteBuffer());
553 return IntArray.create(seg.asSlice(0, length));
558 * This finds the memory type index for the memory on a specific device.
560 static int find_memory_type(VkPhysicalDeviceMemoryProperties memory, int typeMask, int query) {
561 VkMemoryType mtypes = memory.getMemoryTypes();
563 for (int i = 0; i < memory.getMemoryTypeCount(); i++) {
564 if (((1 << i) & typeMask) != 0 && ((mtypes.getPropertyFlagsAtIndex(i) & query) == query))
570 public static int VK_MAKE_API_VERSION(int variant, int major, int minor, int patch) {
571 return (variant << 29) | (major << 22) | (minor << 12) | patch;
574 void demo() throws Exception {
575 mandelbrot_cs = loadSPIRV("mandelbrot.bin");
582 dst = init_buffer(dstBufferSize,
583 VK_BUFFER_USAGE_STORAGE_BUFFER_BIT,
584 VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
589 init_command_buffer();
591 System.out.printf("Calculating %dx%d\n", WIDTH, HEIGHT);
593 //System.out.println("Saving ...");
595 System.out.println("Showing ...");
597 System.out.println("Done.");
603 public static void main(String[] args) throws Throwable {
604 System.loadLibrary("vulkan");
606 new TestMandelbrot().demo();