Updated for openjdk-19-internal
[panamaz] / src / template / Memory.java
1 /*
2  * Copyright (C) 2020 Michael Zucchi
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
16  */
17
18 package api;
19
20 import java.lang.invoke.*;
21 import java.lang.ref.Cleaner;
22 import jdk.incubator.foreign.*;
23 import static jdk.incubator.foreign.ValueLayout.*;
24
25 import java.util.AbstractList;
26 import java.util.function.Function;
27 import java.util.function.BiFunction;
28
29 /**
30  * A utility for memory operations including a stack allocator.
31  * <p>
32  * The stack allocator works like this
33  * <pre>
34  * try (Frame f = Memory.createFrame()) {
35  *              MemorySegment a = f.allocate(size);
36  * }
37  * </pre>
38  * Any memory allocated is freed when the frame is closed.
39  * <p>
40  * This is MUCH faster than using MemorySegment.allocateNative().
41  */
42 public class Memory {
43
44         // probably should be INT8 INT16, etc
45         public  static final OfByte BYTE = JAVA_BYTE;
46         public static final OfShort SHORT = JAVA_SHORT.withBitAlignment(16);
47         public static final OfInt INT = JAVA_INT.withBitAlignment(32);
48         public static final OfLong LONG = JAVA_LONG.withBitAlignment(64);
49         public static final OfFloat FLOAT = JAVA_FLOAT.withBitAlignment(32);
50         public static final OfDouble DOUBLE = JAVA_DOUBLE.withBitAlignment(64);
51         public static final OfAddress POINTER = ADDRESS.withBitAlignment(64);
52
53         static final ResourceScope sharedScope = ResourceScope.newSharedScope(); // cleaner?
54         static final MemorySegment NULL = MemorySegment.ofAddress(MemoryAddress.NULL, 1, ResourceScope.globalScope());
55
56         public static ResourceScope sharedScope() {
57                 return sharedScope;
58         }
59
60         public static MethodHandle downcall(String name, FunctionDescriptor desc) {
61                 return SymbolLookup.loaderLookup().lookup(name)
62                         .map(sym -> CLinker.systemCLinker().downcallHandle(sym, desc))
63                         .orElse(null);
64         }
65
66         public static MethodHandle downcall(NativeSymbol sym, FunctionDescriptor desc) {
67                 return CLinker.systemCLinker().downcallHandle(sym, desc);
68         }
69
70         public static MethodHandle downcall(String name, MemoryAddress sym, FunctionDescriptor desc, ResourceScope scope) {
71                 return sym != MemoryAddress.NULL
72                         ? CLinker.systemCLinker().downcallHandle(NativeSymbol.ofAddress(name, sym, scope), desc)
73                         : null;
74         }
75
76         static final MethodHandles.Lookup lookup = MethodHandles.lookup();
77
78         public static NativeSymbol upcall(Object instance, FunctionDescriptor desc, ResourceScope scope) {
79                 try {
80                         java.lang.reflect.Method m = instance.getClass().getMethods()[0];
81                         MethodHandle handle = lookup.findVirtual(instance.getClass(), "call", MethodType.methodType(m.getReturnType(), m.getParameterTypes()))
82                                               .bindTo(instance);
83                         return CLinker.systemCLinker().upcallStub(handle, desc, scope);
84                 } catch (Throwable t) {
85                         throw new AssertionError(t);
86                 }
87         }
88
89
90         public static NativeSymbol upcall(Object instance, String method, String signature, FunctionDescriptor desc, ResourceScope scope) {
91                 try {
92                         MethodHandle handle = lookup.findVirtual(instance.getClass(), method, MethodType.fromMethodDescriptorString(signature, Memory.class.getClassLoader()))
93                                               .bindTo(instance);
94                         return CLinker.systemCLinker().upcallStub(handle, desc, scope);
95                 } catch (Throwable t) {
96                         throw new AssertionError(t);
97                 }
98         }
99
100         static final ResourceScope scope = ResourceScope.newSharedScope(Cleaner.create());
101         private static final ThreadLocal<Stack> stacks = ThreadLocal.withInitial(() -> new Stack(scope));
102
103         public static Frame createFrame() {
104                 return stacks.get().createFrame();
105         }
106
107         static class Stack {
108
109                 private final MemorySegment stack;
110                 private long sp;
111                 private Thread thread = Thread.currentThread();
112
113                 Stack(ResourceScope scope) {
114                         stack = MemorySegment.allocateNative(4096, 4096, scope);
115                         sp = 4096;
116                 }
117
118                 Frame createFrame() {
119
120                         return new Frame() {
121                                 private final long tos = sp;
122                                 private Thread self = thread;
123                                 private ResourceScope scope;
124
125                                 @Override
126                                 public MemorySegment allocate(long size, long alignment) {
127                                         if (self != Thread.currentThread())
128                                                 throw new IllegalStateException();
129                                         if (alignment != Long.highestOneBit(alignment))
130                                                 throw new IllegalArgumentException();
131                                         if (sp >= size) {
132                                                 sp = (sp - size) & ~(alignment - 1);
133                                                 return stack.asSlice(sp, size).fill((byte)0);
134                                         } else {
135                                                 if (scope == null)
136                                                         scope = ResourceScope.newConfinedScope();
137                                                 return MemorySegment.allocateNative(size, alignment, scope);
138                                         }
139                                 }
140
141                                 @Override
142                                 public void close() {
143                                         sp = tos;
144                                         self = null;
145                                         if (scope != null) {
146                                                 scope.close();
147                                                 scope = null;
148                                         }
149                                 }
150                         };
151                 }
152         }
153
154         public interface Addressable {
155                 MemoryAddress address();
156                 ResourceScope scope();
157         }
158
159         public record FunctionPointer<T>(NativeSymbol symbol, T function) {
160         }
161
162         public static MemoryAddress address(jdk.incubator.foreign.Addressable v) {
163                 return v != null ? v.address() : MemoryAddress.NULL;
164         }
165
166         public static MemoryAddress address(Memory.Addressable v) {
167                 return v != null ? v.address() : MemoryAddress.NULL;
168         }
169
170         public static <T> MemoryAddress address(FunctionPointer<T> v) {
171                 return v != null ? v.symbol().address() : MemoryAddress.NULL;
172         }
173
174         // hmm do i want this or not?
175         // -> added 'type safety'
176         // -> load of crap to be written
177         public static class ByteArray extends AbstractList<Byte> implements Memory.Addressable {
178                 final MemorySegment segment;
179
180                 private ByteArray(MemorySegment segment) {
181                         this.segment = segment;
182                 }
183
184                 public static ByteArray create(MemorySegment segment) {
185                         return new ByteArray(segment);
186                 }
187
188                 public static ByteArray createArray(MemoryAddress address, long length, ResourceScope scope) {
189                         return create(MemorySegment.ofAddress(address, length, scope));
190                 }
191
192                 public static ByteArray createArray(long length, SegmentAllocator alloc) {
193                         return create(alloc.allocateArray(Memory.BYTE, length));
194                 }
195
196                 public static ByteArray create(String value, SegmentAllocator alloc) {
197                         return create(alloc.allocateUtf8String(value));
198                 }
199
200                 public static ByteArray create(String value, ResourceScope scope) {
201                         return create(SegmentAllocator.nativeAllocator(scope).allocateUtf8String(value));
202                 }
203
204                 public final MemoryAddress address() {
205                         return segment.address();
206                 }
207
208                 public final ResourceScope scope() {
209                         return segment.scope();
210                 }
211
212                 @Override
213                 public int size() {
214                         return (int)length();
215                 }
216
217                 @Override
218                 public Byte get(int index) {
219                         return getAtIndex(index);
220                 }
221
222                 @Override
223                 public Byte set(int index, Byte value) {
224                         byte old = getAtIndex(index);
225                         setAtIndex(index, value);
226                         return old;
227                 }
228
229                 public long length() {
230                         return segment.byteSize() / Memory.BYTE.byteSize();
231                 }
232
233                 public byte getAtIndex(long index) {
234                         return (byte)segment.get(Memory.BYTE, index);
235                 }
236
237                 public void setAtIndex(long index, byte value) {
238                         segment.set(Memory.BYTE, index, value);
239                 }
240         }
241
242         public static class ShortArray extends AbstractList<Short> implements Memory.Addressable {
243                 final MemorySegment segment;
244
245                 public ShortArray(MemorySegment segment) {
246                         this.segment = segment;
247                 }
248
249                 public ShortArray(Frame frame, long size) {
250                         this(frame.allocateArray(Memory.SHORT, size));
251                 }
252
253                 public ShortArray(Frame frame, short... values) {
254                         this(frame.allocateArray(Memory.SHORT, values));
255                 }
256
257                 public final MemoryAddress address() {
258                         return segment.address();
259                 }
260
261                 public final ResourceScope scope() {
262                         return segment.scope();
263                 }
264
265                 @Override
266                 public int size() {
267                         return (int)length();
268                 }
269
270                 @Override
271                 public Short get(int index) {
272                         return getAtIndex(index);
273                 }
274
275                 @Override
276                 public Short set(int index, Short value) {
277                         short old = getAtIndex(index);
278                         setAtIndex(index, value);
279                         return old;
280                 }
281
282                 public long length() {
283                         return segment.byteSize() / Memory.SHORT.byteSize();
284                 }
285
286                 public short getAtIndex(long index) {
287                         return segment.getAtIndex(Memory.SHORT, index);
288                 }
289
290                 public void setAtIndex(long index, short value) {
291                         segment.setAtIndex(Memory.SHORT, index, value);
292                 }
293         }
294
295         public static class IntArray extends AbstractList<Integer> implements Memory.Addressable {
296                 final MemorySegment segment;
297
298                 public IntArray(MemorySegment segment) {
299                         this.segment = segment;
300                 }
301
302                 public IntArray(Frame frame, long size) {
303                         this(frame.allocateArray(Memory.INT, size));
304                 }
305
306                 public IntArray(Frame frame, int... values) {
307                         this(frame.allocateArray(Memory.INT, values));
308                 }
309
310                 public final MemoryAddress address() {
311                         return segment.address();
312                 }
313
314                 public final ResourceScope scope() {
315                         return segment.scope();
316                 }
317
318                 @Override
319                 public int size() {
320                         return (int)length();
321                 }
322
323                 @Override
324                 public Integer get(int index) {
325                         return getAtIndex(index);
326                 }
327
328                 @Override
329                 public Integer set(int index, Integer value) {
330                         int old = getAtIndex(index);
331                         setAtIndex(index, value);
332                         return old;
333                 }
334
335                 public long length() {
336                         return segment.byteSize() / Memory.INT.byteSize();
337                 }
338
339                 public int getAtIndex(long index) {
340                         return segment.getAtIndex(Memory.INT, index);
341                 }
342
343                 public void setAtIndex(long index, int value) {
344                         segment.setAtIndex(Memory.INT, index, value);
345                 }
346         }
347
348         public static class LongArray extends AbstractList<Long> implements Memory.Addressable {
349                 final MemorySegment segment;
350
351                 public LongArray(MemorySegment segment) {
352                         this.segment = segment;
353                 }
354
355                 public LongArray(Frame frame, long size) {
356                         this(frame.allocateArray(Memory.LONG, size));
357                 }
358
359                 public LongArray(Frame frame, long... values) {
360                         this(frame.allocateArray(Memory.LONG, values));
361                 }
362
363                 public final MemoryAddress address() {
364                         return segment.address();
365                 }
366
367                 public final ResourceScope scope() {
368                         return segment.scope();
369                 }
370
371                 @Override
372                 public int size() {
373                         return (int)length();
374                 }
375
376                 @Override
377                 public Long get(int index) {
378                         return getAtIndex(index);
379                 }
380
381                 @Override
382                 public Long set(int index, Long value) {
383                         long old = getAtIndex(index);
384                         setAtIndex(index, value);
385                         return old;
386                 }
387
388                 public long length() {
389                         return segment.byteSize() / Memory.LONG.byteSize();
390                 }
391
392                 public long getAtIndex(long index) {
393                         return segment.getAtIndex(Memory.LONG, index);
394                 }
395
396                 public void setAtIndex(long index, long value) {
397                         segment.setAtIndex(Memory.LONG, index, value);
398                 }
399         }
400
401         public static class FloatArray extends AbstractList<Float> implements Memory.Addressable {
402                 final MemorySegment segment;
403
404                 public FloatArray(MemorySegment segment) {
405                         this.segment = segment;
406                 }
407
408                 public FloatArray(Frame frame, long size) {
409                         this(frame.allocateArray(Memory.FLOAT, size));
410                 }
411
412                 public FloatArray(Frame frame, float... values) {
413                         this(frame.allocateArray(Memory.FLOAT, values));
414                 }
415
416                 public final MemoryAddress address() {
417                         return segment.address();
418                 }
419
420                 public final ResourceScope scope() {
421                         return segment.scope();
422                 }
423
424                 @Override
425                 public int size() {
426                         return (int)length();
427                 }
428
429                 @Override
430                 public Float get(int index) {
431                         return getAtIndex(index);
432                 }
433
434                 @Override
435                 public Float set(int index, Float value) {
436                         float old = getAtIndex(index);
437                         setAtIndex(index, value);
438                         return old;
439                 }
440
441                 public long length() {
442                         return segment.byteSize() / Memory.FLOAT.byteSize();
443                 }
444
445                 public float getAtIndex(long index) {
446                         return segment.getAtIndex(Memory.FLOAT, index);
447                 }
448
449                 public void setAtIndex(long index, float value) {
450                         segment.setAtIndex(Memory.FLOAT, index, value);
451                 }
452         }
453
454         public static class DoubleArray extends AbstractList<Double> implements Memory.Addressable {
455                 final MemorySegment segment;
456
457                 public DoubleArray(MemorySegment segment) {
458                         this.segment = segment;
459                 }
460
461                 public DoubleArray(Frame frame, long size) {
462                         this(frame.allocateArray(Memory.DOUBLE, size));
463                 }
464
465                 public DoubleArray(Frame frame, double... values) {
466                         this(frame.allocateArray(Memory.DOUBLE, values));
467                 }
468
469                 public final MemoryAddress address() {
470                         return segment.address();
471                 }
472
473                 public final ResourceScope scope() {
474                         return segment.scope();
475                 }
476
477                 @Override
478                 public int size() {
479                         return (int)length();
480                 }
481
482                 @Override
483                 public Double get(int index) {
484                         return getAtIndex(index);
485                 }
486
487                 @Override
488                 public Double set(int index, Double value) {
489                         double old = getAtIndex(index);
490                         setAtIndex(index, value);
491                         return old;
492                 }
493
494                 public long length() {
495                         return segment.byteSize() / Memory.DOUBLE.byteSize();
496                 }
497
498                 public double getAtIndex(long index) {
499                         return segment.getAtIndex(Memory.DOUBLE, index);
500                 }
501
502                 public void setAtIndex(long index, double value) {
503                         segment.setAtIndex(Memory.DOUBLE, index, value);
504                 }
505         }
506
507         public static class PointerArray extends AbstractList<MemoryAddress> implements Memory.Addressable {
508                 final MemorySegment segment;
509
510                 private PointerArray(MemorySegment segment) {
511                         this.segment = segment;
512                 }
513
514                 public static PointerArray create(MemorySegment segment) {
515                         return new PointerArray(segment);
516                 }
517
518                 public static PointerArray createArray(long size, SegmentAllocator alloc) {
519                         return create(alloc.allocateArray(Memory.POINTER, size));
520                 }
521
522                 public final MemoryAddress address() {
523                         return segment.address();
524                 }
525
526                 public final ResourceScope scope() {
527                         return segment.scope();
528                 }
529
530                 @Override
531                 public int size() {
532                         return (int)length();
533                 }
534
535                 @Override
536                 public MemoryAddress get(int index) {
537                         return getAtIndex(index);
538                 }
539
540                 @Override
541                 public MemoryAddress set(int index, MemoryAddress value) {
542                         MemoryAddress old = getAtIndex(index);
543                         setAtIndex(index, value);
544                         return old;
545                 }
546
547                 public long length() {
548                         return segment.byteSize() / Memory.POINTER.byteSize();
549                 }
550
551                 public MemoryAddress getAtIndex(long index) {
552                         return segment.getAtIndex(Memory.POINTER, index);
553                 }
554
555                 public void setAtIndex(long index, MemoryAddress value) {
556                         segment.setAtIndex(Memory.POINTER, index, value);
557                 }
558         }
559
560         public static class HandleArray<T extends Memory.Addressable> extends AbstractList<T> implements Memory.Addressable {
561                 final MemorySegment segment;
562                 BiFunction<MemoryAddress,ResourceScope,T> create;
563
564                 private HandleArray(MemorySegment segment, BiFunction<MemoryAddress,ResourceScope,T> create) {
565                         this.segment = segment;
566                         this.create = create;
567                 }
568
569                 public static <T extends Memory.Addressable> HandleArray<T> create(MemorySegment segment, BiFunction<MemoryAddress,ResourceScope,T> create) {
570                         return new HandleArray<>(segment, create);
571                 }
572
573                 public static <T extends Memory.Addressable> HandleArray<T> createArray(long size, SegmentAllocator alloc, BiFunction<MemoryAddress,ResourceScope,T> create) {
574                         return create(alloc.allocateArray(Memory.POINTER, size), create);
575                 }
576
577                 @Override
578                 public final MemoryAddress address() {
579                         return segment.address();
580                 }
581
582                 public final ResourceScope scope() {
583                         return segment.scope();
584                 }
585
586                 @Override
587                 public int size() {
588                         return (int)length();
589                 }
590
591                 @Override
592                 public T get(int index) {
593                         return getAtIndex(index);
594                 }
595
596                 @Override
597                 public T set(int index, T value) {
598                         T old = getAtIndex(index);
599                         setAtIndex(index, value);
600                         return old;
601                 }
602
603                 public long length() {
604                         return segment.byteSize() / Memory.POINTER.byteSize();
605                 }
606
607                 public T getAtIndex(long index) {
608                         MemoryAddress ptr = segment.getAtIndex(Memory.POINTER, index);
609                         return ptr != null ? create.apply(ptr, scope) : null;
610                 }
611
612                 public void setAtIndex(long index, T value) {
613                         segment.setAtIndex(Memory.POINTER, index, value != null ? value.address() : MemoryAddress.NULL);
614                 }
615         }
616
617 }