703cffc58637cb777cf5dab333eb301e740a9c87
[panamaz] / src / notzed.vulkan.test / classes / vulkan / test / TestMandelbrot.java
1 /*
2 The MIT License (MIT)
3
4 Copyright (C) 2017 Eric Arneb├Ąck
5 Copyright (C) 2019 Michael Zucchi
6
7 Permission is hereby granted, free of charge, to any person obtaining a copy
8 of this software and associated documentation files (the "Software"), to deal
9 in the Software without restriction, including without limitation the rights
10 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 copies of the Software, and to permit persons to whom the Software is
12 furnished to do so, subject to the following conditions:
13
14 The above copyright notice and this permission notice shall be included in
15 all copies or substantial portions of the Software.
16
17 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 THE SOFTWARE.
24
25  */
26
27  /*
28  * This is a Java conversion of a C conversion of this:
29  * https://github.com/Erkaman/vulkan_minimal_compute
30  *
31  * It's been simplified a bit and converted to the 'zvk' api.
32  */
33 package vulkan.test;
34
35 import java.io.InputStream;
36 import java.io.FileOutputStream;
37 import java.io.IOException;
38 import java.nio.channels.Channels;
39 import java.nio.ByteBuffer;
40 import java.nio.ByteOrder;
41
42 import java.awt.Graphics;
43 import java.awt.Image;
44 import java.awt.Toolkit;
45 import java.awt.image.MemoryImageSource;
46 import javax.swing.JFrame;
47 import javax.swing.JPanel;
48 import jdk.incubator.foreign.*;
49 import au.notzed.nativez.*;
50 import java.util.ArrayList;
51 import java.util.List;
52
53 import vulkan.*;
54 import static vulkan.Vulkan.*;
55
56 public class TestMandelbrot {
57
58         static final boolean debug = true;
59         ResourceScope scope = ResourceScope.newSharedScope();
60
61         int WIDTH = 1920 * 1;
62         int HEIGHT = 1080 * 1;
63
64         VkInstance instance;
65         VkPhysicalDevice physicalDevice;
66
67         VkDevice device;
68         VkQueue computeQueue;
69
70         long dstBufferSize = WIDTH * HEIGHT * 4;
71         //VkBuffer dstBuffer;
72         //VkDeviceMemory dstMemory;
73         BufferMemory dst;
74
75         VkDescriptorSetLayout descriptorSetLayout;
76         VkDescriptorPool descriptorPool;
77         HandleArray<VkDescriptorSet> descriptorSets;
78
79         int computeQueueIndex;
80         VkPhysicalDeviceMemoryProperties deviceMemoryProperties;
81
82         String mandelbrot_entry = "main";
83         IntArray mandelbrot_cs;
84
85         VkShaderModule mandelbrotShader;
86         VkPipelineLayout pipelineLayout;
87         HandleArray<VkPipeline> computePipeline = VkPipeline.createArray(1, (SegmentAllocator)scope);
88
89         VkCommandPool commandPool;
90         HandleArray<VkCommandBuffer> commandBuffers;
91
92         record BufferMemory(VkBuffer buffer, VkDeviceMemory memory) {
93
94         }
95
96         VkDebugUtilsMessengerEXT logger;
97
98         void init_debug() throws Exception {
99                 if (!debug)
100                         return;
101                 try ( Frame frame = Frame.frame()) {
102                         var cb = PFN_vkDebugUtilsMessengerCallbackEXT.upcall((severity, flags, data) -> {
103                                 System.out.printf("Debug: %d: %s\n", severity, data.getMessage());
104                                 return 0;
105                         }, scope);
106                         VkDebugUtilsMessengerCreateInfoEXT info = VkDebugUtilsMessengerCreateInfoEXT.create(
107                                 0,
108                                 //VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT |
109                                 VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT
110                                 | VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT
111                                 | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT,
112                                 VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT
113                                 | VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT
114                                 | VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT,
115                                 cb,
116                                 null,
117                                 frame);
118
119                         logger = instance.vkCreateDebugUtilsMessengerEXT(info, scope);
120                 }
121         }
122
123         void init_instance() throws Exception {
124                 try ( Frame frame = Frame.frame()) {
125                         VkInstanceCreateInfo info = VkInstanceCreateInfo.create(
126                                 0,
127                                 VkApplicationInfo.create("test", 1, "test-engine", 2, VK_API_VERSION_1_0, frame),
128                                 new String[]{"VK_LAYER_KHRONOS_validation"},
129                                 debug ? new String[]{"VK_EXT_debug_utils"} : null,
130                                 frame
131                         );
132
133                         instance = VkInstance.vkCreateInstance(info, scope);
134                 }
135         }
136
137         void init_device() throws Exception {
138                 try ( Frame frame = Frame.frame()) {
139                         HandleArray<VkPhysicalDevice> devs;
140                         int count;
141                         int res;
142
143                         devs = instance.vkEnumeratePhysicalDevices(frame, scope);
144
145                         int best = 0;
146                         int devid = -1;
147                         int queueid = -1;
148
149                         for (int i = 0; i < devs.length(); i++) {
150                                 VkPhysicalDevice dev = devs.getAtIndex(i);
151                                 VkQueueFamilyProperties famprops = dev.vkGetPhysicalDeviceQueueFamilyProperties(frame);
152                                 int family_count = (int)famprops.length();
153
154                                 for (int j = 0; j < family_count; j++) {
155                                         var flags = famprops.getAtIndex(j).getQueueFlags();
156                                         int score = 0;
157
158                                         if ((flags & VK_QUEUE_COMPUTE_BIT) != 0)
159                                                 score += 1;
160                                         if ((flags & VK_QUEUE_GRAPHICS_BIT) == 0)
161                                                 score += 1;
162
163                                         if (score > best) {
164                                                 score = best;
165                                                 devid = i;
166                                                 queueid = j;
167                                         }
168                                 }
169                         }
170
171                         if (devid == -1)
172                                 throw new Exception("Cannot find a suitable device");
173
174                         computeQueueIndex = queueid;
175                         physicalDevice = devs.getAtIndex(devid);
176
177                         FloatArray qpri = FloatArray.create(frame, 0.0f);
178                         VkDeviceQueueCreateInfo qinfo = VkDeviceQueueCreateInfo.create(
179                                 0,
180                                 queueid,
181                                 qpri,
182                                 frame);
183                         VkDeviceCreateInfo devinfo = VkDeviceCreateInfo.create(
184                                 0,
185                                 qinfo,
186                                 null,
187                                 null,
188                                 null,
189                                 frame);
190
191                         device = physicalDevice.vkCreateDevice(devinfo, scope);
192
193                         System.out.printf("device = %s\n", device.address());
194
195                         // NOTE: app scope
196                         deviceMemoryProperties = VkPhysicalDeviceMemoryProperties.create((SegmentAllocator)scope);
197                         physicalDevice.vkGetPhysicalDeviceMemoryProperties(deviceMemoryProperties);
198
199                         computeQueue = device.vkGetDeviceQueue(queueid, 0, scope);
200                 }
201         }
202
203         /**
204          * Buffers are created in three steps:
205          * 1) create buffer, specifying usage and size
206          * 2) allocate memory based on memory requirements
207          * 3) bind memory
208          * <p>
209          */
210         BufferMemory init_buffer(long dataSize, int usage, int properties) throws Exception {
211                 try ( Frame frame = Frame.frame()) {
212                         VkMemoryRequirements req = VkMemoryRequirements.create(frame);
213                         VkBufferCreateInfo buf_info = VkBufferCreateInfo.create(
214                                 0,
215                                 dataSize,
216                                 usage,
217                                 VK_SHARING_MODE_EXCLUSIVE,
218                                 null,
219                                 frame);
220
221                         VkBuffer buffer = device.vkCreateBuffer(buf_info, scope);
222
223                         device.vkGetBufferMemoryRequirements(buffer, req);
224
225                         VkMemoryAllocateInfo alloc = VkMemoryAllocateInfo.create(
226                                 req.getSize(),
227                                 find_memory_type(deviceMemoryProperties, req.getMemoryTypeBits(), properties),
228                                 frame);
229
230                         VkDeviceMemory memory = device.vkAllocateMemory(alloc, scope);
231
232                         device.vkBindBufferMemory(buffer, memory, 0);
233
234                         return new BufferMemory(buffer, memory);
235                 }
236         }
237
238         /**
239          * Descriptors are used to bind and describe memory blocks
240          * to shaders.
241          * <p>
242          * *Pool is used to allocate descriptors, it is per-device.
243          * *Layout is used to group descriptors for a given pipeline,
244          * The descriptors describe individually-addressable blocks.
245          */
246         void init_descriptor() throws Exception {
247                 try ( Frame frame = Frame.frame()) {
248                         /* Create descriptorset layout */
249                         VkDescriptorSetLayoutBinding layout_binding = VkDescriptorSetLayoutBinding.create(
250                                 0,
251                                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
252                                 1,
253                                 VK_SHADER_STAGE_COMPUTE_BIT,
254                                 null,
255                                 frame);
256
257                         VkDescriptorSetLayoutCreateInfo descriptor_layout = VkDescriptorSetLayoutCreateInfo.create(
258                                 0,
259                                 layout_binding,
260                                 frame);
261
262                         descriptorSetLayout = device.vkCreateDescriptorSetLayout(descriptor_layout, scope);
263
264                         /* Create descriptor pool */
265                         VkDescriptorPoolSize type_count = VkDescriptorPoolSize.create(
266                                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
267                                 1,
268                                 frame);
269
270                         VkDescriptorPoolCreateInfo descriptor_pool = VkDescriptorPoolCreateInfo.create(
271                                 0,
272                                 1,
273                                 type_count,
274                                 frame);
275
276                         descriptorPool = device.vkCreateDescriptorPool(descriptor_pool, scope);
277
278                         /* Allocate from pool */
279                         HandleArray<VkDescriptorSetLayout> layout_table = VkDescriptorSetLayout.createArray(1, frame);
280
281                         layout_table.setAtIndex(0, descriptorSetLayout);
282
283                         VkDescriptorSetAllocateInfo alloc_info = VkDescriptorSetAllocateInfo.create(
284                                 descriptorPool,
285                                 layout_table,
286                                 frame);
287
288                         descriptorSets = device.vkAllocateDescriptorSets(alloc_info, (SegmentAllocator)scope);
289
290                         /* Bind a buffer to the descriptor */
291                         VkDescriptorBufferInfo bufferInfo = VkDescriptorBufferInfo.create(
292                                 dst.buffer,
293                                 0,
294                                 dstBufferSize,
295                                 frame);
296
297                         VkWriteDescriptorSet writeSet = VkWriteDescriptorSet.create(
298                                 descriptorSets.getAtIndex(0),
299                                 0,
300                                 0,
301                                 1,
302                                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
303                                 null,
304                                 bufferInfo,
305                                 null,
306                                 frame);
307
308                         System.out.println(writeSet);
309
310                         device.vkUpdateDescriptorSets(1, writeSet, 0, null);
311                 }
312         }
313
314         /**
315          * Create the compute pipeline. This is the shader and data layouts for it.
316          */
317         void init_pipeline() throws Exception {
318                 try ( Frame frame = Frame.frame()) {
319                         /* Set shader code */
320                         VkShaderModuleCreateInfo vsInfo = VkShaderModuleCreateInfo.create(
321                                 0,
322                                 mandelbrot_cs.length() * 4,
323                                 mandelbrot_cs,
324                                 frame);
325
326                         mandelbrotShader = device.vkCreateShaderModule(vsInfo, scope);
327
328                         /* Link shader to layout */
329                         HandleArray<VkDescriptorSetLayout> layout_table = VkDescriptorSetLayout.createArray(1, frame);
330
331                         layout_table.setAtIndex(0, descriptorSetLayout);
332
333                         VkPipelineLayoutCreateInfo pipelineinfo = VkPipelineLayoutCreateInfo.create(
334                                 0,
335                                 layout_table,
336                                 null,
337                                 frame);
338
339                         pipelineLayout = device.vkCreatePipelineLayout(pipelineinfo, scope);
340
341                         /* Create pipeline */
342                         VkComputePipelineCreateInfo pipeline = VkComputePipelineCreateInfo.create(
343                                 0,
344                                 pipelineLayout,
345                                 null,
346                                 0,
347                                 frame);
348
349                         VkPipelineShaderStageCreateInfo stage = pipeline.getStage();
350
351                         stage.setStage(VK_SHADER_STAGE_COMPUTE_BIT);
352                         stage.setModule(mandelbrotShader);
353                         stage.setName(mandelbrot_entry, frame);
354
355                         device.vkCreateComputePipelines(null, 1, pipeline, computePipeline);
356                 }
357         }
358
359         /**
360          * Create a command buffer, this is somewhat like a display list.
361          */
362         void init_command_buffer() throws Exception {
363                 try ( Frame frame = Frame.frame()) {
364                         VkCommandPoolCreateInfo poolinfo = VkCommandPoolCreateInfo.create(
365                                 0,
366                                 computeQueueIndex,
367                                 frame);
368
369                         commandPool = device.vkCreateCommandPool(poolinfo, scope);
370
371                         VkCommandBufferAllocateInfo cmdinfo = VkCommandBufferAllocateInfo.create(
372                                 commandPool,
373                                 VK_COMMAND_BUFFER_LEVEL_PRIMARY,
374                                 1,
375                                 frame);
376
377                         // should it take a scope?
378                         commandBuffers = device.vkAllocateCommandBuffers(cmdinfo, (SegmentAllocator)scope, scope);
379
380                         /* Fill command buffer with commands for later operation */
381                         VkCommandBufferBeginInfo beginInfo = VkCommandBufferBeginInfo.create(
382                                 VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT,
383                                 null,
384                                 frame);
385
386                         commandBuffers.get(0).vkBeginCommandBuffer(beginInfo);
387
388                         /* Bind the compute operation and data */
389                         commandBuffers.get(0).vkCmdBindPipeline(VK_PIPELINE_BIND_POINT_COMPUTE, computePipeline.get(0));
390                         commandBuffers.get(0).vkCmdBindDescriptorSets(VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout, 0, 1, descriptorSets, 0, null);
391
392                         /* Run it */
393                         commandBuffers.get(0).vkCmdDispatch(WIDTH, HEIGHT, 1);
394
395                         commandBuffers.get(0).vkEndCommandBuffer();
396                 }
397         }
398
399         /**
400          * Execute the pre-created command buffer.
401          * <p>
402          * A fence is used to wait for completion.
403          */
404         void execute() throws Exception {
405                 try ( Frame frame = Frame.frame()) {
406                         VkSubmitInfo submitInfo = VkSubmitInfo.create(frame);
407
408                         submitInfo.setCommandBufferCount(1);
409                         submitInfo.setCommandBuffers(commandBuffers);
410
411                         /* Create fence to mark the task completion */
412                         VkFence fence;
413                         HandleArray<VkFence> fences = VkFence.createArray(1, frame);
414                         VkFenceCreateInfo fenceInfo = VkFenceCreateInfo.create(frame);
415
416                         // maybe this should take a HandleArray<Fence> rather than being a constructor
417                         // FIXME: some local scope
418                         fence = device.vkCreateFence(fenceInfo, scope);
419                         fences.set(0, fence);
420
421                         /* Await completion */
422                         computeQueue.vkQueueSubmit(1, submitInfo, fence);
423
424                         int VK_TRUE = 1;
425                         int res;
426                         do {
427                                 res = device.vkWaitForFences(1, fences, VK_TRUE, 1000000);
428                         } while (res == VK_TIMEOUT);
429
430                         device.vkDestroyFence(fence);
431                 }
432         }
433
434         void shutdown() {
435                 device.vkDestroyCommandPool(commandPool);
436                 device.vkDestroyPipeline(computePipeline.getAtIndex(0));
437                 device.vkDestroyPipelineLayout(pipelineLayout);
438                 device.vkDestroyShaderModule(mandelbrotShader);
439
440                 device.vkDestroyDescriptorPool(descriptorPool);
441                 device.vkDestroyDescriptorSetLayout(descriptorSetLayout);
442
443                 device.vkFreeMemory(dst.memory());
444                 device.vkDestroyBuffer(dst.buffer());
445
446                 device.vkDestroyDevice();
447                 if (logger != null)
448                         instance.vkDestroyDebugUtilsMessengerEXT(logger);
449                 instance.vkDestroyInstance();
450         }
451
452         /**
453          * Accesses the gpu buffer, converts it to RGB byte, and saves it as a pam file.
454          */
455         void save_result() throws Exception {
456                 try ( ResourceScope scope = ResourceScope.newConfinedScope()) {
457                         MemorySegment mem = device.vkMapMemory(dst.memory(), 0, dstBufferSize, 0, scope);
458                         byte[] pixels = new byte[WIDTH * HEIGHT * 3];
459
460                         System.out.printf("map %d bytes\n", dstBufferSize);
461
462                         for (int i = 0; i < WIDTH * HEIGHT; i++) {
463                                 pixels[i * 3 + 0] = mem.get(Memory.BYTE, i * 4 + 0);
464                                 pixels[i * 3 + 1] = mem.get(Memory.BYTE, i * 4 + 1);
465                                 pixels[i * 3 + 2] = mem.get(Memory.BYTE, i * 4 + 2);
466                         }
467
468                         device.vkUnmapMemory(dst.memory());
469
470                         pam_save("mandelbrot.pam", WIDTH, HEIGHT, 3, pixels);
471                 }
472         }
473
474         void show_result() throws Exception {
475                 try ( ResourceScope scope = ResourceScope.newConfinedScope()) {
476                         MemorySegment mem = device.vkMapMemory(dst.memory(), 0, dstBufferSize, 0, scope);
477                         int[] pixels = new int[WIDTH * HEIGHT];
478
479                         System.out.printf("map %d bytes\n", dstBufferSize);
480
481                         MemorySegment.ofArray(pixels).copyFrom(mem);
482
483                         device.vkUnmapMemory(dst.memory());
484
485                         swing_show(WIDTH, HEIGHT, pixels);
486                 }
487         }
488
489         /**
490          * Trivial pnm format image output.
491          */
492         void pam_save(String name, int width, int height, int depth, byte[] pixels) throws IOException {
493                 try ( FileOutputStream fos = new FileOutputStream(name)) {
494                         fos.write(String.format("P6\n%d\n%d\n255\n", width, height).getBytes());
495                         fos.write(pixels);
496                         System.out.printf("wrote: %s\n", name);
497                 }
498         }
499
500         static class DataImage extends JPanel {
501
502                 final int w, h, stride;
503                 final MemoryImageSource source;
504                 final Image image;
505                 final int[] pixels;
506
507                 public DataImage(int w, int h, int[] pixels) {
508                         this.w = w;
509                         this.h = h;
510                         this.stride = w;
511                         this.pixels = pixels;
512                         this.source = new MemoryImageSource(w, h, pixels, 0, w);
513                         this.source.setAnimated(true);
514                         this.source.setFullBufferUpdates(true);
515                         this.image = Toolkit.getDefaultToolkit().createImage(source);
516                 }
517
518                 @Override
519                 protected void paintComponent(Graphics g) {
520                         super.paintComponent(g);
521                         g.drawImage(image, 0, 0, this);
522                 }
523         }
524
525         void swing_show(int w, int h, int[] pixels) {
526                 JFrame window;
527                 DataImage image = new DataImage(w, h, pixels);
528
529                 window = new JFrame("mandelbrot");
530                 window.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
531                 window.setContentPane(image);
532                 window.setSize(w, h);
533                 window.setVisible(true);
534         }
535
536         /**
537          * This finds the memory type index for the memory on a specific device.
538          */
539         static int find_memory_type(VkPhysicalDeviceMemoryProperties memory, int typeMask, int query) {
540                 VkMemoryType mtypes = memory.getMemoryTypes();
541
542                 for (int i = 0; i < memory.getMemoryTypeCount(); i++) {
543                         if (((1 << i) & typeMask) != 0 && ((mtypes.getAtIndex(i).getPropertyFlags() & query) == query))
544                                 return i;
545                 }
546                 return -1;
547         }
548
549         void demo() throws Exception {
550                 mandelbrot_cs = ShaderIO.loadMandelbrot((SegmentAllocator)scope);
551
552                 init_instance();
553                 init_debug();
554
555                 init_device();
556
557                 dst = init_buffer(dstBufferSize,
558                         VK_BUFFER_USAGE_STORAGE_BUFFER_BIT,
559                         VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
560
561                 init_descriptor();
562
563                 init_pipeline();
564                 init_command_buffer();
565
566                 System.out.printf("Calculating %dx%d\n", WIDTH, HEIGHT);
567                 execute();
568                 //System.out.println("Saving ...");
569                 //save_result();
570                 System.out.println("Showing ...");
571                 show_result();
572                 System.out.println("Done.");
573
574                 shutdown();
575         }
576
577         public static void main(String[] args) throws Throwable {
578                 System.loadLibrary("vulkan");
579
580                 new TestMandelbrot().demo();
581         }
582 }