Skip to content

Commit

Permalink
Disable type inference from message content (#6274)
Browse files Browse the repository at this point in the history
  • Loading branch information
timbussmann authored Feb 17, 2022
1 parent 7476f1f commit 9263070
Show file tree
Hide file tree
Showing 9 changed files with 272 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
namespace NServiceBus.AcceptanceTests.Serialization
{
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using AcceptanceTesting;
using EndpointTemplates;
using MessageInterfaces;
using NServiceBus.Pipeline;
using NServiceBus.Serialization;
using NUnit.Framework;
using Settings;

class When_disabling_serializer_type_inference : NServiceBusAcceptanceTest
{
[Test]
public async Task Should_not_deserialize_messages_without_types_header()
{
var context = await Scenario.Define<Context>()
.WithEndpoint<ReceivingEndpoint>(e => e
.DoNotFailOnErrorMessages()
.When(s => s.SendLocal(new MessageWithoutTypeHeader())))
.Done(c => c.IncomingMessageReceived)
.Run(TimeSpan.FromSeconds(20));

Assert.IsFalse(context.HandlerInvoked);
Assert.AreEqual(1, context.FailedMessages.Single().Value.Count);
Exception exception = context.FailedMessages.Single().Value.Single().Exception;
Assert.IsInstanceOf<MessageDeserializationException>(exception);
StringAssert.Contains($"Could not determine the message type from the '{Headers.EnclosedMessageTypes}' header", exception.InnerException.Message);
}

[Test]
public async Task Should_not_deserialize_messages_with_unknown_type_header()
{
var context = await Scenario.Define<Context>()
.WithEndpoint<ReceivingEndpoint>(e => e
.DoNotFailOnErrorMessages()
.When(s => s.SendLocal(new UnknownMessage())))
.Done(c => c.IncomingMessageReceived)
.Run(TimeSpan.FromSeconds(20));

Assert.IsFalse(context.HandlerInvoked);
Assert.AreEqual(1, context.FailedMessages.Single().Value.Count);
Exception exception = context.FailedMessages.Single().Value.Single().Exception;
Assert.IsInstanceOf<MessageDeserializationException>(exception);
StringAssert.Contains($"Could not determine the message type from the '{Headers.EnclosedMessageTypes}' header", exception.InnerException.Message);
}

class Context : ScenarioContext
{
public bool HandlerInvoked { get; set; }
public bool IncomingMessageReceived { get; set; }
}

class ReceivingEndpoint : EndpointConfigurationBuilder
{
public ReceivingEndpoint() =>
EndpointSetup<DefaultServer>(c =>
{
c.Pipeline.Register(typeof(TypeHeaderManipulationBehavior), "Removes the EnclosedMessageTypes header from incoming messages");
var serializerSettings = c.UseSerialization<CustomSerializer>();
serializerSettings.DisableMessageTypeInference();
});

public class MessageHandler : IHandleMessages<MessageWithoutTypeHeader>
{
Context testContext;

public MessageHandler(Context testContext) => this.testContext = testContext;

public Task Handle(MessageWithoutTypeHeader message, IMessageHandlerContext context)
{
testContext.HandlerInvoked = true;
return Task.FromResult(0);
}
}

class TypeHeaderManipulationBehavior : Behavior<IIncomingPhysicalMessageContext>
{
Context testContext;

public TypeHeaderManipulationBehavior(Context testContext) => this.testContext = testContext;

public override Task Invoke(IIncomingPhysicalMessageContext context, Func<Task> next)
{
testContext.IncomingMessageReceived = true;

if (context.MessageHeaders[Headers.EnclosedMessageTypes].Contains(typeof(MessageWithoutTypeHeader).FullName))
{
context.Message.Headers.Remove(Headers.EnclosedMessageTypes);
}
else if (context.MessageHeaders[Headers.EnclosedMessageTypes].Contains(typeof(UnknownMessage).FullName))
{
context.Message.Headers[Headers.EnclosedMessageTypes] = "SomeNamespace.SomeMessageType";
}

return next();
}
}
}

public class MessageWithoutTypeHeader : IMessage
{
}

public class UnknownMessage : IMessage
{
}

class CustomSerializer : SerializationDefinition, IMessageSerializer
{
public string ContentType { get; } = "CustomSerializer";

public void Serialize(object message, Stream stream)
{
stream.WriteByte(42); // need to write some byte for message serialization to work
}

public object[] Deserialize(Stream stream, IList<Type> messageTypes = null)
{
if (messageTypes?.Count > 0)
{
throw new InvalidOperationException("Did not expect message types to be detected in this test");
}

throw new InvalidOperationException("Should not invoke deserializer without type information");
}

public override Func<IMessageMapper, IMessageSerializer> Configure(ReadOnlySettings settings) => _ => this;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
namespace NServiceBus.AcceptanceTests.Serialization
{
using System;
using System.Linq;
using System.Threading.Tasks;
using AcceptanceTesting;
using EndpointTemplates;
using NServiceBus.Pipeline;
using NUnit.Framework;

public class When_message_type_header_is_whitespaces : NServiceBusAcceptanceTest
{
[Test]
public async Task Should_move_message_to_error_queue()
{
var context = await Scenario.Define<Context>()
.WithEndpoint<ReceivingEndpoint>(e => e
.DoNotFailOnErrorMessages()
.When(s => s.SendLocal(new MessageWithEmptyTypeHeader())))
.Done(c => c.IncomingMessageReceived)
.Run(TimeSpan.FromSeconds(20));

Assert.IsFalse(context.HandlerInvoked);
Assert.AreEqual(1, context.FailedMessages.Single().Value.Count);
Exception exception = context.FailedMessages.Single().Value.Single().Exception;
Assert.IsInstanceOf<MessageDeserializationException>(exception);
}

class Context : ScenarioContext
{
public bool HandlerInvoked { get; set; }
public bool IncomingMessageReceived { get; set; }
}

class ReceivingEndpoint : EndpointConfigurationBuilder
{
public ReceivingEndpoint() =>
EndpointSetup<DefaultServer>(c =>
{
c.Pipeline.Register(typeof(TypeHeaderRemovingBehavior), "Removes the EnclosedMessageTypes header from incoming messages");
});

public class MessageHandler : IHandleMessages<MessageWithEmptyTypeHeader>
{
Context testContext;

public MessageHandler(Context testContext) => this.testContext = testContext;

public Task Handle(MessageWithEmptyTypeHeader message, IMessageHandlerContext context)
{
testContext.HandlerInvoked = true;
return Task.FromResult(0);
}
}

class TypeHeaderRemovingBehavior : Behavior<IIncomingPhysicalMessageContext>
{
Context testContext;

public TypeHeaderRemovingBehavior(Context testContext) => this.testContext = testContext;

public override Task Invoke(IIncomingPhysicalMessageContext context, Func<Task> next)
{
testContext.IncomingMessageReceived = true;

// add some whitespace instead of removing the header completely
context.Message.Headers[Headers.EnclosedMessageTypes] = " ";

return next();
}
}
}

public class MessageWithEmptyTypeHeader : IMessage
{
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,11 @@ namespace NServiceBus
public static bool ShouldSkipSerialization(this NServiceBus.Pipeline.IOutgoingLogicalMessageContext context) { }
public static void SkipSerialization(this NServiceBus.Pipeline.IOutgoingLogicalMessageContext context) { }
}
public class static SerializationExtensionsExtensions
{
public static void DisableMessageTypeInference<T>(this NServiceBus.Serialization.SerializationExtensions<T> config)
where T : NServiceBus.Serialization.SerializationDefinition { }
}
public class static SettingsExtensions
{
public static string EndpointName(this NServiceBus.Settings.ReadOnlySettings settings) { }
Expand Down Expand Up @@ -2499,7 +2504,7 @@ namespace NServiceBus.Serialization
public class SerializationExtensions<T> : NServiceBus.Configuration.AdvancedExtensibility.ExposeSettings
where T : NServiceBus.Serialization.SerializationDefinition
{
public SerializationExtensions(NServiceBus.Settings.SettingsHolder settings) { }
public SerializationExtensions(NServiceBus.Settings.SettingsHolder serializerSettings, NServiceBus.Settings.SettingsHolder endpointConfigurationSettings) { }
}
}
namespace NServiceBus.Settings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,11 @@ namespace NServiceBus
public static bool ShouldSkipSerialization(this NServiceBus.Pipeline.IOutgoingLogicalMessageContext context) { }
public static void SkipSerialization(this NServiceBus.Pipeline.IOutgoingLogicalMessageContext context) { }
}
public class static SerializationExtensionsExtensions
{
public static void DisableMessageTypeInference<T>(this NServiceBus.Serialization.SerializationExtensions<T> config)
where T : NServiceBus.Serialization.SerializationDefinition { }
}
public class static SettingsExtensions
{
public static string EndpointName(this NServiceBus.Settings.ReadOnlySettings settings) { }
Expand Down Expand Up @@ -2501,7 +2506,7 @@ namespace NServiceBus.Serialization
public class SerializationExtensions<T> : NServiceBus.Configuration.AdvancedExtensibility.ExposeSettings
where T : NServiceBus.Serialization.SerializationDefinition
{
public SerializationExtensions(NServiceBus.Settings.SettingsHolder settings) { }
public SerializationExtensions(NServiceBus.Settings.SettingsHolder serializerSettings, NServiceBus.Settings.SettingsHolder endpointConfigurationSettings) { }
}
}
namespace NServiceBus.Settings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ namespace NServiceBus

class DeserializeMessageConnector : StageConnector<IIncomingPhysicalMessageContext, IIncomingLogicalMessageContext>
{
public DeserializeMessageConnector(MessageDeserializerResolver deserializerResolver, LogicalMessageFactory logicalMessageFactory, MessageMetadataRegistry messageMetadataRegistry, IMessageMapper mapper)
public DeserializeMessageConnector(MessageDeserializerResolver deserializerResolver, LogicalMessageFactory logicalMessageFactory, MessageMetadataRegistry messageMetadataRegistry, IMessageMapper mapper, bool allowContentTypeInference)
{
this.deserializerResolver = deserializerResolver;
this.logicalMessageFactory = logicalMessageFactory;
this.messageMetadataRegistry = messageMetadataRegistry;
this.mapper = mapper;
this.allowContentTypeInference = allowContentTypeInference;
}

public override async Task Invoke(IIncomingPhysicalMessageContext context, Func<IIncomingLogicalMessageContext, Task> stage)
Expand Down Expand Up @@ -99,12 +100,17 @@ LogicalMessage[] Extract(IncomingMessage physicalMessage)
messageMetadata.Add(metadata);
}

if (messageMetadata.Count == 0 && physicalMessage.GetMessageIntent() != MessageIntentEnum.Publish)
if (messageMetadata.Count == 0 && allowContentTypeInference && physicalMessage.GetMessageIntent() != MessageIntentEnum.Publish)
{
log.WarnFormat("Could not determine message type from message header '{0}'. MessageId: {1}", messageTypeIdentifier, physicalMessage.MessageId);
}
}

if (messageMetadata.Count == 0 && !allowContentTypeInference)
{
throw new Exception($"Could not determine the message type from the '{Headers.EnclosedMessageTypes}' header and message type inference from the message body has been disabled. Ensure the header is set or enable message type inference.");
}

var messageTypes = messageMetadata.Select(metadata => metadata.MessageType).ToList();
var messageSerializer = deserializerResolver.Resolve(physicalMessage.Headers);

Expand Down Expand Up @@ -141,6 +147,7 @@ static bool IsV4OrBelowScheduledTask(string existingTypeString)
readonly LogicalMessageFactory logicalMessageFactory;
readonly MessageMetadataRegistry messageMetadataRegistry;
readonly IMessageMapper mapper;
readonly bool allowContentTypeInference;

static readonly LogicalMessage[] NoMessagesFound = new LogicalMessage[0];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public static SerializationExtensions<T> UseSerialization<T>(this EndpointConfig

var settings = new SettingsHolder();
config.Settings.SetMainSerializer(serializationDefinition, settings);
return CreateSerializationExtension<T>(settings);
return CreateSerializationExtension<T>(settings, config.Settings);
}

/// <summary>
Expand Down Expand Up @@ -67,14 +67,9 @@ public static SerializationExtensions<T> AddDeserializer<T>(this EndpointConfigu

var settings = new SettingsHolder();
additionalSerializers.Add(Tuple.Create<SerializationDefinition, SettingsHolder>(serializationDefinition, settings));
return CreateSerializationExtension<T>(settings);
return CreateSerializationExtension<T>(settings, config.Settings);
}

static SerializationExtensions<T> CreateSerializationExtension<T>(SettingsHolder settings) where T : SerializationDefinition
{
var type = typeof(SerializationExtensions<>).MakeGenericType(typeof(T));
var extension = (SerializationExtensions<T>)Activator.CreateInstance(type, settings);
return extension;
}
static SerializationExtensions<T> CreateSerializationExtension<T>(SettingsHolder serializerSettings, SettingsHolder endpointConfigurationSettings) where T : SerializationDefinition => new SerializationExtensions<T>(serializerSettings, endpointConfigurationSettings);
}
}
8 changes: 5 additions & 3 deletions src/NServiceBus.Core/Serialization/SerializationExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ public class SerializationExtensions<T> : ExposeSettings where T : Serialization
/// <summary>
/// Initializes a new instance of <see cref="SerializationExtensions{T}" />.
/// </summary>
public SerializationExtensions(SettingsHolder settings) : base(settings)
{
}
public SerializationExtensions(SettingsHolder serializerSettings, SettingsHolder endpointConfigurationSettings) : base(serializerSettings)
=> EndpointConfigurationSettings = endpointConfigurationSettings;

// provides access to the settings backing EndpointConfiguration. The settings provided by the 'Settings' property are isolated settings for the serializer.
internal readonly SettingsHolder EndpointConfigurationSettings;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
namespace NServiceBus
{
using Serialization;
using Settings;

/// <summary>
/// Provides extensions methods for the <see cref="SerializationExtensions{T}"/> class.
/// </summary>
public static class SerializationExtensionsExtensions
{
/// <summary>
/// Disables inference of message type based on the content type if the message type can't be determined by the 'NServiceBus.EnclosedMessageTypes' header.
/// </summary>
public static void DisableMessageTypeInference<T>(this SerializationExtensions<T> config) where T : SerializationDefinition
{
Guard.AgainstNull(nameof(config), config);

config.EndpointConfigurationSettings.Set(DisableMessageTypeInferenceKey, true);
}

internal static bool IsMessageTypeInferenceEnabled(this ReadOnlySettings endpointConfigurationSettings) =>
!endpointConfigurationSettings.GetOrDefault<bool>(DisableMessageTypeInferenceKey);

const string DisableMessageTypeInferenceKey = "NServiceBus.Serialization.DisableMessageTypeInference";
}
}
7 changes: 4 additions & 3 deletions src/NServiceBus.Core/Serialization/SerializationFeature.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ protected internal sealed override void Setup(FeatureConfigurationContext contex
});
}

var allowMessageTypeInference = settings.IsMessageTypeInferenceEnabled();
var resolver = new MessageDeserializerResolver(defaultSerializer, additionalDeserializers);

var logicalMessageFactory = new LogicalMessageFactory(messageMetadataRegistry, mapper);
context.Pipeline.Register("DeserializeLogicalMessagesConnector", new DeserializeMessageConnector(resolver, logicalMessageFactory, messageMetadataRegistry, mapper), "Deserializes the physical message body into logical messages");
context.Pipeline.Register("DeserializeLogicalMessagesConnector", new DeserializeMessageConnector(resolver, logicalMessageFactory, messageMetadataRegistry, mapper, allowMessageTypeInference), "Deserializes the physical message body into logical messages");
context.Pipeline.Register("SerializeMessageConnector", new SerializeMessageConnector(defaultSerializer, messageMetadataRegistry), "Converts a logical message into a physical message");

context.Container.ConfigureComponent(_ => mapper, DependencyLifecycle.SingleInstance);
Expand All @@ -68,7 +68,8 @@ protected internal sealed override void Setup(FeatureConfigurationContext contex
Version = FileVersionRetriever.GetFileVersion(defaultSerializerAndDefinition.Item1.GetType()),
defaultSerializer.ContentType
},
AdditionalDeserializers = additionalDeserializerDiagnostics
AdditionalDeserializers = additionalDeserializerDiagnostics,
AllowMessageTypeInference = allowMessageTypeInference
});
}

Expand Down

0 comments on commit 9263070

Please sign in to comment.