4ebb0027150e35a0267678787a2b8a98c373c73d
[panamaz] / src / notzed.vulkan.test / classes / vulkan / test / TestMandelbrot.java
1  /*
2 The MIT License (MIT)
3
4 Copyright (C) 2017 Eric Arneb├Ąck
5 Copyright (C) 2019 Michael Zucchi
6
7 Permission is hereby granted, free of charge, to any person obtaining a copy
8 of this software and associated documentation files (the "Software"), to deal
9 in the Software without restriction, including without limitation the rights
10 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 copies of the Software, and to permit persons to whom the Software is
12 furnished to do so, subject to the following conditions:
13
14 The above copyright notice and this permission notice shall be included in
15 all copies or substantial portions of the Software.
16
17 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 THE SOFTWARE.
24
25  */
26
27 /*
28  * This is a Java conversion of a C conversion of this:
29  * https://github.com/Erkaman/vulkan_minimal_compute
30  *
31  * It's been simplified a bit and converted to the 'zvk' api.
32  */
33
34 package vulkan.test;
35
36 import java.io.InputStream;
37 import java.io.FileOutputStream;
38 import java.io.IOException;
39 import java.nio.channels.Channels;
40 import java.nio.ByteBuffer;
41 import java.nio.ByteOrder;
42
43 import java.awt.Graphics;
44 import java.awt.Image;
45 import java.awt.Toolkit;
46 import java.awt.event.ActionEvent;
47 import java.awt.event.KeyEvent;
48 import java.awt.image.MemoryImageSource;
49 import javax.swing.AbstractAction;
50 import javax.swing.JComponent;
51 import javax.swing.JFrame;
52 import javax.swing.JPanel;
53 import javax.swing.KeyStroke;
54
55 import java.lang.ref.WeakReference;
56
57 import java.lang.invoke.*;
58 import jdk.incubator.foreign.*;
59 import jdk.incubator.foreign.MemoryLayout.PathElement;
60 import au.notzed.nativez.*;
61
62 import vulkan.*;
63
64 import static vulkan.VkConstants.*;
65
66 public class TestMandelbrot {
67         static final boolean debug = true;
68         ResourceScope scope = ResourceScope.newSharedScope();
69
70         int WIDTH = 1920*1;
71         int HEIGHT = 1080*1;
72
73         VkInstance instance;
74         VkPhysicalDevice physicalDevice;
75
76         VkDevice device;
77         VkQueue computeQueue;
78
79         long dstBufferSize = WIDTH * HEIGHT * 4;
80         //VkBuffer dstBuffer;
81         //VkDeviceMemory dstMemory;
82         BufferMemory dst;
83
84         VkDescriptorSetLayout descriptorSetLayout;
85         VkDescriptorPool descriptorPool;
86         HandleArray<VkDescriptorSet> descriptorSets;
87
88         int computeQueueIndex;
89         VkPhysicalDeviceMemoryProperties deviceMemoryProperties;
90
91         String mandelbrot_entry = "main";
92         IntArray mandelbrot_cs;
93
94         VkShaderModule mandelbrotShader;
95         VkPipelineLayout pipelineLayout;
96         HandleArray<VkPipeline> computePipeline = VkPipeline.createArray(1, (SegmentAllocator)scope);
97
98         VkCommandPool commandPool;
99         HandleArray<VkCommandBuffer> commandBuffers;
100
101         record BufferMemory ( VkBuffer buffer, VkDeviceMemory memory ) {};
102
103         VkDebugUtilsMessengerEXT logger;
104
105         void init_debug() throws Exception {
106                 if (!debug)
107                         return;
108                 /*
109                 try (Frame frame = Frame.frame()) {
110                         var cb = PFN_vkDebugUtilsMessengerCallbackEXT.upcall((severity, flags, data, dummy) -> {
111                                         System.out.printf("Debug: %d: %s\n", severity, data.getMessage());
112                                         return 0;
113                                 }, scope);
114                         VkDebugUtilsMessengerCreateInfoEXT info = VkDebugUtilsMessengerCreateInfoEXT.create(frame,
115                                 0,
116                                 VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT
117                                 | VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT
118                                 | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT,
119                                 VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT,
120                                 cb,
121                                 null);
122
123                         logger = instance.vkCreateDebugUtilsMessengerEXT(info, null, scope);
124                 }
125                 */
126                 //typedef VkBool32 (*PFN_vkDebugUtilsMessengerCallbackEXT)(VkDebugUtilsMessageSeverityFlagBitsEXT, VkDebugUtilsMessageTypeFlagsEXT, const VkDebugUtilsMessengerCallbackDataEXT *, void *);
127
128         }
129
130         void init_instance() throws Exception {
131                 try (Frame frame = Frame.frame()) {
132                         VkInstanceCreateInfo info = VkInstanceCreateInfo.create(frame,
133                                 0,
134                                 VkApplicationInfo.create(frame, "test", 1, "test-engine", 2, VK_MAKE_API_VERSION(0, 1, 0, 0)),
135                                 new String[] { "VK_LAYER_KHRONOS_validation" },
136                                 debug ? new String[] { "VK_EXT_debug_utils" } : null
137                                 );
138
139                         instance = VkInstance.vkCreateInstance(info, null, scope);
140                 }
141         }
142
143         void init_device() throws Exception {
144                 try (Frame frame = Frame.frame()) {
145                         HandleArray<VkPhysicalDevice> devs;
146                         int count;
147                         int res;
148
149                         devs = instance.vkEnumeratePhysicalDevices(frame, scope);
150
151                         int best = 0;
152                         int devid = -1;
153                         int queueid = -1;
154
155                         for (int i=0;i<devs.length();i++) {
156                                 VkPhysicalDevice dev = devs.getAtIndex(i);
157                                 VkQueueFamilyProperties famprops = dev.vkGetPhysicalDeviceQueueFamilyProperties(frame);
158                                 int family_count = (int)famprops.length();
159
160                                 for (int j=0;j<family_count;j++) {
161                                         var flags = famprops.getAtIndex(j).getQueueFlags();
162                                         int score = 0;
163
164                                         if ((flags & VK_QUEUE_COMPUTE_BIT) != 0)
165                                                 score += 1;
166                                         if ((flags & VK_QUEUE_GRAPHICS_BIT) == 0)
167                                                 score += 1;
168
169                                         if (score > best) {
170                                                 score = best;
171                                                 devid = i;
172                                                 queueid = j;
173                                         }
174                                 }
175                         }
176
177                         if (devid == -1)
178                                 throw new Exception("Cannot find a suitable device");
179
180                         computeQueueIndex = queueid;
181                         physicalDevice = devs.getAtIndex(devid);
182
183                         FloatArray qpri = FloatArray.create(frame, 0.0f);
184                         VkDeviceQueueCreateInfo qinfo = VkDeviceQueueCreateInfo.create(
185                                 frame,
186                                 0,
187                                 queueid,
188                                 qpri);
189                         VkDeviceCreateInfo devinfo = VkDeviceCreateInfo.create(
190                                 frame,
191                                 0,
192                                 qinfo,
193                                 null,
194                                 null,
195                                 null);
196
197                         device = physicalDevice.vkCreateDevice(devinfo, null, scope);
198
199                         System.out.printf("device = %s\n", device.address());
200
201                         // NOTE: app scope
202                         deviceMemoryProperties = VkPhysicalDeviceMemoryProperties.create((SegmentAllocator)scope);
203                         physicalDevice.vkGetPhysicalDeviceMemoryProperties(deviceMemoryProperties);
204
205                         computeQueue = device.vkGetDeviceQueue(queueid, 0, scope);
206                 }
207         }
208
209         /**
210          * Buffers are created in three steps:
211          * 1) create buffer, specifying usage and size
212          * 2) allocate memory based on memory requirements
213          * 3) bind memory
214          *
215          */
216         BufferMemory init_buffer(long dataSize, int usage, int properties) throws Exception {
217                 try (Frame frame = Frame.frame()) {
218                         VkMemoryRequirements req = VkMemoryRequirements.create(frame);
219                         VkBufferCreateInfo buf_info = VkBufferCreateInfo.create(frame,
220                                 0,
221                                 dataSize,
222                                 usage,
223                                 VK_SHARING_MODE_EXCLUSIVE,
224                                 null);
225
226                         VkBuffer buffer = device.vkCreateBuffer(buf_info, null, scope);
227
228                         device.vkGetBufferMemoryRequirements(buffer, req);
229
230                         VkMemoryAllocateInfo alloc = VkMemoryAllocateInfo.create(frame,
231                                 req.getSize(),
232                                 find_memory_type(deviceMemoryProperties, req.getMemoryTypeBits(), properties));
233
234                         VkDeviceMemory memory = device.vkAllocateMemory(alloc, null, scope);
235
236                         device.vkBindBufferMemory(buffer, memory, 0);
237
238                         return new BufferMemory(buffer, memory);
239                 }
240         }
241
242         /**
243          * Descriptors are used to bind and describe memory blocks
244          * to shaders.
245          *
246          * *Pool is used to allocate descriptors, it is per-device.
247          * *Layout is used to group descriptors for a given pipeline,
248          * The descriptors describe individually-addressable blocks.
249          */
250         void init_descriptor() throws Exception {
251                 try (Frame frame = Frame.frame()) {
252                         /* Create descriptorset layout */
253                         VkDescriptorSetLayoutBinding layout_binding = VkDescriptorSetLayoutBinding.create(frame,
254                                 0,
255                                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
256                                 1,
257                                 VK_SHADER_STAGE_COMPUTE_BIT,
258                                 null);
259
260                         VkDescriptorSetLayoutCreateInfo descriptor_layout = VkDescriptorSetLayoutCreateInfo.create(frame,
261                                 0,
262                                 layout_binding);
263
264                         descriptorSetLayout = device.vkCreateDescriptorSetLayout(descriptor_layout, null, scope);
265
266                         /* Create descriptor pool */
267                         VkDescriptorPoolSize type_count = VkDescriptorPoolSize.create(frame,
268                                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
269                                 1);
270
271                         VkDescriptorPoolCreateInfo descriptor_pool = VkDescriptorPoolCreateInfo.create(frame,
272                                 0,
273                                 1,
274                                 type_count);
275
276                         descriptorPool = device.vkCreateDescriptorPool(descriptor_pool, null, scope);
277
278                         /* Allocate from pool */
279                         HandleArray<VkDescriptorSetLayout> layout_table = VkDescriptorSetLayout.createArray(1, frame);
280
281                         layout_table.setAtIndex(0, descriptorSetLayout);
282
283                         VkDescriptorSetAllocateInfo alloc_info = VkDescriptorSetAllocateInfo.create(frame,
284                                 descriptorPool,
285                                 layout_table);
286
287                         descriptorSets = device.vkAllocateDescriptorSets(alloc_info, (SegmentAllocator)scope);
288
289                         /* Bind a buffer to the descriptor */
290                         VkDescriptorBufferInfo bufferInfo = VkDescriptorBufferInfo.create(frame,
291                                 dst.buffer,
292                                 0,
293                                 dstBufferSize);
294
295                         VkWriteDescriptorSet writeSet = VkWriteDescriptorSet.create(frame,
296                                 descriptorSets.getAtIndex(0),
297                                 0,
298                                 0,
299                                 1,
300                                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
301                                 null,
302                                 bufferInfo,
303                                 null);
304
305                         System.out.println(writeSet);
306
307                         device.vkUpdateDescriptorSets(writeSet, null);
308                 }
309         }
310
311         /**
312          * Create the compute pipeline.  This is the shader and data layouts for it.
313          */
314         void init_pipeline() throws Exception {
315                 try (Frame frame = Frame.frame()) {
316                         /* Set shader code */
317                         VkShaderModuleCreateInfo vsInfo = VkShaderModuleCreateInfo.create(frame,
318                                 0,
319                                 mandelbrot_cs.length() * 4,
320                                 mandelbrot_cs);
321
322                         mandelbrotShader = device.vkCreateShaderModule(vsInfo, null, scope);
323
324                         /* Link shader to layout */
325                         HandleArray<VkDescriptorSetLayout> layout_table = VkDescriptorSetLayout.createArray(1, frame);
326
327                         layout_table.setAtIndex(0, descriptorSetLayout);
328
329                         VkPipelineLayoutCreateInfo pipelineinfo = VkPipelineLayoutCreateInfo.create(frame,
330                                 0,
331                                 layout_table,
332                                 null);
333
334                         pipelineLayout = device.vkCreatePipelineLayout(pipelineinfo, null, scope);
335
336                         /* Create pipeline */
337                         VkComputePipelineCreateInfo pipeline = VkComputePipelineCreateInfo.create(frame,
338                                 0,
339                                 pipelineLayout,
340                                 null,
341                                 0);
342
343                         VkPipelineShaderStageCreateInfo stage = pipeline.getStage();
344
345                         stage.setStage(VK_SHADER_STAGE_COMPUTE_BIT);
346                         stage.setModule(mandelbrotShader);
347                         stage.setName(mandelbrot_entry, frame);
348
349                         device.vkCreateComputePipelines(null, pipeline, null, computePipeline);
350                 }
351         }
352
353         /**
354          * Create a command buffer, this is somewhat like a display list.
355          */
356         void init_command_buffer() throws Exception {
357                 try (Frame frame = Frame.frame()) {
358                         VkCommandPoolCreateInfo poolinfo = VkCommandPoolCreateInfo.create(frame,
359                                 0,
360                                 computeQueueIndex);
361
362                         commandPool = device.vkCreateCommandPool(poolinfo, null, scope);
363
364                         VkCommandBufferAllocateInfo cmdinfo = VkCommandBufferAllocateInfo.create(frame,
365                                 commandPool,
366                                 VK_COMMAND_BUFFER_LEVEL_PRIMARY,
367                                 1);
368
369                         // should it take a scope?
370                         commandBuffers = device.vkAllocateCommandBuffers(cmdinfo, (SegmentAllocator)scope, scope);
371
372                         /* Fill command buffer with commands for later operation */
373                         VkCommandBufferBeginInfo beginInfo = VkCommandBufferBeginInfo.create(frame,
374                                 VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT,
375                                 null);
376
377                         commandBuffers.get(0).vkBeginCommandBuffer(beginInfo);
378
379                         /* Bind the compute operation and data */
380                         commandBuffers.get(0).vkCmdBindPipeline(VK_PIPELINE_BIND_POINT_COMPUTE, computePipeline.get(0));
381                         commandBuffers.get(0).vkCmdBindDescriptorSets(VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout, 0, descriptorSets, null);
382
383                         /* Run it */
384                         commandBuffers.get(0).vkCmdDispatch(WIDTH, HEIGHT, 1);
385
386                         commandBuffers.get(0).vkEndCommandBuffer();
387                 }
388         }
389
390         /**
391          * Execute the pre-created command buffer.
392          *
393          * A fence is used to wait for completion.
394          */
395         void execute() throws Exception {
396                 try (Frame frame = Frame.frame()) {
397                         VkSubmitInfo submitInfo = VkSubmitInfo.create(frame);
398
399                         submitInfo.setCommandBufferCount(1);
400                         submitInfo.setCommandBuffers(commandBuffers);
401
402                         /* Create fence to mark the task completion */
403                         VkFence fence;
404                         HandleArray<VkFence> fences = VkFence.createArray(1, frame);
405                         VkFenceCreateInfo fenceInfo = VkFenceCreateInfo.create(frame);
406
407                         // maybe this should take a HandleArray<Fence> rather than being a constructor
408                         // FIXME: some local scope
409                         fence = device.vkCreateFence(fenceInfo, null, scope);
410                         fences.set(0, fence);
411
412                         /* Await completion */
413                         computeQueue.vkQueueSubmit(submitInfo, fence);
414
415                         int VK_TRUE = 1;
416                         int res;
417                         do {
418                                 res = device.vkWaitForFences(fences, VK_TRUE, 1000000);
419                         } while (res == VK_TIMEOUT);
420
421                         device.vkDestroyFence(fence, null);
422                 }
423         }
424
425         void shutdown() {
426                 device.vkDestroyCommandPool(commandPool, null);
427                 device.vkDestroyPipeline(computePipeline.getAtIndex(0), null);
428                 device.vkDestroyPipelineLayout(pipelineLayout, null);
429                 device.vkDestroyShaderModule(mandelbrotShader, null);
430
431                 device.vkDestroyDescriptorPool(descriptorPool, null);
432                 device.vkDestroyDescriptorSetLayout(descriptorSetLayout, null);
433
434                 device.vkFreeMemory(dst.memory(), null);
435                 device.vkDestroyBuffer(dst.buffer(), null);
436
437                 device.vkDestroyDevice(null);
438                 if (logger != null)
439                         instance.vkDestroyDebugUtilsMessengerEXT(logger, null);
440                 instance.vkDestroyInstance(null);
441         }
442
443         /**
444          * Accesses the gpu buffer, converts it to RGB byte, and saves it as a pam file.
445          */
446         void save_result() throws Exception {
447                 try (ResourceScope scope = ResourceScope.newConfinedScope()) {
448                         MemorySegment mem = device.vkMapMemory(dst.memory(), 0, dstBufferSize, 0, scope);
449                         byte[] pixels = new byte[WIDTH * HEIGHT * 3];
450
451                         System.out.printf("map %d bytes\n", dstBufferSize);
452
453                         for (int i = 0; i < WIDTH * HEIGHT; i++) {
454                                 pixels[i * 3 + 0] = mem.get(Memory.BYTE, i * 4 + 0);
455                                 pixels[i * 3 + 1] = mem.get(Memory.BYTE, i * 4 + 1);
456                                 pixels[i * 3 + 2] = mem.get(Memory.BYTE, i * 4 + 2);
457                         }
458
459                         device.vkUnmapMemory(dst.memory());
460
461                         pam_save("mandelbrot.pam", WIDTH, HEIGHT, 3, pixels);
462                 }
463         }
464
465         void show_result() throws Exception {
466                 try (ResourceScope scope = ResourceScope.newConfinedScope()) {
467                         MemorySegment mem = device.vkMapMemory(dst.memory(), 0, dstBufferSize, 0, scope);
468                         int[] pixels = new int[WIDTH * HEIGHT];
469
470                         System.out.printf("map %d bytes\n", dstBufferSize);
471
472                         MemorySegment.ofArray(pixels).copyFrom(mem);
473
474                         device.vkUnmapMemory(dst.memory());
475
476                         swing_show(WIDTH, HEIGHT, pixels);
477                 }
478         }
479
480         /**
481          * Trivial pnm format image output.
482          */
483         void pam_save(String name, int width, int height, int depth, byte[] pixels) throws IOException {
484                 try (FileOutputStream fos = new FileOutputStream(name)) {
485                         fos.write(String.format("P6\n%d\n%d\n255\n", width, height).getBytes());
486                         fos.write(pixels);
487                         System.out.printf("wrote: %s\n", name);
488                 }
489         }
490
491         static class DataImage extends JPanel {
492
493                 final int w, h, stride;
494                 final MemoryImageSource source;
495                 final Image image;
496                 final int[] pixels;
497
498                 public DataImage(int w, int h, int[] pixels) {
499                         this.w = w;
500                         this.h = h;
501                         this.stride = w;
502                         this.pixels = pixels;
503                         this.source = new MemoryImageSource(w, h, pixels, 0, w);
504                         this.source.setAnimated(true);
505                         this.source.setFullBufferUpdates(true);
506                         this.image = Toolkit.getDefaultToolkit().createImage(source);
507                 }
508
509                 @Override
510                 protected void paintComponent(Graphics g) {
511                         super.paintComponent(g);
512                         g.drawImage(image, 0, 0, this);
513                 }
514         }
515
516         void swing_show(int w, int h, int[] pixels) {
517                 JFrame window;
518                 DataImage image = new DataImage(w, h, pixels);
519
520                 window = new JFrame("mandelbrot");
521                 window.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
522                 window.setContentPane(image);
523                 window.setSize(w, h);
524                 window.setVisible(true);
525         }
526
527         IntArray loadSPIRV0(String name) throws IOException {
528                 // hmm any way to just load this directly?
529                 try (InputStream is = TestMandelbrot.class.getResourceAsStream(name)) {
530                         ByteBuffer bb = ByteBuffer.allocateDirect(8192).order(ByteOrder.nativeOrder());
531                         int length = Channels.newChannel(is).read(bb);
532
533                         bb.position(0);
534                         bb.limit(length);
535
536                         return IntArray.create(MemorySegment.ofByteBuffer(bb));
537                 }
538         }
539
540         IntArray loadSPIRV(String name) throws IOException {
541                 try (InputStream is = TestMandelbrot.class.getResourceAsStream(name)) {
542                         MemorySegment seg = ((SegmentAllocator)scope).allocateArray(Memory.INT, 2048);
543                         int length = Channels.newChannel(is).read(seg.asByteBuffer());
544
545                         return IntArray.create(seg.asSlice(0, length));
546                 }
547         }
548
549         /**
550          * This finds the memory type index for the memory on a specific device.
551          */
552         static int find_memory_type(VkPhysicalDeviceMemoryProperties memory, int typeMask, int query) {
553                 VkMemoryType mtypes = memory.getMemoryTypes();
554
555                 for (int i = 0; i < memory.getMemoryTypeCount(); i++) {
556                         if (((1 << i) & typeMask) != 0 && ((mtypes.getAtIndex(i).getPropertyFlags() & query) == query))
557                                 return i;
558                 }
559                 return -1;
560         }
561
562         public static int VK_MAKE_API_VERSION(int variant, int major, int minor, int patch) {
563                 return (variant << 29) | (major << 22) | (minor << 12) | patch;
564         }
565
566         void demo() throws Exception {
567                 mandelbrot_cs = loadSPIRV("mandelbrot.bin");
568
569                 init_instance();
570                 init_debug();
571
572                 init_device();
573
574                 dst = init_buffer(dstBufferSize,
575                         VK_BUFFER_USAGE_STORAGE_BUFFER_BIT,
576                         VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
577
578                 init_descriptor();
579
580                 init_pipeline();
581                 init_command_buffer();
582
583                 System.out.printf("Calculating %dx%d\n", WIDTH, HEIGHT);
584                 execute();
585                 //System.out.println("Saving ...");
586                 //save_result();
587                 System.out.println("Showing ...");
588                 show_result();
589                 System.out.println("Done.");
590
591                 shutdown();
592         }
593
594
595         public static void main(String[] args) throws Throwable {
596                 System.loadLibrary("vulkan");
597
598                 new TestMandelbrot().demo();
599         }
600 }