Skip to content

Commit

Permalink
Enable nullable annotations in StructType and UnionType from ctypes (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
BCSharp authored Feb 25, 2025
1 parent 99a014f commit 62b2f65
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 95 deletions.
6 changes: 3 additions & 3 deletions src/core/IronPython.Modules/_ctypes/INativeType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ int Alignment {
/// Serializes the provided value into the specified address at the given
/// offset.
/// </summary>
object SetValue(MemoryHolder/*!*/ address, int offset, object value);
object? SetValue(MemoryHolder/*!*/ address, int offset, object value);

/// <summary>
/// Gets the .NET type which is used when calling or returning the value
Expand All @@ -68,12 +68,12 @@ int Alignment {
/// Emits marshalling of an object from Python to native code. This produces the
/// native type from the Python type.
/// </summary>
MarshalCleanup EmitMarshalling(ILGenerator/*!*/ method, LocalOrArg/*!*/ argIndex, List<object>/*!*/ constantPool, int constantPoolArgument);
MarshalCleanup? EmitMarshalling(ILGenerator/*!*/ method, LocalOrArg/*!*/ argIndex, List<object>/*!*/ constantPool, int constantPoolArgument);

/// <summary>
/// Emits marshalling from native code to Python code This produces the python type
/// from the native type. This is used for return values and parameters
/// to Python callable objects that are passed back out to native code.
/// to Python callable objects that are passed back out of native code.
/// </summary>
void EmitReverseMarshalling(ILGenerator/*!*/ method, LocalOrArg/*!*/ value, List<object>/*!*/ constantPool, int constantPoolArgument);

Expand Down
125 changes: 68 additions & 57 deletions src/core/IronPython.Modules/_ctypes/StructType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
// The .NET Foundation licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information.

#nullable enable

#if FEATURE_CTYPES

using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Numerics;
using System.Reflection.Emit;
using System.Runtime.InteropServices;
using System.Text;

using Microsoft.Scripting;
Expand All @@ -19,6 +20,7 @@
using IronPython.Runtime.Operations;
using IronPython.Runtime.Types;


namespace IronPython.Modules {
/// <summary>
/// Provides support for interop with native code from Python code.
Expand All @@ -30,15 +32,16 @@ public static partial class CTypes {
/// </summary>
[PythonType, PythonHidden]
public class StructType : PythonType, INativeType {
internal Field[] _fields;
[DisallowNull]
internal Field[]? _fields; // not null after type construction completes
private int? _size, _alignment, _pack;
private static readonly Field[] _emptyFields = System.Array.Empty<Field>(); // fields were never initialized before a type was created
private static readonly Field[] _emptyFields = []; // fields were never initialized before a type was created

public StructType(CodeContext/*!*/ context, string name, PythonTuple bases, PythonDictionary members)
public StructType(CodeContext/*!*/ context, [NotNone] string name, [NotNone] PythonTuple bases, [NotNone] PythonDictionary members)
: base(context, name, bases, members) {

foreach (PythonType pt in ResolutionOrder) {
StructType st = pt as StructType;
StructType? st = pt as StructType;
if (st != this) {
st?.EnsureFinal();
}
Expand Down Expand Up @@ -71,11 +74,11 @@ private StructType(Type underlyingSystemType)
: base(underlyingSystemType) {
}

public static ArrayType/*!*/ operator *(StructType type, int count) {
public static ArrayType/*!*/ operator *([NotNone] StructType type, int count) {
return MakeArrayType(type, count);
}

public static ArrayType/*!*/ operator *(int count, StructType type) {
public static ArrayType/*!*/ operator *(int count, [NotNone] StructType type) {
return MakeArrayType(type, count);
}

Expand All @@ -93,13 +96,13 @@ public _Structure from_address(CodeContext/*!*/ context, IntPtr ptr) {
return res;
}

public _Structure from_buffer(CodeContext/*!*/ context, object/*?*/ data, int offset = 0) {
public _Structure from_buffer(CodeContext/*!*/ context, object? data, int offset = 0) {
_Structure res = (_Structure)CreateInstance(context);
res.InitializeFromBuffer(data, offset, ((INativeType)this).Size);
return res;
}

public _Structure from_buffer_copy(CodeContext/*!*/ context, object/*?*/ data, int offset = 0) {
public _Structure from_buffer_copy(CodeContext/*!*/ context, object? data, int offset = 0) {
_Structure res = (_Structure)CreateInstance(context);
res.InitializeFromBufferCopy(data, offset, ((INativeType)this).Size);
return res;
Expand All @@ -110,19 +113,19 @@ public _Structure from_buffer_copy(CodeContext/*!*/ context, object/*?*/ data, i
///
/// Structures just return themselves.
/// </summary>
public object from_param(object obj) {
public object from_param(object? obj) {
if (!Builtin.isinstance(obj, this)) {
throw PythonOps.TypeError("expected {0} instance got {1}", Name, PythonOps.GetPythonTypeName(obj));
throw PythonOps.TypeError("expected {0} instance, got {1}", Name, PythonOps.GetPythonTypeName(obj));
}

return obj;
return obj!;
}

public object in_dll(object library, string name) {
public object in_dll(object? library, [NotNone] string name) {
throw new NotImplementedException("in dll");
}

public new virtual void __setattr__(CodeContext/*!*/ context, string name, object value) {
public new virtual void __setattr__(CodeContext/*!*/ context, [NotNone] string name, object? value) {
if (name == "_fields_") {
lock (this) {
if (_fields != null) {
Expand Down Expand Up @@ -160,7 +163,7 @@ object INativeType.GetValue(MemoryHolder/*!*/ owner, object readingFrom, int off
return res;
}

object INativeType.SetValue(MemoryHolder/*!*/ address, int offset, object value) {
object? INativeType.SetValue(MemoryHolder/*!*/ address, int offset, object value) {
try {
return SetValueInternal(address, offset, value);
} catch (ArgumentTypeException e) {
Expand All @@ -174,24 +177,21 @@ object INativeType.SetValue(MemoryHolder/*!*/ address, int offset, object value)
}
}

internal object SetValueInternal(MemoryHolder address, int offset, object value) {
IList<object> init = value as IList<object>;
if (init != null) {
internal object? SetValueInternal(MemoryHolder address, int offset, object value) {
if (value is IList<object> init) {
EnsureFinal();
if (init.Count > _fields.Length) {
throw PythonOps.TypeError("too many initializers");
}

for (int i = 0; i < init.Count; i++) {
_fields[i].SetValue(address, offset, init[i]);
}
} else if (value is CData data) {
data.MemHolder.CopyTo(address, offset, data.Size);
return data.MemHolder.EnsureObjects();
} else {
CData data = value as CData;
if (data != null) {
data.MemHolder.CopyTo(address, offset, data.Size);
return data.MemHolder.EnsureObjects();
} else {
throw new NotImplementedException("set value");
}
throw new NotImplementedException("set value");
}
return null;
}
Expand All @@ -202,7 +202,7 @@ internal object SetValueInternal(MemoryHolder address, int offset, object value)
return GetMarshalTypeFromSize(_size.Value);
}

MarshalCleanup INativeType.EmitMarshalling(ILGenerator/*!*/ method, LocalOrArg argIndex, List<object>/*!*/ constantPool, int constantPoolArgument) {
MarshalCleanup? INativeType.EmitMarshalling(ILGenerator/*!*/ method, LocalOrArg argIndex, List<object>/*!*/ constantPool, int constantPoolArgument) {
Type argumentType = argIndex.Type;
argIndex.Emit(method);
if (argumentType.IsValueType) {
Expand All @@ -212,8 +212,8 @@ MarshalCleanup INativeType.EmitMarshalling(ILGenerator/*!*/ method, LocalOrArg a
method.Emit(OpCodes.Ldarg, constantPoolArgument);
method.Emit(OpCodes.Ldc_I4, constantPool.Count - 1);
method.Emit(OpCodes.Ldelem_Ref);
method.Emit(OpCodes.Call, typeof(ModuleOps).GetMethod(nameof(ModuleOps.CheckCDataType)));
method.Emit(OpCodes.Call, typeof(CData).GetProperty(nameof(CData.UnsafeAddress)).GetGetMethod());
method.Emit(OpCodes.Call, typeof(ModuleOps).GetMethod(nameof(ModuleOps.CheckCDataType))!);
method.Emit(OpCodes.Call, typeof(CData).GetProperty(nameof(CData.UnsafeAddress))!.GetGetMethod()!);
method.Emit(OpCodes.Ldobj, ((INativeType)this).GetNativeType());
return null;
}
Expand Down Expand Up @@ -251,24 +251,24 @@ internal static PythonType MakeSystemType(Type underlyingSystemType) {
return PythonType.SetPythonType(underlyingSystemType, new StructType(underlyingSystemType));
}

private void SetFields(object fields) {
[MemberNotNull(nameof(_fields), nameof(_size), nameof(_alignment))]
private void SetFields(object? fields) {
lock (this) {
IList<object> list = GetFieldsList(fields);
IList<object> fieldDefList = GetFieldsList(fields);

int? bitCount = null;
int? curBitCount = null;
INativeType lastType = null;
INativeType? lastType = null;
List<Field> allFields = GetBaseSizeAlignmentAndFields(out int size, out int alignment);

IList<object> anonFields = GetAnonymousFields(this);
IList<object>? anonFields = GetAnonymousFields(this);

for (int fieldIndex = 0; fieldIndex < list.Count; fieldIndex++) {
object o = list[fieldIndex];
GetFieldInfo(this, o, out string fieldName, out INativeType cdata, out bitCount);
foreach (object fieldDef in fieldDefList) {
GetFieldInfo(this, fieldDef, out string fieldName, out INativeType cdata, out bitCount);

int prevSize = UpdateSizeAndAlignment(cdata, bitCount, lastType, ref size, ref alignment, ref curBitCount);

Field newField = new Field(fieldName, cdata, prevSize, allFields.Count, bitCount, curBitCount - bitCount);
var newField = new Field(fieldName, cdata, prevSize, allFields.Count, bitCount, curBitCount - bitCount);
allFields.Add(newField);
AddSlot(fieldName, newField);

Expand All @@ -282,16 +282,18 @@ private void SetFields(object fields) {
CheckAnonymousFields(allFields, anonFields);

if (bitCount != null) {
size += lastType.Size;
// incomplete last bitfield
// bitCount not null implies at least one bitfield, so at least one iteration of the loop above
size += lastType!.Size;
}

_fields = allFields.ToArray();
_fields = [..allFields];
_size = PythonStruct.Align(size, alignment);
_alignment = alignment;
}
}

internal static void CheckAnonymousFields(List<Field> allFields, IList<object> anonFields) {
internal static void CheckAnonymousFields(List<Field> allFields, IList<object>? anonFields) {
if (anonFields != null) {
foreach (string s in anonFields) {
bool found = false;
Expand All @@ -309,9 +311,9 @@ internal static void CheckAnonymousFields(List<Field> allFields, IList<object> a
}
}

internal static IList<object> GetAnonymousFields(PythonType type) {
internal static IList<object>? GetAnonymousFields(PythonType type) {
object anonymous;
IList<object> anonFields = null;
IList<object>? anonFields = null;
if (type.TryGetBoundAttr(type.Context.SharedContext, type, "_anonymous_", out anonymous)) {
anonFields = anonymous as IList<object>;
if (anonFields == null) {
Expand All @@ -323,16 +325,18 @@ internal static IList<object> GetAnonymousFields(PythonType type) {

internal static void AddAnonymousFields(PythonType type, List<Field> allFields, INativeType cdata, Field newField) {
Field[] childFields;
if (cdata is StructType) {
childFields = ((StructType)cdata)._fields;
} else if (cdata is UnionType) {
childFields = ((UnionType)cdata)._fields;
if (cdata is StructType st) {
st.EnsureFinal();
childFields = st._fields;
} else if (cdata is UnionType un) {
un.EnsureFinal();
childFields = un._fields;
} else {
throw PythonOps.TypeError("anonymous field must be struct or union");
}

foreach (Field existingField in childFields) {
Field anonField = new Field(
var anonField = new Field(
existingField.FieldName,
existingField.NativeType,
checked(existingField.offset + newField.offset),
Expand All @@ -347,12 +351,12 @@ internal static void AddAnonymousFields(PythonType type, List<Field> allFields,
private List<Field> GetBaseSizeAlignmentAndFields(out int size, out int alignment) {
size = 0;
alignment = 1;
List<Field> allFields = new List<Field>();
INativeType lastType = null;
List<Field> allFields = [];
INativeType? lastType = null;
int? totalBitCount = null;
foreach (PythonType pt in BaseTypes) {
StructType st = pt as StructType;
if (st != null) {
if (pt is StructType st) {
st.EnsureFinal();
foreach (Field f in st._fields) {
allFields.Add(f);
UpdateSizeAndAlignment(f.NativeType, f.BitCount, lastType, ref size, ref alignment, ref totalBitCount);
Expand All @@ -368,7 +372,8 @@ private List<Field> GetBaseSizeAlignmentAndFields(out int size, out int alignmen
return allFields;
}

private int UpdateSizeAndAlignment(INativeType cdata, int? bitCount, INativeType lastType, ref int size, ref int alignment, ref int? totalBitCount) {
private int UpdateSizeAndAlignment(INativeType cdata, int? bitCount, INativeType? lastType, ref int size, ref int alignment, ref int? totalBitCount) {
Debug.Assert(totalBitCount == null || lastType != null); // lastType is null only on the first iteration, when totalBitCount is null as well
int prevSize = size;
if (bitCount != null) {
if (lastType != null && lastType.Size != cdata.Size) {
Expand All @@ -382,7 +387,7 @@ private int UpdateSizeAndAlignment(INativeType cdata, int? bitCount, INativeType
if ((bitCount + totalBitCount + 7) / 8 <= cdata.Size) {
totalBitCount = bitCount + totalBitCount;
} else {
size += lastType.Size;
size += lastType!.Size;
prevSize = size;
totalBitCount = bitCount;
}
Expand All @@ -391,7 +396,7 @@ private int UpdateSizeAndAlignment(INativeType cdata, int? bitCount, INativeType
}
} else {
if (totalBitCount != null) {
size += lastType.Size;
size += lastType!.Size;
prevSize = size;
totalBitCount = null;
}
Expand All @@ -411,6 +416,7 @@ private int UpdateSizeAndAlignment(INativeType cdata, int? bitCount, INativeType
return prevSize;
}

[MemberNotNull(nameof(_fields), nameof(_size), nameof(_alignment))]
internal void EnsureFinal() {
if (_fields == null) {
SetFields(PythonTuple.EMPTY);
Expand All @@ -419,6 +425,8 @@ internal void EnsureFinal() {
// track that we were initialized w/o fields.
_fields = _emptyFields;
}
} else if (_size == null || _alignment == null) {
throw new InvalidOperationException("fields initialized w/o size or alignment");
}
}

Expand All @@ -427,19 +435,22 @@ internal void EnsureFinal() {
/// from all of our base classes. If later new _fields_ are added we'll be
/// initialized and these values will be replaced.
/// </summary>
[MemberNotNull(nameof(_size), nameof(_alignment))]
private void EnsureSizeAndAlignment() {
Debug.Assert(_size.HasValue == _alignment.HasValue);
// these are always iniitalized together
// these are always initialized together
if (_size == null) {
lock (this) {
if (_size == null) {
int size, alignment;
GetBaseSizeAlignmentAndFields(out size, out alignment);
GetBaseSizeAlignmentAndFields(out int size, out int alignment);
_size = size;
_alignment = alignment;
}
}
}
if (_alignment == null) {
throw new InvalidOperationException("size and alignment should always be initialized together");
}
}
}
}
Expand Down
Loading

0 comments on commit 62b2f65

Please sign in to comment.