Updated for openjdk-19-internal
[panamaz] / test-vulkan / src / zvk / Memory.java
1 /*
2  * Copyright (C) 2020 Michael Zucchi
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
16  */
17
18 package zvk;
19
20 import java.lang.invoke.*;
21 import java.lang.ref.Cleaner;
22 import jdk.incubator.foreign.*;
23 import static jdk.incubator.foreign.ValueLayout.*;
24
25 import java.util.AbstractList;
26 import java.util.function.Function;
27
28 /**
29  * A utility for memory operations including a stack allocator.
30  * <p>
31  * The stack allocator works like this
32  * <pre>
33  * try (Frame f = Memory.createFrame()) {
34  *              MemorySegment a = f.allocate(size);
35  * }
36  * </pre>
37  * Any memory allocated is freed when the frame is closed.
38  * <p>
39  * This is MUCH faster than using MemorySegment.allocateNative().
40  */
41 public class Memory {
42
43         // probably should be INT8 INT16, etc
44         public static final OfByte BYTE = JAVA_BYTE;
45         public static final OfShort SHORT = JAVA_SHORT.withBitAlignment(16);
46         public static final OfInt INT = JAVA_INT.withBitAlignment(32);
47         public static final OfLong LONG = JAVA_LONG.withBitAlignment(64);
48         public static final OfFloat FLOAT = JAVA_FLOAT.withBitAlignment(32);
49         public static final OfDouble DOUBLE = JAVA_DOUBLE.withBitAlignment(64);
50         public static final OfAddress POINTER = ADDRESS.withBitAlignment(64);
51
52         static final ResourceScope sharedScope = ResourceScope.newSharedScope(); // cleaner?
53         static final MemorySegment NULL = MemorySegment.ofAddress(MemoryAddress.NULL, 1, ResourceScope.globalScope());
54
55         public static ResourceScope sharedScope() {
56                 return sharedScope;
57         }
58
59         public static MethodHandle downcall(String name, FunctionDescriptor desc) {
60                 return SymbolLookup.loaderLookup().lookup(name)
61                         .map(sym -> CLinker.systemCLinker().downcallHandle(sym, desc))
62                         .orElse(null);
63         }
64
65         public static MethodHandle downcall(NativeSymbol sym, FunctionDescriptor desc) {
66                 return CLinker.systemCLinker().downcallHandle(sym, desc);
67         }
68
69         public static MethodHandle downcall(String name, MemoryAddress sym, FunctionDescriptor desc, ResourceScope scope) {
70                 return sym != MemoryAddress.NULL
71                         ? CLinker.systemCLinker().downcallHandle(NativeSymbol.ofAddress(name, sym, scope), desc)
72                         : null;
73         }
74
75         static final MethodHandles.Lookup lookup = MethodHandles.lookup();
76
77         public static NativeSymbol upcall(Object instance, FunctionDescriptor desc, ResourceScope scope) {
78                 try {
79                         java.lang.reflect.Method m = instance.getClass().getMethods()[0];
80                         MethodHandle handle = lookup.findVirtual(instance.getClass(), "call", MethodType.methodType(m.getReturnType(), m.getParameterTypes()))
81                                 .bindTo(instance);
82                         return CLinker.systemCLinker().upcallStub(handle, desc, scope);
83                 } catch (Throwable t) {
84                         throw new AssertionError(t);
85                 }
86         }
87
88
89         public static NativeSymbol upcall(Object instance, String method, String signature, FunctionDescriptor desc, ResourceScope scope) {
90                 try {
91                         MethodHandle handle = lookup.findVirtual(instance.getClass(), method, MethodType.fromMethodDescriptorString(signature, Memory.class.getClassLoader()))
92                                               .bindTo(instance);
93                         return CLinker.systemCLinker().upcallStub(handle, desc, scope);
94                 } catch (Throwable t) {
95                         throw new AssertionError(t);
96                 }
97         }
98
99 static final ResourceScope scope = ResourceScope.newSharedScope(Cleaner.create());
100         private static final ThreadLocal<Stack> stacks = ThreadLocal.withInitial(() -> new Stack(scope));
101
102         public static Frame createFrame() {
103                 return stacks.get().createFrame();
104         }
105
106         static class Stack {
107
108                 private final MemorySegment stack;
109                 private long sp;
110                 private Thread thread = Thread.currentThread();
111
112                 Stack(ResourceScope scope) {
113                         stack = MemorySegment.allocateNative(4096, 4096, scope);
114                         sp = 4096;
115                 }
116
117                 Frame createFrame() {
118
119                         return new Frame() {
120                                 private final long tos = sp;
121                                 private Thread self = thread;
122                                 private ResourceScope scope;
123
124                                 @Override
125                                 public MemorySegment allocate(long size, long alignment) {
126                                         if (self != Thread.currentThread())
127                                                 throw new IllegalStateException();
128                                         if (alignment != Long.highestOneBit(alignment))
129                                                 throw new IllegalArgumentException();
130                                         if (sp >= size) {
131                                                 sp = (sp - size) & ~(alignment - 1);
132                                                 return stack.asSlice(sp, size).fill((byte)0);
133                                         } else {
134                                                 if (scope == null)
135                                                         scope = ResourceScope.newConfinedScope();
136                                                 return MemorySegment.allocateNative(size, alignment, scope);
137                                         }
138                                 }
139
140                                 @Override
141                                 public void close() {
142                                         sp = tos;
143                                         self = null;
144                                         if (scope != null) {
145                                                 scope.close();
146                                                 scope = null;
147                                         }
148                                 }
149                         };
150                 }
151         }
152
153         public static MemoryAddress address(jdk.incubator.foreign.Addressable v) {
154                 return v != null ? v.address() : MemoryAddress.NULL;
155         }
156
157         public static MemoryAddress address(Memory.Addressable v) {
158                 return v != null ? v.address() : MemoryAddress.NULL;
159         }
160
161         public interface Addressable {
162                 MemoryAddress address();
163         }
164
165         // hmm do i want this or not?
166         // -> added 'type safety'
167         // -> load of crap to be written
168         public static class ByteArray extends AbstractList<Byte> implements Memory.Addressable {
169                 final MemorySegment segment;
170
171                 public ByteArray(MemorySegment segment) {
172                         this.segment = segment;
173                 }
174
175                 public ByteArray(Frame frame, long size) {
176                         this(frame.allocateArray(Memory.BYTE, size));
177                 }
178
179                 public ByteArray(Frame frame, byte... values) {
180                         this(frame.allocateArray(Memory.BYTE, values));
181                 }
182
183                 public final MemoryAddress address() {
184                         return segment.address();
185                 }
186
187                 @Override
188                 public int size() {
189                         return (int)length();
190                 }
191
192                 @Override
193                 public Byte get(int index) {
194                         return getAtIndex(index);
195                 }
196
197                 @Override
198                 public Byte set(int index, Byte value) {
199                         byte old = getAtIndex(index);
200                         setAtIndex(index, value);
201                         return old;
202                 }
203
204                 public long length() {
205                         return segment.byteSize() / Memory.BYTE.byteSize();
206                 }
207
208                 public byte getAtIndex(long index) {
209                         return (byte)segment.get(Memory.BYTE, index);
210                 }
211
212                 public void setAtIndex(long index, byte value) {
213                         segment.set(Memory.BYTE, index, value);
214                 }
215         }
216
217         public static class ShortArray extends AbstractList<Short> implements Memory.Addressable {
218                 final MemorySegment segment;
219
220                 public ShortArray(MemorySegment segment) {
221                         this.segment = segment;
222                 }
223
224                 public ShortArray(Frame frame, long size) {
225                         this(frame.allocateArray(Memory.SHORT, size));
226                 }
227
228                 public ShortArray(Frame frame, short... values) {
229                         this(frame.allocateArray(Memory.SHORT, values));
230                 }
231
232                 public final MemoryAddress address() {
233                         return segment.address();
234                 }
235
236                 @Override
237                 public int size() {
238                         return (int)length();
239                 }
240
241                 @Override
242                 public Short get(int index) {
243                         return getAtIndex(index);
244                 }
245
246                 @Override
247                 public Short set(int index, Short value) {
248                         short old = getAtIndex(index);
249                         setAtIndex(index, value);
250                         return old;
251                 }
252
253                 public long length() {
254                         return segment.byteSize() / Memory.SHORT.byteSize();
255                 }
256
257                 public short getAtIndex(long index) {
258                         return segment.getAtIndex(Memory.SHORT, index);
259                 }
260
261                 public void setAtIndex(long index, short value) {
262                         segment.setAtIndex(Memory.SHORT, index, value);
263                 }
264         }
265
266         public static class IntArray extends AbstractList<Integer> implements Memory.Addressable {
267                 final MemorySegment segment;
268
269                 public IntArray(MemorySegment segment) {
270                         this.segment = segment;
271                 }
272
273                 public IntArray(Frame frame, long size) {
274                         this(frame.allocateArray(Memory.INT, size));
275                 }
276
277                 public IntArray(Frame frame, int... values) {
278                         this(frame.allocateArray(Memory.INT, values));
279                 }
280
281                 public final MemoryAddress address() {
282                         return segment.address();
283                 }
284
285                 @Override
286                 public int size() {
287                         return (int)length();
288                 }
289
290                 @Override
291                 public Integer get(int index) {
292                         return getAtIndex(index);
293                 }
294
295                 @Override
296                 public Integer set(int index, Integer value) {
297                         int old = getAtIndex(index);
298                         setAtIndex(index, value);
299                         return old;
300                 }
301
302                 public long length() {
303                         return segment.byteSize() / Memory.INT.byteSize();
304                 }
305
306                 public int getAtIndex(long index) {
307                         return segment.getAtIndex(Memory.INT, index);
308                 }
309
310                 public void setAtIndex(long index, int value) {
311                         segment.setAtIndex(Memory.INT, index, value);
312                 }
313         }
314
315         public static class LongArray extends AbstractList<Long> implements Memory.Addressable {
316                 final MemorySegment segment;
317
318                 public LongArray(MemorySegment segment) {
319                         this.segment = segment;
320                 }
321
322                 public LongArray(Frame frame, long size) {
323                         this(frame.allocateArray(Memory.LONG, size));
324                 }
325
326                 public LongArray(Frame frame, long... values) {
327                         this(frame.allocateArray(Memory.LONG, values));
328                 }
329
330                 public final MemoryAddress address() {
331                         return segment.address();
332                 }
333
334                 @Override
335                 public int size() {
336                         return (int)length();
337                 }
338
339                 @Override
340                 public Long get(int index) {
341                         return getAtIndex(index);
342                 }
343
344                 @Override
345                 public Long set(int index, Long value) {
346                         long old = getAtIndex(index);
347                         setAtIndex(index, value);
348                         return old;
349                 }
350
351                 public long length() {
352                         return segment.byteSize() / Memory.LONG.byteSize();
353                 }
354
355                 public long getAtIndex(long index) {
356                         return segment.getAtIndex(Memory.LONG, index);
357                 }
358
359                 public void setAtIndex(long index, long value) {
360                         segment.setAtIndex(Memory.LONG, index, value);
361                 }
362         }
363
364         public static class FloatArray extends AbstractList<Float> implements Memory.Addressable {
365                 final MemorySegment segment;
366
367                 public FloatArray(MemorySegment segment) {
368                         this.segment = segment;
369                 }
370
371                 public FloatArray(Frame frame, long size) {
372                         this(frame.allocateArray(Memory.FLOAT, size));
373                 }
374
375                 public FloatArray(Frame frame, float... values) {
376                         this(frame.allocateArray(Memory.FLOAT, values));
377                 }
378
379                 public final MemoryAddress address() {
380                         return segment.address();
381                 }
382
383                 @Override
384                 public int size() {
385                         return (int)length();
386                 }
387
388                 @Override
389                 public Float get(int index) {
390                         return getAtIndex(index);
391                 }
392
393                 @Override
394                 public Float set(int index, Float value) {
395                         float old = getAtIndex(index);
396                         setAtIndex(index, value);
397                         return old;
398                 }
399
400                 public long length() {
401                         return segment.byteSize() / Memory.FLOAT.byteSize();
402                 }
403
404                 public float getAtIndex(long index) {
405                         return segment.getAtIndex(Memory.FLOAT, index);
406                 }
407
408                 public void setAtIndex(long index, float value) {
409                         segment.setAtIndex(Memory.FLOAT, index, value);
410                 }
411         }
412
413         public static class DoubleArray extends AbstractList<Double> implements Memory.Addressable {
414                 final MemorySegment segment;
415
416                 public DoubleArray(MemorySegment segment) {
417                         this.segment = segment;
418                 }
419
420                 public DoubleArray(Frame frame, long size) {
421                         this(frame.allocateArray(Memory.DOUBLE, size));
422                 }
423
424                 public DoubleArray(Frame frame, double... values) {
425                         this(frame.allocateArray(Memory.DOUBLE, values));
426                 }
427
428                 public final MemoryAddress address() {
429                         return segment.address();
430                 }
431
432                 @Override
433                 public int size() {
434                         return (int)length();
435                 }
436
437                 @Override
438                 public Double get(int index) {
439                         return getAtIndex(index);
440                 }
441
442                 @Override
443                 public Double set(int index, Double value) {
444                         double old = getAtIndex(index);
445                         setAtIndex(index, value);
446                         return old;
447                 }
448
449                 public long length() {
450                         return segment.byteSize() / Memory.DOUBLE.byteSize();
451                 }
452
453                 public double getAtIndex(long index) {
454                         return segment.getAtIndex(Memory.DOUBLE, index);
455                 }
456
457                 public void setAtIndex(long index, double value) {
458                         segment.setAtIndex(Memory.DOUBLE, index, value);
459                 }
460         }
461
462         public static class PointerArray extends AbstractList<MemoryAddress> implements Memory.Addressable {
463                 final MemorySegment segment;
464
465                 public PointerArray(MemorySegment segment) {
466                         this.segment = segment;
467                 }
468
469                 public PointerArray(Frame frame, long size) {
470                         this(frame.allocateArray(Memory.POINTER, size));
471                 }
472
473                 public PointerArray(Frame frame, MemoryAddress... values) {
474                         this(frame.allocateArray(Memory.POINTER, values));
475                 }
476
477                 public final MemoryAddress address() {
478                         return segment.address();
479                 }
480
481                 @Override
482                 public int size() {
483                         return (int)length();
484                 }
485
486                 @Override
487                 public MemoryAddress get(int index) {
488                         return getAtIndex(index);
489                 }
490
491                 @Override
492                 public MemoryAddress set(int index, MemoryAddress value) {
493                         MemoryAddress old = getAtIndex(index);
494                         setAtIndex(index, value);
495                         return old;
496                 }
497
498                 public long length() {
499                         return segment.byteSize() / Memory.POINTER.byteSize();
500                 }
501
502                 public MemoryAddress getAtIndex(long index) {
503                         return segment.getAtIndex(Memory.POINTER, index);
504                 }
505
506                 public void setAtIndex(long index, MemoryAddress value) {
507                         segment.setAtIndex(Memory.POINTER, index, value);
508                 }
509         }
510
511         public static class HandleArray<T extends Memory.Addressable> extends AbstractList<T> implements Memory.Addressable {
512                 final MemorySegment segment;
513                 Function<MemoryAddress,T> create;
514
515                 public HandleArray(Function<MemoryAddress,T> create, MemorySegment segment) {
516                         this.segment = segment;
517                         this.create = create;
518                 }
519
520                 public HandleArray(Frame frame, Function<MemoryAddress,T> create, long size) {
521                         this(create, frame.allocateArray(Memory.POINTER, size));
522                 }
523
524                 @Override
525                 public final MemoryAddress address() {
526                         return segment.address();
527                 }
528
529                 @Override
530                 public int size() {
531                         return (int)length();
532                 }
533
534                 @Override
535                 public T get(int index) {
536                         return getAtIndex(index);
537                 }
538
539                 @Override
540                 public T set(int index, T value) {
541                         T old = getAtIndex(index);
542                         setAtIndex(index, value);
543                         return old;
544                 }
545
546                 public long length() {
547                         return segment.byteSize() / Memory.POINTER.byteSize();
548                 }
549
550                 public T getAtIndex(long index) {
551                         MemoryAddress ptr = segment.getAtIndex(Memory.POINTER, index);
552                         return ptr != null ? create.apply(ptr) : null;
553                 }
554
555                 public void setAtIndex(long index, T value) {
556                         segment.setAtIndex(Memory.POINTER, index, value != null ? value.address() : MemoryAddress.NULL);
557                 }
558         }
559
560 }