Skip to content

Commit

Permalink
Fix NEO callstates (#3599)
Browse files Browse the repository at this point in the history
* Allow callstates to use HF

* Rename to method

* Other rename

* Change the way

* Reduce changes

* Reduce changes

* Adapt name always

* Avoid string when only is lower the first char

* UT

* Test all

* Update src/Neo/ProtocolSettings.cs

Co-authored-by: Christopher Schuchardt <[email protected]>

* Update src/Neo/ProtocolSettings.cs

Co-authored-by: Christopher Schuchardt <[email protected]>

* Reuse Load from stream

* Unify

* Fix default logic

* Change ContractMethod to allowMultiple

* Use LowerInvariant

* Move CheckingHardfork

* Remove optional arg

* Fix build

* Avoid file not found error

---------

Co-authored-by: Christopher Schuchardt <[email protected]>
  • Loading branch information
shargon and cschuchardt88 authored Dec 19, 2024
1 parent f91b680 commit 705f4bb
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 31 deletions.
29 changes: 22 additions & 7 deletions src/Neo/ProtocolSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.IO;
using System.Linq;

namespace Neo
Expand Down Expand Up @@ -123,19 +124,32 @@ public record ProtocolSettings

public static ProtocolSettings Custom { get; set; }

/// <summary>
/// Loads the <see cref="ProtocolSettings"/> from the specified stream.
/// </summary>
/// <param name="stream">The stream of the settings.</param>
/// <returns>The loaded <see cref="ProtocolSettings"/>.</returns>
public static ProtocolSettings Load(Stream stream)
{
var config = new ConfigurationBuilder().AddJsonStream(stream).Build();
var section = config.GetSection("ProtocolConfiguration");
return Load(section);
}

/// <summary>
/// Loads the <see cref="ProtocolSettings"/> at the specified path.
/// </summary>
/// <param name="path">The path of the settings file.</param>
/// <param name="optional">Indicates whether the file is optional.</param>
/// <returns>The loaded <see cref="ProtocolSettings"/>.</returns>
public static ProtocolSettings Load(string path, bool optional = true)
public static ProtocolSettings Load(string path)
{
IConfigurationRoot config = new ConfigurationBuilder().AddJsonFile(path, optional).Build();
IConfigurationSection section = config.GetSection("ProtocolConfiguration");
var settings = Load(section);
CheckingHardfork(settings);
return settings;
if (!File.Exists(path))
{
return Default;
}

using var stream = File.OpenRead(path);
return Load(stream);
}

/// <summary>
Expand Down Expand Up @@ -165,6 +179,7 @@ public static ProtocolSettings Load(IConfigurationSection section)
? EnsureOmmitedHardforks(section.GetSection("Hardforks").GetChildren().ToDictionary(p => Enum.Parse<Hardfork>(p.Key, true), p => uint.Parse(p.Value))).ToImmutableDictionary()
: Default.Hardforks
};
CheckingHardfork(Custom);
return Custom;
}

Expand Down
3 changes: 2 additions & 1 deletion src/Neo/SmartContract/Native/ContractMethodAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
namespace Neo.SmartContract.Native
{
[DebuggerDisplay("{Name}")]
[AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, AllowMultiple = false)]
// We allow multiple attributes because the fees or requiredCallFlags may change between hard forks.
[AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, AllowMultiple = true)]
internal class ContractMethodAttribute : Attribute, IHardforkActivable
{
public string Name { get; init; }
Expand Down
3 changes: 2 additions & 1 deletion src/Neo/SmartContract/Native/ContractMethodMetadata.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ internal class ContractMethodMetadata : IHardforkActivable

public ContractMethodMetadata(MemberInfo member, ContractMethodAttribute attribute)
{
Name = attribute.Name ?? member.Name.ToLower()[0] + member.Name[1..];
Name = attribute.Name ?? member.Name;
Name = Name.ToLowerInvariant()[0] + Name[1..];
Handler = member switch
{
MethodInfo m => m,
Expand Down
2 changes: 1 addition & 1 deletion src/Neo/SmartContract/Native/CryptoLib.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public static bool VerifyWithECDsa(byte[] message, byte[] pubkey, byte[] signatu
}

// This is for solving the hardfork issue in https://github.com/neo-project/neo/pull/3209
[ContractMethod(true, Hardfork.HF_Cockatrice, CpuFee = 1 << 15, Name = "verifyWithECDsa")]
[ContractMethod(true, Hardfork.HF_Cockatrice, CpuFee = 1 << 15, Name = nameof(VerifyWithECDsa))]
public static bool VerifyWithECDsaV0(byte[] message, byte[] pubkey, byte[] signature, NamedCurveHash curve)
{
if (curve != NamedCurveHash.secp256k1SHA256 && curve != NamedCurveHash.secp256r1SHA256)
Expand Down
2 changes: 1 addition & 1 deletion src/Neo/SmartContract/Native/LedgerContract.cs
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ public Transaction GetTransaction(DataCache snapshot, UInt256 hash)
return GetTransactionState(snapshot, hash)?.Transaction;
}

[ContractMethod(CpuFee = 1 << 15, RequiredCallFlags = CallFlags.ReadStates, Name = "getTransaction")]
[ContractMethod(CpuFee = 1 << 15, RequiredCallFlags = CallFlags.ReadStates, Name = nameof(GetTransaction))]
private Transaction GetTransactionForContract(ApplicationEngine engine, UInt256 hash)
{
TransactionState state = GetTransactionState(engine.SnapshotCache, hash);
Expand Down
25 changes: 16 additions & 9 deletions src/Neo/SmartContract/Native/NativeContract.cs
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,12 @@ protected NativeContract()
// Reflection to get the methods

List<ContractMethodMetadata> listMethods = [];
foreach (MemberInfo member in GetType().GetMembers(BindingFlags.Instance | BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public))
foreach (var member in GetType().GetMembers(BindingFlags.Instance | BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public))
{
ContractMethodAttribute attribute = member.GetCustomAttribute<ContractMethodAttribute>();
if (attribute is null) continue;
listMethods.Add(new ContractMethodMetadata(member, attribute));
foreach (var attribute in member.GetCustomAttributes<ContractMethodAttribute>())
{
listMethods.Add(new ContractMethodMetadata(member, attribute));
}
}
_methodDescriptors = listMethods.OrderBy(p => p.Name, StringComparer.Ordinal).ThenBy(p => p.Parameters.Length).ToList().AsReadOnly();

Expand Down Expand Up @@ -363,23 +364,29 @@ public static NativeContract GetContract(UInt160 hash)
return contract;
}

internal Dictionary<int, ContractMethodMetadata> GetContractMethods(ApplicationEngine engine)
{
var nativeContracts = engine.GetState(() => new NativeContractsCache());
var currentAllowedMethods = nativeContracts.GetAllowedMethods(this, engine);
return currentAllowedMethods.Methods;
}

internal async void Invoke(ApplicationEngine engine, byte version)
{
try
{
if (version != 0)
throw new InvalidOperationException($"The native contract of version {version} is not active.");
// Get native contracts invocation cache
NativeContractsCache nativeContracts = engine.GetState(() => new NativeContractsCache());
NativeContractsCache.CacheEntry currentAllowedMethods = nativeContracts.GetAllowedMethods(this, engine);
var currentAllowedMethods = GetContractMethods(engine);
// Check if the method is allowed
ExecutionContext context = engine.CurrentContext;
ContractMethodMetadata method = currentAllowedMethods.Methods[context.InstructionPointer];
var context = engine.CurrentContext;
var method = currentAllowedMethods[context.InstructionPointer];
if (method.ActiveIn is not null && !engine.IsHardforkEnabled(method.ActiveIn.Value))
throw new InvalidOperationException($"Cannot call this method before hardfork {method.ActiveIn}.");
if (method.DeprecatedIn is not null && engine.IsHardforkEnabled(method.DeprecatedIn.Value))
throw new InvalidOperationException($"Cannot call this method after hardfork {method.DeprecatedIn}.");
ExecutionContextState state = context.GetState<ExecutionContextState>();
var state = context.GetState<ExecutionContextState>();
if (!state.CallFlags.HasFlag(method.RequiredCallFlags))
throw new InvalidOperationException($"Cannot call this method with the flag {state.CallFlags}.");
// In the unit of datoshi, 1 datoshi = 1e-8 GAS
Expand Down
9 changes: 6 additions & 3 deletions src/Neo/SmartContract/Native/NeoToken.cs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,8 @@ public BigInteger UnclaimedGas(DataCache snapshot, UInt160 account, uint end)
return CalculateBonus(snapshot, state, end);
}

[ContractMethod(RequiredCallFlags = CallFlags.States)]
[ContractMethod(true, Hardfork.HF_Echidna, RequiredCallFlags = CallFlags.States)]
[ContractMethod(Hardfork.HF_Echidna, /* */ RequiredCallFlags = CallFlags.States | CallFlags.AllowNotify)]
private bool RegisterCandidate(ApplicationEngine engine, ECPoint pubkey)
{
if (!engine.CheckWitnessInternal(Contract.CreateSignatureRedeemScript(pubkey).ToScriptHash()))
Expand All @@ -349,7 +350,8 @@ private bool RegisterCandidate(ApplicationEngine engine, ECPoint pubkey)
return true;
}

[ContractMethod(CpuFee = 1 << 16, RequiredCallFlags = CallFlags.States)]
[ContractMethod(true, Hardfork.HF_Echidna, CpuFee = 1 << 16, RequiredCallFlags = CallFlags.States)]
[ContractMethod(Hardfork.HF_Echidna, /* */ CpuFee = 1 << 16, RequiredCallFlags = CallFlags.States | CallFlags.AllowNotify)]
private bool UnregisterCandidate(ApplicationEngine engine, ECPoint pubkey)
{
if (!engine.CheckWitnessInternal(Contract.CreateSignatureRedeemScript(pubkey).ToScriptHash()))
Expand All @@ -366,7 +368,8 @@ private bool UnregisterCandidate(ApplicationEngine engine, ECPoint pubkey)
return true;
}

[ContractMethod(CpuFee = 1 << 16, RequiredCallFlags = CallFlags.States)]
[ContractMethod(true, Hardfork.HF_Echidna, CpuFee = 1 << 16, RequiredCallFlags = CallFlags.States)]
[ContractMethod(Hardfork.HF_Echidna, /* */ CpuFee = 1 << 16, RequiredCallFlags = CallFlags.States | CallFlags.AllowNotify)]
private async ContractTask<bool> Vote(ApplicationEngine engine, UInt160 account, ECPoint voteTo)
{
if (!engine.CheckWitnessInternal(account)) return false;
Expand Down
6 changes: 3 additions & 3 deletions tests/Neo.UnitTests/SmartContract/Native/UT_NativeContract.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public void TestActiveDeprecatedIn()
string json = UT_ProtocolSettings.CreateHFSettings("\"HF_Cockatrice\": 20");
var file = Path.GetTempFileName();
File.WriteAllText(file, json);
ProtocolSettings settings = ProtocolSettings.Load(file, false);
ProtocolSettings settings = ProtocolSettings.Load(file);
File.Delete(file);

Assert.IsFalse(NativeContract.IsActive(new active() { ActiveIn = Hardfork.HF_Cockatrice, DeprecatedIn = null }, settings.IsHardforkEnabled, 1));
Expand All @@ -87,7 +87,7 @@ public void TestActiveDeprecatedInRoleManagement()
string json = UT_ProtocolSettings.CreateHFSettings("\"HF_Echidna\": 20");
var file = Path.GetTempFileName();
File.WriteAllText(file, json);
ProtocolSettings settings = ProtocolSettings.Load(file, false);
ProtocolSettings settings = ProtocolSettings.Load(file);
File.Delete(file);

var before = NativeContract.RoleManagement.GetContractState(settings.IsHardforkEnabled, 19);
Expand All @@ -112,7 +112,7 @@ public void TestIsInitializeBlock()

var file = Path.GetTempFileName();
File.WriteAllText(file, json);
ProtocolSettings settings = ProtocolSettings.Load(file, false);
ProtocolSettings settings = ProtocolSettings.Load(file);
File.Delete(file);

Assert.IsTrue(NativeContract.CryptoLib.IsInitializeBlock(settings, 0, out var hf));
Expand Down
45 changes: 45 additions & 0 deletions tests/Neo.UnitTests/SmartContract/Native/UT_NeoToken.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
using Neo.VM;
using Neo.Wallets;
using System;
using System.IO;
using System.Linq;
using System.Numerics;
using System.Security.Principal;
using System.Text;
using static Neo.SmartContract.Native.NeoToken;

namespace Neo.UnitTests.SmartContract.Native
Expand Down Expand Up @@ -54,6 +57,48 @@ public void TestSetup()
[TestMethod]
public void Check_Decimals() => NativeContract.NEO.Decimals(_snapshotCache).Should().Be(0);

[TestMethod]
public void Test_HF_EchidnaStates()
{
string json = UT_ProtocolSettings.CreateHFSettings("\"HF_Echidna\": 10");
using var stream = new MemoryStream(Encoding.UTF8.GetBytes(json));
var settings = ProtocolSettings.Load(stream);

var clonedCache = _snapshotCache.CloneCache();
var persistingBlock = new Block { Header = new Header() };

foreach (var method in new string[] { "vote", "registerCandidate", "unregisterCandidate" })
{
// Test WITHOUT HF_Echidna

persistingBlock.Header.Index = 9;

using (var engine = ApplicationEngine.Create(TriggerType.Application,
new Nep17NativeContractExtensions.ManualWitness(UInt160.Zero), clonedCache, persistingBlock, settings: settings))
{
var methods = NativeContract.NEO.GetContractMethods(engine);
var entries = methods.Values.Where(u => u.Name == method).ToArray();

Assert.AreEqual(entries.Length, 1);
Assert.AreEqual(entries[0].RequiredCallFlags, CallFlags.States);
}

// Test WITH HF_Echidna

persistingBlock.Header.Index = 10;

using (var engine = ApplicationEngine.Create(TriggerType.Application,
new Nep17NativeContractExtensions.ManualWitness(UInt160.Zero), clonedCache, persistingBlock, settings: settings))
{
var methods = NativeContract.NEO.GetContractMethods(engine);
var entries = methods.Values.Where(u => u.Name == method).ToArray();

Assert.AreEqual(entries.Length, 1);
Assert.AreEqual(entries[0].RequiredCallFlags, CallFlags.States | CallFlags.AllowNotify);
}
}
}

[TestMethod]
public void Check_Vote()
{
Expand Down
10 changes: 5 additions & 5 deletions tests/Neo.UnitTests/UT_ProtocolSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public void HardForkTestBAndNotA()

var file = Path.GetTempFileName();
File.WriteAllText(file, json);
ProtocolSettings settings = ProtocolSettings.Load(file, false);
ProtocolSettings settings = ProtocolSettings.Load(file);
File.Delete(file);

settings.Hardforks[Hardfork.HF_Aspidochelone].Should().Be(0);
Expand All @@ -78,7 +78,7 @@ public void HardForkTestAAndNotB()

var file = Path.GetTempFileName();
File.WriteAllText(file, json);
ProtocolSettings settings = ProtocolSettings.Load(file, false);
ProtocolSettings settings = ProtocolSettings.Load(file);
File.Delete(file);

settings.Hardforks[Hardfork.HF_Aspidochelone].Should().Be(0);
Expand All @@ -100,7 +100,7 @@ public void HardForkTestNone()

var file = Path.GetTempFileName();
File.WriteAllText(file, json);
ProtocolSettings settings = ProtocolSettings.Load(file, false);
ProtocolSettings settings = ProtocolSettings.Load(file);
File.Delete(file);

settings.Hardforks[Hardfork.HF_Aspidochelone].Should().Be(0);
Expand All @@ -120,7 +120,7 @@ public void HardForkTestAMoreThanB()
string json = CreateHFSettings("\"HF_Aspidochelone\": 4120001, \"HF_Basilisk\": 4120000");
var file = Path.GetTempFileName();
File.WriteAllText(file, json);
Assert.ThrowsException<ArgumentException>(() => ProtocolSettings.Load(file, false));
Assert.ThrowsException<ArgumentException>(() => ProtocolSettings.Load(file));
File.Delete(file);
}

Expand Down Expand Up @@ -316,7 +316,7 @@ public void TestTimePerBlockCalculation()
[TestMethod]
public void TestLoad()
{
var loadedSetting = ProtocolSettings.Load("test.config.json", false);
var loadedSetting = ProtocolSettings.Load("test.config.json");

// Comparing all properties
TestProtocolSettings.Default.Network.Should().Be(loadedSetting.Network);
Expand Down

0 comments on commit 705f4bb

Please sign in to comment.