Skip to content

Commit 00930d1

Browse files
fix: all tests passed
Signed-off-by: Akhilender Bongirwar <[email protected]>
1 parent 7ba2516 commit 00930d1

File tree

2 files changed

+59
-12
lines changed

2 files changed

+59
-12
lines changed

kyo-offheap/native/src/main/java/lang/foreign/Arena.scala

+13-7
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ import scala.scalanative.unsafe.*
99
* An Arena is used to track off-heap allocations so that they can be automatically deallocated when the scope ends.
1010
*/
1111
final class Arena extends AutoCloseable:
12-
private val allocations = new ConcurrentLinkedQueue[Ptr[Byte]]()
12+
private val allocations = new ConcurrentLinkedQueue[MemorySegment]()
13+
@volatile private var _isClosed: Boolean = false
1314

1415
/** Allocates a block of off-heap memory of the given size (in bytes).
1516
*
@@ -19,19 +20,24 @@ final class Arena extends AutoCloseable:
1920
* A MemorySegment representing the allocated block.
2021
*/
2122
def allocate(byteSize: Long): MemorySegment =
22-
val segment = MemorySegment.allocate(byteSize)
23-
allocations.add(segment.ptr)
23+
if _isClosed then
24+
throw new IllegalStateException("Arena is closed")
25+
val segment = MemorySegment.allocate(byteSize, this)
26+
allocations.add(segment)
2427
segment
2528
end allocate
2629

2730
/** Frees all memory that was allocated through this Arena.
2831
*/
2932
override def close(): Unit =
30-
var ptr = allocations.poll()
31-
while ptr != null do
32-
free(ptr)
33-
ptr = allocations.poll()
33+
_isClosed = true
34+
var seg = allocations.poll()
35+
while seg != null do
36+
free(seg.ptr)
37+
seg = allocations.poll()
3438
end close
39+
40+
def isClosed: Boolean = _isClosed
3541
end Arena
3642

3743
object Arena:

kyo-offheap/native/src/main/java/lang/foreign/MemorySegment.scala

+46-5
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ import scala.scalanative.unsigned.*
1010
* This implementation wraps a pointer along with the allocated size in bytes. It also provides basic support for slicing and for
1111
* reading/writing primitive values, which are used by the Layout instances in the shared code.
1212
*/
13-
final class MemorySegment private (private[foreign] val ptr: Ptr[Byte], val byteSize: Long):
13+
final class MemorySegment private (private[foreign] val ptr: Ptr[Byte], val byteSize: Long, private val arena: Arena):
14+
private def checkOpen(): Unit =
15+
if arena.isClosed then
16+
throw new IllegalStateException("MemorySegment accessed after Arena was closed")
17+
1418
/** Creates a new MemorySegment that is a slice of this segment.
1519
*
1620
* @param offset
@@ -21,70 +25,107 @@ final class MemorySegment private (private[foreign] val ptr: Ptr[Byte], val byte
2125
* A new MemorySegment representing the slice.
2226
*/
2327
def asSlice(offset: Long, newSize: Long): MemorySegment =
28+
checkOpen()
2429
if offset < 0 || newSize < 0 || offset + newSize > byteSize then
2530
throw new IllegalArgumentException(s"Invalid slice parameters: byteSize=$byteSize, offset=$offset, newSize=$newSize")
2631
else
27-
MemorySegment(ptr + offset, newSize)
32+
MemorySegment(ptr + offset, newSize, arena)
33+
end if
34+
end asSlice
2835

2936
/** Reads a value from memory using the provided layout.
3037
*/
3138
def get(layout: ValueLayout.OfBoolean, offset: Long): Boolean =
39+
checkOpen()
3240
require(offset + layout.byteSize <= byteSize)
3341
!(ptr + offset).asInstanceOf[Ptr[CBool]]
42+
end get
3443
def get(layout: ValueLayout.OfByte, offset: Long): Byte =
44+
checkOpen()
3545
require(offset + layout.byteSize <= byteSize)
3646
!(ptr + offset).asInstanceOf[Ptr[Byte]]
47+
end get
3748
def get(layout: ValueLayout.OfShort, offset: Long): Short =
49+
checkOpen()
3850
require(offset + layout.byteSize <= byteSize)
3951
!(ptr + offset).asInstanceOf[Ptr[Short]]
52+
end get
4053
def get(layout: ValueLayout.OfInt, offset: Long): Int =
54+
checkOpen()
4155
require(offset + layout.byteSize <= byteSize)
4256
!(ptr + offset).asInstanceOf[Ptr[Int]]
57+
end get
4358
def get(layout: ValueLayout.OfLong, offset: Long): Long =
59+
checkOpen()
4460
require(offset + layout.byteSize <= byteSize)
4561
!(ptr + offset).asInstanceOf[Ptr[Long]]
62+
end get
4663
def get(layout: ValueLayout.OfFloat, offset: Long): Float =
64+
checkOpen()
4765
require(offset + layout.byteSize <= byteSize)
4866
!(ptr + offset).asInstanceOf[Ptr[Float]]
67+
end get
4968
def get(layout: ValueLayout.OfDouble, offset: Long): Double =
69+
checkOpen()
5070
require(offset + layout.byteSize <= byteSize)
5171
!(ptr + offset).asInstanceOf[Ptr[Double]]
72+
end get
5273
def get(layout: ValueLayout.OfChar, offset: Long): Char =
74+
checkOpen()
5375
require(offset + layout.byteSize <= byteSize)
5476
!(ptr + offset).asInstanceOf[Ptr[Char]]
77+
end get
5578
def get(layout: AddressLayout, offset: Long): MemorySegment =
79+
checkOpen()
5680
val newByteSize = byteSize - offset
5781
require(newByteSize >= 0)
58-
new MemorySegment((ptr + offset).asInstanceOf[Ptr[Byte]], newByteSize)
82+
new MemorySegment((ptr + offset).asInstanceOf[Ptr[Byte]], newByteSize, arena)
5983
end get
6084

6185
/** Writes a value to memory using the provided layout.
6286
*/
6387
def set(layout: ValueLayout.OfBoolean, offset: Long, value: Boolean): Unit =
88+
checkOpen()
6489
require(offset + layout.byteSize <= byteSize)
6590
!(ptr + offset).asInstanceOf[Ptr[CBool]] = value
91+
end set
6692
def set(layout: ValueLayout.OfByte, offset: Long, value: Byte): Unit =
93+
checkOpen()
6794
require(offset + layout.byteSize <= byteSize)
6895
!(ptr + offset).asInstanceOf[Ptr[Byte]] = value
96+
end set
6997
def set(layout: ValueLayout.OfShort, offset: Long, value: Short): Unit =
98+
checkOpen()
7099
require(offset + layout.byteSize <= byteSize)
71100
!(ptr + offset).asInstanceOf[Ptr[Short]] = value
101+
end set
72102
def set(layout: ValueLayout.OfInt, offset: Long, value: Int): Unit =
103+
checkOpen()
73104
require(offset + layout.byteSize <= byteSize)
74105
!(ptr + offset).asInstanceOf[Ptr[Int]] = value
106+
end set
75107
def set(layout: ValueLayout.OfLong, offset: Long, value: Long): Unit =
108+
checkOpen()
76109
require(offset + layout.byteSize <= byteSize)
77110
!(ptr + offset).asInstanceOf[Ptr[Long]] = value
111+
end set
78112
def set(layout: ValueLayout.OfFloat, offset: Long, value: Float): Unit =
113+
checkOpen()
79114
require(offset + layout.byteSize <= byteSize)
80115
!(ptr + offset).asInstanceOf[Ptr[Float]] = value
116+
end set
81117
def set(layout: ValueLayout.OfDouble, offset: Long, value: Double): Unit =
118+
checkOpen()
82119
require(offset + layout.byteSize <= byteSize)
83120
!(ptr + offset).asInstanceOf[Ptr[Double]] = value
121+
end set
84122
def set(layout: ValueLayout.OfChar, offset: Long, value: Char): Unit =
123+
checkOpen()
85124
require(offset + layout.byteSize <= byteSize)
86125
!(ptr + offset).asInstanceOf[Ptr[Char]] = value
126+
end set
87127
def set(layout: AddressLayout, offset: Long, value: MemorySegment): Unit =
128+
checkOpen()
88129
require(offset + layout.byteSize <= byteSize)
89130
val _ = memcpy((ptr + offset).asInstanceOf[Ptr[Byte]], value.ptr, value.byteSize.toCSize)
90131
end set
@@ -93,10 +134,10 @@ end MemorySegment
93134
object MemorySegment:
94135

95136
/** Allocates a new MemorySegment of the given byte size. */
96-
private[foreign] def allocate(byteSize: Long): MemorySegment =
137+
private[foreign] def allocate(byteSize: Long, arena: Arena): MemorySegment =
97138
val ptr = malloc(byteSize).asInstanceOf[Ptr[Byte]]
98139
if ptr == null then throw new RuntimeException("malloc returned null")
99-
new MemorySegment(ptr, byteSize)
140+
new MemorySegment(ptr, byteSize, arena)
100141
end allocate
101142

102143
/** Copies a block of memory from the source segment to the destination segment.

0 commit comments

Comments
 (0)