Updated for openjdk-19-internal
[panamaz] / test-vulkan / src / zvk / test / TestVulkan.java
1  /*
2 The MIT License (MIT)
3
4 Copyright (C) 2017 Eric Arneb├Ąck
5 Copyright (C) 2019 Michael Zucchi
6
7 Permission is hereby granted, free of charge, to any person obtaining a copy
8 of this software and associated documentation files (the "Software"), to deal
9 in the Software without restriction, including without limitation the rights
10 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 copies of the Software, and to permit persons to whom the Software is
12 furnished to do so, subject to the following conditions:
13
14 The above copyright notice and this permission notice shall be included in
15 all copies or substantial portions of the Software.
16
17 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 THE SOFTWARE.
24
25  */
26
27 /*
28  * This is a Java conversion of a C conversion of this:
29  * https://github.com/Erkaman/vulkan_minimal_compute
30  *
31  * It's been simplified a bit and converted to the 'zvk' api.
32  */
33
34 package zvk.test;
35
36 import java.io.InputStream;
37 import java.io.FileOutputStream;
38 import java.io.IOException;
39 import java.nio.channels.Channels;
40 import java.nio.ByteBuffer;
41 import java.nio.ByteOrder;
42
43 import java.lang.invoke.*;
44 import jdk.incubator.foreign.*;
45 import jdk.incubator.foreign.MemoryLayout.PathElement;
46
47 import zvk.*;
48
49 import static zvk.VkBufferUsageFlagBits.*;
50 import static zvk.VkMemoryPropertyFlagBits.*;
51 import static zvk.VkSharingMode.*;
52 import static zvk.VkDescriptorType.*;
53 import static zvk.VkShaderStageFlagBits.*;
54 import static zvk.VkCommandBufferLevel.*;
55 import static zvk.VkCommandBufferUsageFlagBits.*;
56 import static zvk.VkPipelineBindPoint.*;
57
58 import static zvk.VkDebugUtilsMessageSeverityFlagBitsEXT.*;
59 import static zvk.VkDebugUtilsMessageTypeFlagBitsEXT.*;
60
61 public class TestVulkan {
62         ResourceScope scope = ResourceScope.newSharedScope();
63
64         int WIDTH = 1920*1;
65         int HEIGHT = 1080*1;
66
67         VkInstance instance;
68         VkPhysicalDevice physicalDevice;
69
70         VkDevice device;
71         VkQueue computeQueue;
72
73         long dstBufferSize = WIDTH * HEIGHT * 4 * 4;
74         //VkBuffer dstBuffer;
75         //VkDeviceMemory dstMemory;
76         BufferMemory dst;
77
78         VkDescriptorSetLayout descriptorSetLayout;
79         VkDescriptorPool descriptorPool;
80         Memory.HandleArray<VkDescriptorSet> descriptorSets = VkDescriptorSet.createArray(scope, 1);
81
82         int computeQueueIndex;
83         VkPhysicalDeviceMemoryProperties deviceMemoryProperties;
84
85         String mandelbrot_entry = "main";
86         Memory.IntArray mandelbrot_cs;
87
88         VkShaderModule mandelbrotShader;
89         VkPipelineLayout pipelineLayout;
90         Memory.HandleArray<VkPipeline> computePipeline = VkPipeline.createArray(scope, 1);
91
92         VkCommandPool commandPool;
93         Memory.HandleArray<VkCommandBuffer> commandBuffers;
94
95         record BufferMemory ( VkBuffer buffer, VkDeviceMemory memory ) {};
96
97         VkDebugUtilsMessengerEXT logger;
98
99         void init_debug() throws Exception {
100                 try (Frame frame = Memory.createFrame()) {
101                         NativeSymbol cb = PFN_vkDebugUtilsMessengerCallbackEXT.of((severity, flags, data) -> {
102                                         System.out.printf("Debug: %d: %s\n", severity, data.getMessage());
103                                         return 0;
104                                 }, scope);
105                         VkDebugUtilsMessengerCreateInfoEXT info = VkDebugUtilsMessengerCreateInfoEXT.create(frame,
106                                 0,
107                                 VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT
108                                 | VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT
109                                 | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT,
110                                 VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT,
111                                 cb.address(),
112                                 null);
113
114                         logger = instance.vkCreateDebugUtilsMessengerEXT(info, null);
115                 }
116
117                 //typedef VkBool32 (*PFN_vkDebugUtilsMessengerCallbackEXT)(VkDebugUtilsMessageSeverityFlagBitsEXT, VkDebugUtilsMessageTypeFlagsEXT, const VkDebugUtilsMessengerCallbackDataEXT *, void *);
118
119         }
120
121         void init_instance() throws Exception {
122                 try (Frame frame = Memory.createFrame()) {
123                         VkInstanceCreateInfo info = VkInstanceCreateInfo.create(frame,
124                                 0,
125                                 VkApplicationInfo.create(frame, "test", 1, "test-engine", 2, VK_MAKE_API_VERSION(0, 1, 0, 0)),
126                                 new String[] { "VK_LAYER_KHRONOS_validation" },
127                                 null //new String[] { "VK_EXT_debug_utils" }
128                                 );
129
130                         instance = VkInstance.vkCreateInstance(info, null);
131                 }
132         }
133
134         void init_device() throws Exception {
135                 try (Frame frame = Memory.createFrame()) {
136                         Memory.IntArray count$h = new Memory.IntArray(frame, 1);
137                         Memory.HandleArray<VkPhysicalDevice> devs;
138                         int count;
139                         int res;
140
141                         devs = instance.vkEnumeratePhysicalDevices();
142
143                         int best = 0;
144                         int devid = -1;
145                         int queueid = -1;
146
147                         for (int i=0;i<devs.length();i++) {
148                                 VkPhysicalDevice dev = devs.getAtIndex(i);
149                                 VkQueueFamilyProperties famprops;
150
151                                 // TODO: change to return the allocated array directly
152                                 dev.vkGetPhysicalDeviceQueueFamilyProperties(count$h, null);
153                                 famprops = VkQueueFamilyProperties.createArray(frame, count$h.getAtIndex(0));
154                                 dev.vkGetPhysicalDeviceQueueFamilyProperties(count$h, famprops);
155
156                                 int family_count = count$h.getAtIndex(0);
157
158                                 for (int j=0;j<family_count;j++) {
159                                         int score = 0;
160
161                                         if ((famprops.getQueueFlags(j) & VkQueueFlagBits.VK_QUEUE_COMPUTE_BIT) != 0)
162                                                 score += 1;
163                                         if ((famprops.getQueueFlags(j) & VkQueueFlagBits.VK_QUEUE_GRAPHICS_BIT) == 0)
164                                                 score += 1;
165
166                                         if (score > best) {
167                                                 score = best;
168                                                 devid = i;
169                                                 queueid = j;
170                                         }
171                                 }
172                         }
173
174                         if (devid == -1)
175                                 throw new Exception("Cannot find a suitable device");
176
177                         computeQueueIndex = queueid;
178                         physicalDevice = devs.getAtIndex(devid);
179
180                         Memory.FloatArray qpri = new Memory.FloatArray(frame, 0.0f);
181                         VkDeviceQueueCreateInfo qinfo = VkDeviceQueueCreateInfo.create(
182                                 frame,
183                                 0,
184                                 queueid,
185                                 1,
186                                 qpri);
187                         VkDeviceCreateInfo devinfo = VkDeviceCreateInfo.create(
188                                 frame,
189                                 0,
190                                 1,
191                                 qinfo,
192                                 null,
193                                 null,
194                                 null);
195
196                         device = physicalDevice.vkCreateDevice(devinfo, null);
197
198                         System.out.printf("device = %s\n", device.address());
199
200                         // NOTE: app scope
201                         deviceMemoryProperties = VkPhysicalDeviceMemoryProperties.create(scope);
202                         physicalDevice.vkGetPhysicalDeviceMemoryProperties(deviceMemoryProperties);
203
204                         computeQueue = device.vkGetDeviceQueue(queueid, 0);
205                 }
206         }
207
208         /**
209          * Buffers are created in three steps:
210          * 1) create buffer, specifying usage and size
211          * 2) allocate memory based on memory requirements
212          * 3) bind memory
213          *
214          */
215         BufferMemory init_buffer(long dataSize, int usage, int properties) throws Exception {
216                 try (Frame frame = Memory.createFrame()) {
217                         VkMemoryRequirements req = VkMemoryRequirements.create(frame);
218                         VkBufferCreateInfo buf_info = VkBufferCreateInfo.create(frame,
219                                 0,
220                                 dataSize,
221                                 usage,
222                                 VK_SHARING_MODE_EXCLUSIVE,
223                                 0,
224                                 null);
225
226                         VkBuffer buffer = device.vkCreateBuffer(buf_info, null);
227
228                         device.vkGetBufferMemoryRequirements(buffer, req);
229
230                         VkMemoryAllocateInfo alloc = VkMemoryAllocateInfo.create(frame,
231                                 req.getSize(),
232                                 find_memory_type(deviceMemoryProperties, req.getMemoryTypeBits(), properties));
233
234                         VkDeviceMemory memory = device.vkAllocateMemory(alloc, null);
235
236                         device.vkBindBufferMemory(buffer, memory, 0);
237
238                         return new BufferMemory(buffer, memory);
239                 }
240         }
241
242         /**
243          * Descriptors are used to bind and describe memory blocks
244          * to shaders.
245          *
246          * *Pool is used to allocate descriptors, it is per-device.
247          * *Layout is used to group descriptors for a given pipeline,
248          * The descriptors describe individually-addressable blocks.
249          */
250         void init_descriptor() throws Exception {
251                 try (Frame frame = Memory.createFrame()) {
252                         /* Create descriptorset layout */
253                         VkDescriptorSetLayoutBinding layout_binding = VkDescriptorSetLayoutBinding.create(frame,
254                                 0,
255                                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
256                                 1,
257                                 VK_SHADER_STAGE_COMPUTE_BIT,
258                                 null);
259
260                         VkDescriptorSetLayoutCreateInfo descriptor_layout = VkDescriptorSetLayoutCreateInfo.create(frame,
261                                 0,
262                                 1,
263                                 layout_binding);
264
265                         descriptorSetLayout = device.vkCreateDescriptorSetLayout(descriptor_layout, null);
266
267                         /* Create descriptor pool */
268                         VkDescriptorPoolSize type_count = VkDescriptorPoolSize.create(frame,
269                                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
270                                 1);
271
272                         VkDescriptorPoolCreateInfo descriptor_pool = VkDescriptorPoolCreateInfo.create(frame,
273                                 0,
274                                 1,
275                                 1,
276                                 type_count);
277
278                         descriptorPool = device.vkCreateDescriptorPool(descriptor_pool, null);
279
280                         /* Allocate from pool */
281                         Memory.HandleArray<VkDescriptorSetLayout> layout_table = VkDescriptorSetLayout.createArray(frame, 1);
282
283                         layout_table.setAtIndex(0, descriptorSetLayout);
284
285                         VkDescriptorSetAllocateInfo alloc_info = VkDescriptorSetAllocateInfo.create(frame,
286                                 descriptorPool,
287                                 1,
288                                 layout_table);
289
290                         device.vkAllocateDescriptorSets(alloc_info, descriptorSets);
291
292                         /* Bind a buffer to the descriptor */
293                         VkDescriptorBufferInfo bufferInfo = VkDescriptorBufferInfo.create(frame,
294                                 dst.buffer,
295                                 0,
296                                 dstBufferSize);
297
298                         VkWriteDescriptorSet writeSet = VkWriteDescriptorSet.create(frame,
299                                 descriptorSets.getAtIndex(0),
300                                 0,
301                                 0,
302                                 1,
303                                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
304                                 null,
305                                 bufferInfo,
306                                 null);
307
308                         device.vkUpdateDescriptorSets(1, writeSet, 0, null);
309                 }
310         }
311
312         /**
313          * Create the compute pipeline.  This is the shader and data layouts for it.
314          */
315         void init_pipeline() throws Exception {
316                 try (Frame frame = Memory.createFrame()) {
317                         /* Set shader code */
318                         VkShaderModuleCreateInfo vsInfo = VkShaderModuleCreateInfo.create(frame,
319                                 0,
320                                 mandelbrot_cs.length() * 4,
321                                 mandelbrot_cs);
322
323                         mandelbrotShader = device.vkCreateShaderModule(vsInfo, null);
324
325                         /* Link shader to layout */
326                         Memory.HandleArray<VkDescriptorSetLayout> layout_table = VkDescriptorSetLayout.createArray(frame, 1);
327
328                         layout_table.setAtIndex(0, descriptorSetLayout);
329
330                         VkPipelineLayoutCreateInfo pipelineinfo = VkPipelineLayoutCreateInfo.create(frame,
331                                 0,
332                                 1,
333                                 layout_table,
334                                 0,
335                                 null);
336
337                         pipelineLayout = device.vkCreatePipelineLayout(pipelineinfo, null);
338
339                         /* Create pipeline */
340                         VkComputePipelineCreateInfo pipeline = VkComputePipelineCreateInfo.create(frame,
341                                 0,
342                                 pipelineLayout,
343                                 null,
344                                 0);
345
346                         VkPipelineShaderStageCreateInfo stage = pipeline.getStage();
347
348                         stage.setStage(VK_SHADER_STAGE_COMPUTE_BIT);
349                         stage.setModule(mandelbrotShader);
350                         stage.setName(frame, mandelbrot_entry);
351
352                         device.vkCreateComputePipelines(null, 1, pipeline, null, computePipeline);
353                 }
354         }
355         /**
356          * Create a command buffer, this is somewhat like a display list.
357          */
358         void init_command_buffer() throws Exception {
359                 try (Frame frame = Memory.createFrame()) {
360                         VkCommandPoolCreateInfo poolinfo = VkCommandPoolCreateInfo.create(frame,
361                                 0,
362                                 computeQueueIndex);
363
364                         commandPool = device.vkCreateCommandPool(poolinfo, null);
365
366                         VkCommandBufferAllocateInfo cmdinfo = VkCommandBufferAllocateInfo.create(frame,
367                                 commandPool,
368                                 VK_COMMAND_BUFFER_LEVEL_PRIMARY,
369                                 1);
370
371                         // should it take a scope?
372                         commandBuffers = device.vkAllocateCommandBuffers(cmdinfo);
373
374                         /* Fill command buffer with commands for later operation */
375                         VkCommandBufferBeginInfo beginInfo = VkCommandBufferBeginInfo.create(frame,
376                                 VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT,
377                                 null);
378
379                         commandBuffers.get(0).vkBeginCommandBuffer(beginInfo);
380
381                         /* Bind the compute operation and data */
382                         commandBuffers.get(0).vkCmdBindPipeline(VK_PIPELINE_BIND_POINT_COMPUTE, computePipeline.get(0));
383                         commandBuffers.get(0).vkCmdBindDescriptorSets(VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout, 0, 1, descriptorSets, 0, null);
384
385                         /* Run it */
386                         commandBuffers.get(0).vkCmdDispatch(WIDTH, HEIGHT, 1);
387
388                         commandBuffers.get(0).vkEndCommandBuffer();
389                 }
390         }
391
392         /**
393          * Execute the pre-created command buffer.
394          *
395          * A fence is used to wait for completion.
396          */
397         void execute() throws Exception {
398                 try (Frame frame = Memory.createFrame()) {
399                         VkSubmitInfo submitInfo = VkSubmitInfo.create(frame);
400
401                         submitInfo.setCommandBufferCount(0, 1);
402                         submitInfo.setCommandBuffers(0, commandBuffers);
403
404                         /* Create fence to mark the task completion */
405                         VkFence fence;
406                         Memory.HandleArray<VkFence> fences = VkFence.createArray(frame, 1);
407                         VkFenceCreateInfo fenceInfo = VkFenceCreateInfo.create(frame);
408
409                         // maybe this should take a HandleArray<Fence> rather than being a constructor
410                         fence = device.vkCreateFence(fenceInfo, null);
411                         fences.set(0, fence);
412
413                         /* Await completion */
414                         computeQueue.vkQueueSubmit(1, submitInfo, fence);
415
416                         int VK_TRUE = 1;
417                         int res;
418                         do {
419                                 res = device.vkWaitForFences(1, fences, VK_TRUE, 1000000);
420                         } while (res == VkResult.VK_TIMEOUT);
421
422                         device.vkDestroyFence(fence, null);
423                 }
424         }
425
426         void shutdown() {
427                 device.vkDestroyCommandPool(commandPool, null);
428                 device.vkDestroyPipeline(computePipeline.getAtIndex(0), null);
429                 device.vkDestroyPipelineLayout(pipelineLayout, null);
430                 device.vkDestroyShaderModule(mandelbrotShader, null);
431
432                 device.vkDestroyDescriptorPool(descriptorPool, null);
433                 device.vkDestroyDescriptorSetLayout(descriptorSetLayout, null);
434
435                 device.vkFreeMemory(dst.memory(), null);
436                 device.vkDestroyBuffer(dst.buffer(), null);
437
438                 device.vkDestroyDevice(null);
439                 instance.vkDestroyInstance(null);
440         }
441
442         /**
443          * Accesses the gpu buffer, converts it to RGB byte, and saves it as a pam file.
444          */
445         void save_result() throws Exception {
446                 try (ResourceScope scope = ResourceScope.newConfinedScope()) {
447                         MemorySegment mem = device.vkMapMemory(dst.memory(), 0, dstBufferSize, 0, scope);
448                         byte[] pixels = new byte[WIDTH * HEIGHT * 3];
449
450                         // this is super-slow!
451                         for (int i = 0; i < WIDTH * HEIGHT; i++) {
452                                 pixels[i * 3 + 0] = (byte)(255.0f * mem.getAtIndex(Memory.FLOAT, i * 4 + 0));
453                                 pixels[i * 3 + 1] = (byte)(255.0f * mem.getAtIndex(Memory.FLOAT, i * 4 + 1));
454                                 pixels[i * 3 + 2] = (byte)(255.0f * mem.getAtIndex(Memory.FLOAT, i * 4 + 2));
455                         }
456
457                         device.vkUnmapMemory(dst.memory());
458
459                         pam_save("mandelbrot.pam", WIDTH, HEIGHT, 3, pixels);
460                 }
461         }
462
463
464         /**
465          * Trivial pnm format image output.
466          */
467         void pam_save(String name, int width, int height, int depth, byte[] pixels) throws IOException {
468                 try (FileOutputStream fos = new FileOutputStream(name)) {
469                         fos.write(String.format("P6\n%d\n%d\n255\n", width, height).getBytes());
470                         fos.write(pixels);
471                         System.out.printf("wrote: %s\n", name);
472                 }
473         }
474
475         static Memory.IntArray loadSPIRV(String name) throws IOException {
476                 // hmm any way to just load this directly?
477                 try (InputStream is = TestVulkan.class.getResourceAsStream(name)) {
478                         ByteBuffer bb = ByteBuffer.allocateDirect(8192).order(ByteOrder.nativeOrder());
479                         int length = Channels.newChannel(is).read(bb);
480
481                         bb.position(0);
482                         bb.limit(length);
483                         return new Memory.IntArray(MemorySegment.ofByteBuffer(bb));
484                 }
485         }
486
487         /**
488          * This finds the memory type index for the memory on a specific device.
489          */
490         static int find_memory_type(VkPhysicalDeviceMemoryProperties memory, int typeMask, int query) {
491                 VkMemoryType mtypes = memory.getMemoryTypes();
492
493                 for (int i = 0; i < memory.getMemoryTypeCount(); i++) {
494                         if (((1 << i) & typeMask) != 0 && ((mtypes.getPropertyFlags(i) & query) == query))
495                                 return i;
496                 }
497                 return -1;
498         }
499
500         public static int VK_MAKE_API_VERSION(int variant, int major, int minor, int patch) {
501                 return (variant << 29) | (major << 22) | (minor << 12) | patch;
502         }
503
504         void demo() throws Exception {
505                 mandelbrot_cs = loadSPIRV("mandelbrot.bin");
506
507                 init_instance();
508                 //init_debug();
509                 init_device();
510
511                 dst = init_buffer(dstBufferSize,
512                         VK_BUFFER_USAGE_STORAGE_BUFFER_BIT,
513                         VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
514
515                 init_descriptor();
516                 init_pipeline();
517                 init_command_buffer();
518
519                 System.out.printf("Calculating %dx%d\n", WIDTH, HEIGHT);
520                 execute();
521                 System.out.println("Saving ...");
522                 save_result();
523                 System.out.println("Done.");
524
525                 shutdown();
526         }
527
528
529         public static void main(String[] args) throws Throwable {
530                 System.loadLibrary("vulkan");
531
532                 new TestVulkan().demo();
533         }
534 }