Add abstracted display module.
[panamaz] / src / notzed.vulkan.test / classes / vulkan / test / TestMandelbrot.java
1 /*
2 The MIT License (MIT)
3
4 Copyright (C) 2017 Eric Arneb├Ąck
5 Copyright (C) 2019 Michael Zucchi
6
7 Permission is hereby granted, free of charge, to any person obtaining a copy
8 of this software and associated documentation files (the "Software"), to deal
9 in the Software without restriction, including without limitation the rights
10 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 copies of the Software, and to permit persons to whom the Software is
12 furnished to do so, subject to the following conditions:
13
14 The above copyright notice and this permission notice shall be included in
15 all copies or substantial portions of the Software.
16
17 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 THE SOFTWARE.
24
25  */
26
27  /*
28  * This is a Java conversion of a C conversion of this:
29  * https://github.com/Erkaman/vulkan_minimal_compute
30  *
31  * It's been simplified a bit and converted to the 'zvk' api.
32  */
33 package vulkan.test;
34
35 import java.io.InputStream;
36 import java.io.FileOutputStream;
37 import java.io.IOException;
38 import java.nio.channels.Channels;
39 import java.nio.ByteBuffer;
40 import java.nio.ByteOrder;
41
42 import java.awt.Graphics;
43 import java.awt.Image;
44 import java.awt.Toolkit;
45 import java.awt.image.MemoryImageSource;
46 import javax.swing.JFrame;
47 import javax.swing.JPanel;
48 import jdk.incubator.foreign.*;
49 import au.notzed.nativez.*;
50 import java.util.ArrayList;
51 import java.util.List;
52
53 import vulkan.*;
54 import static vulkan.Vulkan.*;
55
56 public class TestMandelbrot {
57
58         static final boolean debug = true;
59         ResourceScope scope = ResourceScope.newSharedScope();
60
61         int WIDTH = 1920 * 1;
62         int HEIGHT = 1080 * 1;
63
64         VkInstance instance;
65         VkPhysicalDevice physicalDevice;
66
67         VkDevice device;
68         VkQueue computeQueue;
69
70         long dstBufferSize = WIDTH * HEIGHT * 4;
71         //VkBuffer dstBuffer;
72         //VkDeviceMemory dstMemory;
73         BufferMemory dst;
74
75         VkDescriptorSetLayout descriptorSetLayout;
76         VkDescriptorPool descriptorPool;
77         HandleArray<VkDescriptorSet> descriptorSets;
78
79         int computeQueueIndex;
80         VkPhysicalDeviceMemoryProperties deviceMemoryProperties;
81
82         String mandelbrot_entry = "main";
83         IntArray mandelbrot_cs;
84
85         VkShaderModule mandelbrotShader;
86         VkPipelineLayout pipelineLayout;
87         HandleArray<VkPipeline> computePipeline = VkPipeline.createArray(1, (SegmentAllocator)scope);
88
89         VkCommandPool commandPool;
90         HandleArray<VkCommandBuffer> commandBuffers;
91
92         record BufferMemory(VkBuffer buffer, VkDeviceMemory memory) {
93
94         }
95
96         VkDebugUtilsMessengerEXT logger;
97
98         void init_debug() throws Exception {
99                 if (!debug)
100                         return;
101                 try ( Frame frame = Frame.frame()) {
102                         var cb = PFN_vkDebugUtilsMessengerCallbackEXT.upcall((severity, flags, data) -> {
103                                 System.out.printf("Debug: %d: %s\n", severity, data.getMessage());
104                                 return 0;
105                         }, scope);
106                         VkDebugUtilsMessengerCreateInfoEXT info = VkDebugUtilsMessengerCreateInfoEXT.create(frame,
107                                 0,
108                                 //VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT |
109                                 VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT
110                                 | VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT
111                                 | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT,
112                                 VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT
113                                 | VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT
114                                 | VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT,
115                                 cb,
116                                 null);
117
118                         logger = instance.vkCreateDebugUtilsMessengerEXT(info, scope);
119                 }
120         }
121
122         void init_instance() throws Exception {
123                 try ( Frame frame = Frame.frame()) {
124                         VkInstanceCreateInfo info = VkInstanceCreateInfo.create(frame,
125                                 0,
126                                 VkApplicationInfo.create(frame, "test", 1, "test-engine", 2, VK_API_VERSION_1_0),
127                                 new String[]{"VK_LAYER_KHRONOS_validation"},
128                                 debug ? new String[]{"VK_EXT_debug_utils"} : null
129                         );
130
131                         instance = VkInstance.vkCreateInstance(info, scope);
132                 }
133         }
134
135         void init_device() throws Exception {
136                 try ( Frame frame = Frame.frame()) {
137                         HandleArray<VkPhysicalDevice> devs;
138                         int count;
139                         int res;
140
141                         devs = instance.vkEnumeratePhysicalDevices(frame, scope);
142
143                         int best = 0;
144                         int devid = -1;
145                         int queueid = -1;
146
147                         for (int i = 0; i < devs.length(); i++) {
148                                 VkPhysicalDevice dev = devs.getAtIndex(i);
149                                 VkQueueFamilyProperties famprops = dev.vkGetPhysicalDeviceQueueFamilyProperties(frame);
150                                 int family_count = (int)famprops.length();
151
152                                 for (int j = 0; j < family_count; j++) {
153                                         var flags = famprops.getAtIndex(j).getQueueFlags();
154                                         int score = 0;
155
156                                         if ((flags & VK_QUEUE_COMPUTE_BIT) != 0)
157                                                 score += 1;
158                                         if ((flags & VK_QUEUE_GRAPHICS_BIT) == 0)
159                                                 score += 1;
160
161                                         if (score > best) {
162                                                 score = best;
163                                                 devid = i;
164                                                 queueid = j;
165                                         }
166                                 }
167                         }
168
169                         if (devid == -1)
170                                 throw new Exception("Cannot find a suitable device");
171
172                         computeQueueIndex = queueid;
173                         physicalDevice = devs.getAtIndex(devid);
174
175                         FloatArray qpri = FloatArray.create(frame, 0.0f);
176                         VkDeviceQueueCreateInfo qinfo = VkDeviceQueueCreateInfo.create(
177                                 frame,
178                                 0,
179                                 queueid,
180                                 qpri);
181                         VkDeviceCreateInfo devinfo = VkDeviceCreateInfo.create(
182                                 frame,
183                                 0,
184                                 qinfo,
185                                 null,
186                                 null,
187                                 null);
188
189                         device = physicalDevice.vkCreateDevice(devinfo, scope);
190
191                         System.out.printf("device = %s\n", device.address());
192
193                         // NOTE: app scope
194                         deviceMemoryProperties = VkPhysicalDeviceMemoryProperties.create((SegmentAllocator)scope);
195                         physicalDevice.vkGetPhysicalDeviceMemoryProperties(deviceMemoryProperties);
196
197                         computeQueue = device.vkGetDeviceQueue(queueid, 0, scope);
198                 }
199         }
200
201         /**
202          * Buffers are created in three steps:
203          * 1) create buffer, specifying usage and size
204          * 2) allocate memory based on memory requirements
205          * 3) bind memory
206          * <p>
207          */
208         BufferMemory init_buffer(long dataSize, int usage, int properties) throws Exception {
209                 try ( Frame frame = Frame.frame()) {
210                         VkMemoryRequirements req = VkMemoryRequirements.create(frame);
211                         VkBufferCreateInfo buf_info = VkBufferCreateInfo.create(frame,
212                                 0,
213                                 dataSize,
214                                 usage,
215                                 VK_SHARING_MODE_EXCLUSIVE,
216                                 null);
217
218                         VkBuffer buffer = device.vkCreateBuffer(buf_info, scope);
219
220                         device.vkGetBufferMemoryRequirements(buffer, req);
221
222                         VkMemoryAllocateInfo alloc = VkMemoryAllocateInfo.create(frame,
223                                 req.getSize(),
224                                 find_memory_type(deviceMemoryProperties, req.getMemoryTypeBits(), properties));
225
226                         VkDeviceMemory memory = device.vkAllocateMemory(alloc, scope);
227
228                         device.vkBindBufferMemory(buffer, memory, 0);
229
230                         return new BufferMemory(buffer, memory);
231                 }
232         }
233
234         /**
235          * Descriptors are used to bind and describe memory blocks
236          * to shaders.
237          * <p>
238          * *Pool is used to allocate descriptors, it is per-device.
239          * *Layout is used to group descriptors for a given pipeline,
240          * The descriptors describe individually-addressable blocks.
241          */
242         void init_descriptor() throws Exception {
243                 try ( Frame frame = Frame.frame()) {
244                         /* Create descriptorset layout */
245                         VkDescriptorSetLayoutBinding layout_binding = VkDescriptorSetLayoutBinding.create(frame,
246                                 0,
247                                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
248                                 1,
249                                 VK_SHADER_STAGE_COMPUTE_BIT,
250                                 null);
251
252                         VkDescriptorSetLayoutCreateInfo descriptor_layout = VkDescriptorSetLayoutCreateInfo.create(frame,
253                                 0,
254                                 layout_binding);
255
256                         descriptorSetLayout = device.vkCreateDescriptorSetLayout(descriptor_layout, scope);
257
258                         /* Create descriptor pool */
259                         VkDescriptorPoolSize type_count = VkDescriptorPoolSize.create(frame,
260                                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
261                                 1);
262
263                         VkDescriptorPoolCreateInfo descriptor_pool = VkDescriptorPoolCreateInfo.create(frame,
264                                 0,
265                                 1,
266                                 type_count);
267
268                         descriptorPool = device.vkCreateDescriptorPool(descriptor_pool, scope);
269
270                         /* Allocate from pool */
271                         HandleArray<VkDescriptorSetLayout> layout_table = VkDescriptorSetLayout.createArray(1, frame);
272
273                         layout_table.setAtIndex(0, descriptorSetLayout);
274
275                         VkDescriptorSetAllocateInfo alloc_info = VkDescriptorSetAllocateInfo.create(frame,
276                                 descriptorPool,
277                                 layout_table);
278
279                         descriptorSets = device.vkAllocateDescriptorSets(alloc_info, (SegmentAllocator)scope);
280
281                         /* Bind a buffer to the descriptor */
282                         VkDescriptorBufferInfo bufferInfo = VkDescriptorBufferInfo.create(frame,
283                                 dst.buffer,
284                                 0,
285                                 dstBufferSize);
286
287                         VkWriteDescriptorSet writeSet = VkWriteDescriptorSet.create(frame,
288                                 descriptorSets.getAtIndex(0),
289                                 0,
290                                 0,
291                                 1,
292                                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
293                                 null,
294                                 bufferInfo,
295                                 null);
296
297                         System.out.println(writeSet);
298
299                         device.vkUpdateDescriptorSets(1, writeSet, 0, null);
300                 }
301         }
302
303         /**
304          * Create the compute pipeline. This is the shader and data layouts for it.
305          */
306         void init_pipeline() throws Exception {
307                 try ( Frame frame = Frame.frame()) {
308                         /* Set shader code */
309                         VkShaderModuleCreateInfo vsInfo = VkShaderModuleCreateInfo.create(frame,
310                                 0,
311                                 mandelbrot_cs.length() * 4,
312                                 mandelbrot_cs);
313
314                         mandelbrotShader = device.vkCreateShaderModule(vsInfo, scope);
315
316                         /* Link shader to layout */
317                         HandleArray<VkDescriptorSetLayout> layout_table = VkDescriptorSetLayout.createArray(1, frame);
318
319                         layout_table.setAtIndex(0, descriptorSetLayout);
320
321                         VkPipelineLayoutCreateInfo pipelineinfo = VkPipelineLayoutCreateInfo.create(frame,
322                                 0,
323                                 layout_table,
324                                 null);
325
326                         pipelineLayout = device.vkCreatePipelineLayout(pipelineinfo, scope);
327
328                         /* Create pipeline */
329                         VkComputePipelineCreateInfo pipeline = VkComputePipelineCreateInfo.create(frame,
330                                 0,
331                                 pipelineLayout,
332                                 null,
333                                 0);
334
335                         VkPipelineShaderStageCreateInfo stage = pipeline.getStage();
336
337                         stage.setStage(VK_SHADER_STAGE_COMPUTE_BIT);
338                         stage.setModule(mandelbrotShader);
339                         stage.setName(mandelbrot_entry, frame);
340
341                         device.vkCreateComputePipelines(null, 1, pipeline, computePipeline);
342                 }
343         }
344
345         /**
346          * Create a command buffer, this is somewhat like a display list.
347          */
348         void init_command_buffer() throws Exception {
349                 try ( Frame frame = Frame.frame()) {
350                         VkCommandPoolCreateInfo poolinfo = VkCommandPoolCreateInfo.create(frame,
351                                 0,
352                                 computeQueueIndex);
353
354                         commandPool = device.vkCreateCommandPool(poolinfo, scope);
355
356                         VkCommandBufferAllocateInfo cmdinfo = VkCommandBufferAllocateInfo.create(frame,
357                                 commandPool,
358                                 VK_COMMAND_BUFFER_LEVEL_PRIMARY,
359                                 1);
360
361                         // should it take a scope?
362                         commandBuffers = device.vkAllocateCommandBuffers(cmdinfo, (SegmentAllocator)scope, scope);
363
364                         /* Fill command buffer with commands for later operation */
365                         VkCommandBufferBeginInfo beginInfo = VkCommandBufferBeginInfo.create(frame,
366                                 VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT,
367                                 null);
368
369                         commandBuffers.get(0).vkBeginCommandBuffer(beginInfo);
370
371                         /* Bind the compute operation and data */
372                         commandBuffers.get(0).vkCmdBindPipeline(VK_PIPELINE_BIND_POINT_COMPUTE, computePipeline.get(0));
373                         commandBuffers.get(0).vkCmdBindDescriptorSets(VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout, 0, 1, descriptorSets, 0, null);
374
375                         /* Run it */
376                         commandBuffers.get(0).vkCmdDispatch(WIDTH, HEIGHT, 1);
377
378                         commandBuffers.get(0).vkEndCommandBuffer();
379                 }
380         }
381
382         /**
383          * Execute the pre-created command buffer.
384          * <p>
385          * A fence is used to wait for completion.
386          */
387         void execute() throws Exception {
388                 try ( Frame frame = Frame.frame()) {
389                         VkSubmitInfo submitInfo = VkSubmitInfo.create(frame);
390
391                         submitInfo.setCommandBufferCount(1);
392                         submitInfo.setCommandBuffers(commandBuffers);
393
394                         /* Create fence to mark the task completion */
395                         VkFence fence;
396                         HandleArray<VkFence> fences = VkFence.createArray(1, frame);
397                         VkFenceCreateInfo fenceInfo = VkFenceCreateInfo.create(frame);
398
399                         // maybe this should take a HandleArray<Fence> rather than being a constructor
400                         // FIXME: some local scope
401                         fence = device.vkCreateFence(fenceInfo, scope);
402                         fences.set(0, fence);
403
404                         /* Await completion */
405                         computeQueue.vkQueueSubmit(1, submitInfo, fence);
406
407                         int VK_TRUE = 1;
408                         int res;
409                         do {
410                                 res = device.vkWaitForFences(1, fences, VK_TRUE, 1000000);
411                         } while (res == VK_TIMEOUT);
412
413                         device.vkDestroyFence(fence);
414                 }
415         }
416
417         void shutdown() {
418                 device.vkDestroyCommandPool(commandPool);
419                 device.vkDestroyPipeline(computePipeline.getAtIndex(0));
420                 device.vkDestroyPipelineLayout(pipelineLayout);
421                 device.vkDestroyShaderModule(mandelbrotShader);
422
423                 device.vkDestroyDescriptorPool(descriptorPool);
424                 device.vkDestroyDescriptorSetLayout(descriptorSetLayout);
425
426                 device.vkFreeMemory(dst.memory());
427                 device.vkDestroyBuffer(dst.buffer());
428
429                 device.vkDestroyDevice();
430                 if (logger != null)
431                         instance.vkDestroyDebugUtilsMessengerEXT(logger);
432                 instance.vkDestroyInstance();
433         }
434
435         /**
436          * Accesses the gpu buffer, converts it to RGB byte, and saves it as a pam file.
437          */
438         void save_result() throws Exception {
439                 try ( ResourceScope scope = ResourceScope.newConfinedScope()) {
440                         MemorySegment mem = device.vkMapMemory(dst.memory(), 0, dstBufferSize, 0, scope);
441                         byte[] pixels = new byte[WIDTH * HEIGHT * 3];
442
443                         System.out.printf("map %d bytes\n", dstBufferSize);
444
445                         for (int i = 0; i < WIDTH * HEIGHT; i++) {
446                                 pixels[i * 3 + 0] = mem.get(Memory.BYTE, i * 4 + 0);
447                                 pixels[i * 3 + 1] = mem.get(Memory.BYTE, i * 4 + 1);
448                                 pixels[i * 3 + 2] = mem.get(Memory.BYTE, i * 4 + 2);
449                         }
450
451                         device.vkUnmapMemory(dst.memory());
452
453                         pam_save("mandelbrot.pam", WIDTH, HEIGHT, 3, pixels);
454                 }
455         }
456
457         void show_result() throws Exception {
458                 try ( ResourceScope scope = ResourceScope.newConfinedScope()) {
459                         MemorySegment mem = device.vkMapMemory(dst.memory(), 0, dstBufferSize, 0, scope);
460                         int[] pixels = new int[WIDTH * HEIGHT];
461
462                         System.out.printf("map %d bytes\n", dstBufferSize);
463
464                         MemorySegment.ofArray(pixels).copyFrom(mem);
465
466                         device.vkUnmapMemory(dst.memory());
467
468                         swing_show(WIDTH, HEIGHT, pixels);
469                 }
470         }
471
472         /**
473          * Trivial pnm format image output.
474          */
475         void pam_save(String name, int width, int height, int depth, byte[] pixels) throws IOException {
476                 try ( FileOutputStream fos = new FileOutputStream(name)) {
477                         fos.write(String.format("P6\n%d\n%d\n255\n", width, height).getBytes());
478                         fos.write(pixels);
479                         System.out.printf("wrote: %s\n", name);
480                 }
481         }
482
483         static class DataImage extends JPanel {
484
485                 final int w, h, stride;
486                 final MemoryImageSource source;
487                 final Image image;
488                 final int[] pixels;
489
490                 public DataImage(int w, int h, int[] pixels) {
491                         this.w = w;
492                         this.h = h;
493                         this.stride = w;
494                         this.pixels = pixels;
495                         this.source = new MemoryImageSource(w, h, pixels, 0, w);
496                         this.source.setAnimated(true);
497                         this.source.setFullBufferUpdates(true);
498                         this.image = Toolkit.getDefaultToolkit().createImage(source);
499                 }
500
501                 @Override
502                 protected void paintComponent(Graphics g) {
503                         super.paintComponent(g);
504                         g.drawImage(image, 0, 0, this);
505                 }
506         }
507
508         void swing_show(int w, int h, int[] pixels) {
509                 JFrame window;
510                 DataImage image = new DataImage(w, h, pixels);
511
512                 window = new JFrame("mandelbrot");
513                 window.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
514                 window.setContentPane(image);
515                 window.setSize(w, h);
516                 window.setVisible(true);
517         }
518
519         /**
520          * This finds the memory type index for the memory on a specific device.
521          */
522         static int find_memory_type(VkPhysicalDeviceMemoryProperties memory, int typeMask, int query) {
523                 VkMemoryType mtypes = memory.getMemoryTypes();
524
525                 for (int i = 0; i < memory.getMemoryTypeCount(); i++) {
526                         if (((1 << i) & typeMask) != 0 && ((mtypes.getAtIndex(i).getPropertyFlags() & query) == query))
527                                 return i;
528                 }
529                 return -1;
530         }
531
532         void demo() throws Exception {
533                 mandelbrot_cs = ShaderIO.loadMandelbrot((SegmentAllocator)scope);
534
535                 init_instance();
536                 init_debug();
537
538                 init_device();
539
540                 dst = init_buffer(dstBufferSize,
541                         VK_BUFFER_USAGE_STORAGE_BUFFER_BIT,
542                         VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
543
544                 init_descriptor();
545
546                 init_pipeline();
547                 init_command_buffer();
548
549                 System.out.printf("Calculating %dx%d\n", WIDTH, HEIGHT);
550                 execute();
551                 //System.out.println("Saving ...");
552                 //save_result();
553                 System.out.println("Showing ...");
554                 show_result();
555                 System.out.println("Done.");
556
557                 shutdown();
558         }
559
560         public static void main(String[] args) throws Throwable {
561                 System.loadLibrary("vulkan");
562
563                 new TestMandelbrot().demo();
564         }
565 }