Add upcall support for notzed.vulkan.
[panamaz] / src / notzed.vulkan.test / classes / vulkan / test / TestMandelbrot.java
1  /*
2 The MIT License (MIT)
3
4 Copyright (C) 2017 Eric Arneb├Ąck
5 Copyright (C) 2019 Michael Zucchi
6
7 Permission is hereby granted, free of charge, to any person obtaining a copy
8 of this software and associated documentation files (the "Software"), to deal
9 in the Software without restriction, including without limitation the rights
10 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 copies of the Software, and to permit persons to whom the Software is
12 furnished to do so, subject to the following conditions:
13
14 The above copyright notice and this permission notice shall be included in
15 all copies or substantial portions of the Software.
16
17 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 THE SOFTWARE.
24
25  */
26
27 /*
28  * This is a Java conversion of a C conversion of this:
29  * https://github.com/Erkaman/vulkan_minimal_compute
30  *
31  * It's been simplified a bit and converted to the 'zvk' api.
32  */
33
34 package vulkan.test;
35
36 import java.io.InputStream;
37 import java.io.FileOutputStream;
38 import java.io.IOException;
39 import java.nio.channels.Channels;
40 import java.nio.ByteBuffer;
41 import java.nio.ByteOrder;
42
43 import java.awt.Graphics;
44 import java.awt.Image;
45 import java.awt.Toolkit;
46 import java.awt.event.ActionEvent;
47 import java.awt.event.KeyEvent;
48 import java.awt.image.MemoryImageSource;
49 import javax.swing.AbstractAction;
50 import javax.swing.JComponent;
51 import javax.swing.JFrame;
52 import javax.swing.JPanel;
53 import javax.swing.KeyStroke;
54
55 import java.lang.ref.WeakReference;
56
57 import java.lang.invoke.*;
58 import jdk.incubator.foreign.*;
59 import jdk.incubator.foreign.MemoryLayout.PathElement;
60 import au.notzed.nativez.*;
61
62 import vulkan.*;
63
64 import static vulkan.VkConstants.*;
65
66 public class TestMandelbrot {
67         static final boolean debug = true;
68         ResourceScope scope = ResourceScope.newSharedScope();
69
70         int WIDTH = 1920*1;
71         int HEIGHT = 1080*1;
72
73         VkInstance instance;
74         VkPhysicalDevice physicalDevice;
75
76         VkDevice device;
77         VkQueue computeQueue;
78
79         long dstBufferSize = WIDTH * HEIGHT * 4;
80         //VkBuffer dstBuffer;
81         //VkDeviceMemory dstMemory;
82         BufferMemory dst;
83
84         VkDescriptorSetLayout descriptorSetLayout;
85         VkDescriptorPool descriptorPool;
86         HandleArray<VkDescriptorSet> descriptorSets;
87
88         int computeQueueIndex;
89         VkPhysicalDeviceMemoryProperties deviceMemoryProperties;
90
91         String mandelbrot_entry = "main";
92         IntArray mandelbrot_cs;
93
94         VkShaderModule mandelbrotShader;
95         VkPipelineLayout pipelineLayout;
96         HandleArray<VkPipeline> computePipeline = VkPipeline.createArray(1, (SegmentAllocator)scope);
97
98         VkCommandPool commandPool;
99         HandleArray<VkCommandBuffer> commandBuffers;
100
101         record BufferMemory ( VkBuffer buffer, VkDeviceMemory memory ) {};
102
103         VkDebugUtilsMessengerEXT logger;
104
105         void init_debug() throws Exception {
106                 if (!debug)
107                         return;
108                 try (Frame frame = Frame.frame()) {
109                         var cb = PFN_vkDebugUtilsMessengerCallbackEXT.upcall((severity, flags, data) -> {
110                                         System.out.printf("Debug: %d: %s\n", severity, data.getMessage());
111                                         return 0;
112                                 }, scope);
113                         VkDebugUtilsMessengerCreateInfoEXT info = VkDebugUtilsMessengerCreateInfoEXT.create(frame,
114                                 0,
115                                 //VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT |
116                                 VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT
117                                 | VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT
118                                 | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT,
119                                 VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT
120                                 | VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT
121                                 | VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT,
122                                 cb,
123                                 null);
124
125                         logger = instance.vkCreateDebugUtilsMessengerEXT(info, scope);
126                 }
127         }
128
129         void init_instance() throws Exception {
130                 try (Frame frame = Frame.frame()) {
131                         VkInstanceCreateInfo info = VkInstanceCreateInfo.create(frame,
132                                 0,
133                                 VkApplicationInfo.create(frame, "test", 1, "test-engine", 2, VK_MAKE_API_VERSION(0, 1, 0, 0)),
134                                 new String[] { "VK_LAYER_KHRONOS_validation" },
135                                 debug ? new String[] { "VK_EXT_debug_utils" } : null
136                                 );
137
138                         instance = VkInstance.vkCreateInstance(info, scope);
139                 }
140         }
141
142         void init_device() throws Exception {
143                 try (Frame frame = Frame.frame()) {
144                         HandleArray<VkPhysicalDevice> devs;
145                         int count;
146                         int res;
147
148                         devs = instance.vkEnumeratePhysicalDevices(frame, scope);
149
150                         int best = 0;
151                         int devid = -1;
152                         int queueid = -1;
153
154                         for (int i=0;i<devs.length();i++) {
155                                 VkPhysicalDevice dev = devs.getAtIndex(i);
156                                 VkQueueFamilyProperties famprops = dev.vkGetPhysicalDeviceQueueFamilyProperties(frame);
157                                 int family_count = (int)famprops.length();
158
159                                 for (int j=0;j<family_count;j++) {
160                                         var flags = famprops.getAtIndex(j).getQueueFlags();
161                                         int score = 0;
162
163                                         if ((flags & VK_QUEUE_COMPUTE_BIT) != 0)
164                                                 score += 1;
165                                         if ((flags & VK_QUEUE_GRAPHICS_BIT) == 0)
166                                                 score += 1;
167
168                                         if (score > best) {
169                                                 score = best;
170                                                 devid = i;
171                                                 queueid = j;
172                                         }
173                                 }
174                         }
175
176                         if (devid == -1)
177                                 throw new Exception("Cannot find a suitable device");
178
179                         computeQueueIndex = queueid;
180                         physicalDevice = devs.getAtIndex(devid);
181
182                         FloatArray qpri = FloatArray.create(frame, 0.0f);
183                         VkDeviceQueueCreateInfo qinfo = VkDeviceQueueCreateInfo.create(
184                                 frame,
185                                 0,
186                                 queueid,
187                                 qpri);
188                         VkDeviceCreateInfo devinfo = VkDeviceCreateInfo.create(
189                                 frame,
190                                 0,
191                                 qinfo,
192                                 null,
193                                 null,
194                                 null);
195
196                         device = physicalDevice.vkCreateDevice(devinfo, scope);
197
198                         System.out.printf("device = %s\n", device.address());
199
200                         // NOTE: app scope
201                         deviceMemoryProperties = VkPhysicalDeviceMemoryProperties.create((SegmentAllocator)scope);
202                         physicalDevice.vkGetPhysicalDeviceMemoryProperties(deviceMemoryProperties);
203
204                         computeQueue = device.vkGetDeviceQueue(queueid, 0, scope);
205                 }
206         }
207
208         /**
209          * Buffers are created in three steps:
210          * 1) create buffer, specifying usage and size
211          * 2) allocate memory based on memory requirements
212          * 3) bind memory
213          *
214          */
215         BufferMemory init_buffer(long dataSize, int usage, int properties) throws Exception {
216                 try (Frame frame = Frame.frame()) {
217                         VkMemoryRequirements req = VkMemoryRequirements.create(frame);
218                         VkBufferCreateInfo buf_info = VkBufferCreateInfo.create(frame,
219                                 0,
220                                 dataSize,
221                                 usage,
222                                 VK_SHARING_MODE_EXCLUSIVE,
223                                 null);
224
225                         VkBuffer buffer = device.vkCreateBuffer(buf_info, scope);
226
227                         device.vkGetBufferMemoryRequirements(buffer, req);
228
229                         VkMemoryAllocateInfo alloc = VkMemoryAllocateInfo.create(frame,
230                                 req.getSize(),
231                                 find_memory_type(deviceMemoryProperties, req.getMemoryTypeBits(), properties));
232
233                         VkDeviceMemory memory = device.vkAllocateMemory(alloc, scope);
234
235                         device.vkBindBufferMemory(buffer, memory, 0);
236
237                         return new BufferMemory(buffer, memory);
238                 }
239         }
240
241         /**
242          * Descriptors are used to bind and describe memory blocks
243          * to shaders.
244          *
245          * *Pool is used to allocate descriptors, it is per-device.
246          * *Layout is used to group descriptors for a given pipeline,
247          * The descriptors describe individually-addressable blocks.
248          */
249         void init_descriptor() throws Exception {
250                 try (Frame frame = Frame.frame()) {
251                         /* Create descriptorset layout */
252                         VkDescriptorSetLayoutBinding layout_binding = VkDescriptorSetLayoutBinding.create(frame,
253                                 0,
254                                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
255                                 1,
256                                 VK_SHADER_STAGE_COMPUTE_BIT,
257                                 null);
258
259                         VkDescriptorSetLayoutCreateInfo descriptor_layout = VkDescriptorSetLayoutCreateInfo.create(frame,
260                                 0,
261                                 layout_binding);
262
263                         descriptorSetLayout = device.vkCreateDescriptorSetLayout(descriptor_layout, scope);
264
265                         /* Create descriptor pool */
266                         VkDescriptorPoolSize type_count = VkDescriptorPoolSize.create(frame,
267                                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
268                                 1);
269
270                         VkDescriptorPoolCreateInfo descriptor_pool = VkDescriptorPoolCreateInfo.create(frame,
271                                 0,
272                                 1,
273                                 type_count);
274
275                         descriptorPool = device.vkCreateDescriptorPool(descriptor_pool, scope);
276
277                         /* Allocate from pool */
278                         HandleArray<VkDescriptorSetLayout> layout_table = VkDescriptorSetLayout.createArray(1, frame);
279
280                         layout_table.setAtIndex(0, descriptorSetLayout);
281
282                         VkDescriptorSetAllocateInfo alloc_info = VkDescriptorSetAllocateInfo.create(frame,
283                                 descriptorPool,
284                                 layout_table);
285
286                         descriptorSets = device.vkAllocateDescriptorSets(alloc_info, (SegmentAllocator)scope);
287
288                         /* Bind a buffer to the descriptor */
289                         VkDescriptorBufferInfo bufferInfo = VkDescriptorBufferInfo.create(frame,
290                                 dst.buffer,
291                                 0,
292                                 dstBufferSize);
293
294                         VkWriteDescriptorSet writeSet = VkWriteDescriptorSet.create(frame,
295                                 descriptorSets.getAtIndex(0),
296                                 0,
297                                 0,
298                                 1,
299                                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
300                                 null,
301                                 bufferInfo,
302                                 null);
303
304                         System.out.println(writeSet);
305
306                         device.vkUpdateDescriptorSets(writeSet, null);
307                 }
308         }
309
310         /**
311          * Create the compute pipeline.  This is the shader and data layouts for it.
312          */
313         void init_pipeline() throws Exception {
314                 try (Frame frame = Frame.frame()) {
315                         /* Set shader code */
316                         VkShaderModuleCreateInfo vsInfo = VkShaderModuleCreateInfo.create(frame,
317                                 0,
318                                 mandelbrot_cs.length() * 4,
319                                 mandelbrot_cs);
320
321                         mandelbrotShader = device.vkCreateShaderModule(vsInfo, scope);
322
323                         /* Link shader to layout */
324                         HandleArray<VkDescriptorSetLayout> layout_table = VkDescriptorSetLayout.createArray(1, frame);
325
326                         layout_table.setAtIndex(0, descriptorSetLayout);
327
328                         VkPipelineLayoutCreateInfo pipelineinfo = VkPipelineLayoutCreateInfo.create(frame,
329                                 0,
330                                 layout_table,
331                                 null);
332
333                         pipelineLayout = device.vkCreatePipelineLayout(pipelineinfo, scope);
334
335                         /* Create pipeline */
336                         VkComputePipelineCreateInfo pipeline = VkComputePipelineCreateInfo.create(frame,
337                                 0,
338                                 pipelineLayout,
339                                 null,
340                                 0);
341
342                         VkPipelineShaderStageCreateInfo stage = pipeline.getStage();
343
344                         stage.setStage(VK_SHADER_STAGE_COMPUTE_BIT);
345                         stage.setModule(mandelbrotShader);
346                         stage.setName(mandelbrot_entry, frame);
347
348                         device.vkCreateComputePipelines(null, pipeline, computePipeline);
349                 }
350         }
351
352         /**
353          * Create a command buffer, this is somewhat like a display list.
354          */
355         void init_command_buffer() throws Exception {
356                 try (Frame frame = Frame.frame()) {
357                         VkCommandPoolCreateInfo poolinfo = VkCommandPoolCreateInfo.create(frame,
358                                 0,
359                                 computeQueueIndex);
360
361                         commandPool = device.vkCreateCommandPool(poolinfo, scope);
362
363                         VkCommandBufferAllocateInfo cmdinfo = VkCommandBufferAllocateInfo.create(frame,
364                                 commandPool,
365                                 VK_COMMAND_BUFFER_LEVEL_PRIMARY,
366                                 1);
367
368                         // should it take a scope?
369                         commandBuffers = device.vkAllocateCommandBuffers(cmdinfo, (SegmentAllocator)scope, scope);
370
371                         /* Fill command buffer with commands for later operation */
372                         VkCommandBufferBeginInfo beginInfo = VkCommandBufferBeginInfo.create(frame,
373                                 VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT,
374                                 null);
375
376                         commandBuffers.get(0).vkBeginCommandBuffer(beginInfo);
377
378                         /* Bind the compute operation and data */
379                         commandBuffers.get(0).vkCmdBindPipeline(VK_PIPELINE_BIND_POINT_COMPUTE, computePipeline.get(0));
380                         commandBuffers.get(0).vkCmdBindDescriptorSets(VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout, 0, descriptorSets, null);
381
382                         /* Run it */
383                         commandBuffers.get(0).vkCmdDispatch(WIDTH, HEIGHT, 1);
384
385                         commandBuffers.get(0).vkEndCommandBuffer();
386                 }
387         }
388
389         /**
390          * Execute the pre-created command buffer.
391          *
392          * A fence is used to wait for completion.
393          */
394         void execute() throws Exception {
395                 try (Frame frame = Frame.frame()) {
396                         VkSubmitInfo submitInfo = VkSubmitInfo.create(frame);
397
398                         submitInfo.setCommandBufferCount(1);
399                         submitInfo.setCommandBuffers(commandBuffers);
400
401                         /* Create fence to mark the task completion */
402                         VkFence fence;
403                         HandleArray<VkFence> fences = VkFence.createArray(1, frame);
404                         VkFenceCreateInfo fenceInfo = VkFenceCreateInfo.create(frame);
405
406                         // maybe this should take a HandleArray<Fence> rather than being a constructor
407                         // FIXME: some local scope
408                         fence = device.vkCreateFence(fenceInfo, scope);
409                         fences.set(0, fence);
410
411                         /* Await completion */
412                         computeQueue.vkQueueSubmit(submitInfo, fence);
413
414                         int VK_TRUE = 1;
415                         int res;
416                         do {
417                                 res = device.vkWaitForFences(fences, VK_TRUE, 1000000);
418                         } while (res == VK_TIMEOUT);
419
420                         device.vkDestroyFence(fence);
421                 }
422         }
423
424         void shutdown() {
425                 device.vkDestroyCommandPool(commandPool);
426                 device.vkDestroyPipeline(computePipeline.getAtIndex(0));
427                 device.vkDestroyPipelineLayout(pipelineLayout);
428                 device.vkDestroyShaderModule(mandelbrotShader);
429
430                 device.vkDestroyDescriptorPool(descriptorPool);
431                 device.vkDestroyDescriptorSetLayout(descriptorSetLayout);
432
433                 device.vkFreeMemory(dst.memory());
434                 device.vkDestroyBuffer(dst.buffer());
435
436                 device.vkDestroyDevice();
437                 if (logger != null)
438                         instance.vkDestroyDebugUtilsMessengerEXT(logger);
439                 instance.vkDestroyInstance();
440         }
441
442         /**
443          * Accesses the gpu buffer, converts it to RGB byte, and saves it as a pam file.
444          */
445         void save_result() throws Exception {
446                 try (ResourceScope scope = ResourceScope.newConfinedScope()) {
447                         MemorySegment mem = device.vkMapMemory(dst.memory(), 0, dstBufferSize, 0, scope);
448                         byte[] pixels = new byte[WIDTH * HEIGHT * 3];
449
450                         System.out.printf("map %d bytes\n", dstBufferSize);
451
452                         for (int i = 0; i < WIDTH * HEIGHT; i++) {
453                                 pixels[i * 3 + 0] = mem.get(Memory.BYTE, i * 4 + 0);
454                                 pixels[i * 3 + 1] = mem.get(Memory.BYTE, i * 4 + 1);
455                                 pixels[i * 3 + 2] = mem.get(Memory.BYTE, i * 4 + 2);
456                         }
457
458                         device.vkUnmapMemory(dst.memory());
459
460                         pam_save("mandelbrot.pam", WIDTH, HEIGHT, 3, pixels);
461                 }
462         }
463
464         void show_result() throws Exception {
465                 try (ResourceScope scope = ResourceScope.newConfinedScope()) {
466                         MemorySegment mem = device.vkMapMemory(dst.memory(), 0, dstBufferSize, 0, scope);
467                         int[] pixels = new int[WIDTH * HEIGHT];
468
469                         System.out.printf("map %d bytes\n", dstBufferSize);
470
471                         MemorySegment.ofArray(pixels).copyFrom(mem);
472
473                         device.vkUnmapMemory(dst.memory());
474
475                         swing_show(WIDTH, HEIGHT, pixels);
476                 }
477         }
478
479         /**
480          * Trivial pnm format image output.
481          */
482         void pam_save(String name, int width, int height, int depth, byte[] pixels) throws IOException {
483                 try (FileOutputStream fos = new FileOutputStream(name)) {
484                         fos.write(String.format("P6\n%d\n%d\n255\n", width, height).getBytes());
485                         fos.write(pixels);
486                         System.out.printf("wrote: %s\n", name);
487                 }
488         }
489
490         static class DataImage extends JPanel {
491
492                 final int w, h, stride;
493                 final MemoryImageSource source;
494                 final Image image;
495                 final int[] pixels;
496
497                 public DataImage(int w, int h, int[] pixels) {
498                         this.w = w;
499                         this.h = h;
500                         this.stride = w;
501                         this.pixels = pixels;
502                         this.source = new MemoryImageSource(w, h, pixels, 0, w);
503                         this.source.setAnimated(true);
504                         this.source.setFullBufferUpdates(true);
505                         this.image = Toolkit.getDefaultToolkit().createImage(source);
506                 }
507
508                 @Override
509                 protected void paintComponent(Graphics g) {
510                         super.paintComponent(g);
511                         g.drawImage(image, 0, 0, this);
512                 }
513         }
514
515         void swing_show(int w, int h, int[] pixels) {
516                 JFrame window;
517                 DataImage image = new DataImage(w, h, pixels);
518
519                 window = new JFrame("mandelbrot");
520                 window.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
521                 window.setContentPane(image);
522                 window.setSize(w, h);
523                 window.setVisible(true);
524         }
525
526         IntArray loadSPIRV0(String name) throws IOException {
527                 // hmm any way to just load this directly?
528                 try (InputStream is = TestMandelbrot.class.getResourceAsStream(name)) {
529                         ByteBuffer bb = ByteBuffer.allocateDirect(8192).order(ByteOrder.nativeOrder());
530                         int length = Channels.newChannel(is).read(bb);
531
532                         bb.position(0);
533                         bb.limit(length);
534
535                         return IntArray.create(MemorySegment.ofByteBuffer(bb));
536                 }
537         }
538
539         IntArray loadSPIRV(String name) throws IOException {
540                 try (InputStream is = TestMandelbrot.class.getResourceAsStream(name)) {
541                         MemorySegment seg = ((SegmentAllocator)scope).allocateArray(Memory.INT, 2048);
542                         int length = Channels.newChannel(is).read(seg.asByteBuffer());
543
544                         return IntArray.create(seg.asSlice(0, length));
545                 }
546         }
547
548         /**
549          * This finds the memory type index for the memory on a specific device.
550          */
551         static int find_memory_type(VkPhysicalDeviceMemoryProperties memory, int typeMask, int query) {
552                 VkMemoryType mtypes = memory.getMemoryTypes();
553
554                 for (int i = 0; i < memory.getMemoryTypeCount(); i++) {
555                         if (((1 << i) & typeMask) != 0 && ((mtypes.getAtIndex(i).getPropertyFlags() & query) == query))
556                                 return i;
557                 }
558                 return -1;
559         }
560
561         public static int VK_MAKE_API_VERSION(int variant, int major, int minor, int patch) {
562                 return (variant << 29) | (major << 22) | (minor << 12) | patch;
563         }
564
565         void demo() throws Exception {
566                 mandelbrot_cs = loadSPIRV("mandelbrot.bin");
567
568                 init_instance();
569                 init_debug();
570
571                 init_device();
572
573                 dst = init_buffer(dstBufferSize,
574                         VK_BUFFER_USAGE_STORAGE_BUFFER_BIT,
575                         VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
576
577                 init_descriptor();
578
579                 init_pipeline();
580                 init_command_buffer();
581
582                 System.out.printf("Calculating %dx%d\n", WIDTH, HEIGHT);
583                 execute();
584                 //System.out.println("Saving ...");
585                 //save_result();
586                 System.out.println("Showing ...");
587                 show_result();
588                 System.out.println("Done.");
589
590                 shutdown();
591         }
592
593
594         public static void main(String[] args) throws Throwable {
595                 System.loadLibrary("vulkan");
596
597                 new TestMandelbrot().demo();
598         }
599 }