Skip to content

Commit

Permalink
Merge pull request #3 from Passant-Ihab/SupportJson
Browse files Browse the repository at this point in the history
Support the Response format (Json mode) Provided by Mistral
  • Loading branch information
tghamm authored May 18, 2024
2 parents 4f65b01 + eacb91e commit f6c8e3e
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 4 deletions.
45 changes: 45 additions & 0 deletions Mistral.SDK.Tests/Completions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,51 @@ public async Task TestMistralCompletionModel5()

}

public class JsonResult
{
public string result { get; set; }
}

[TestMethod]
public async Task TestMistralCompletionJsonMode()
{
var client = new MistralClient();
var request = new ChatCompletionRequest(ModelDefinitions.MistralLarge, new List<ChatMessage>()
{
new ChatMessage(ChatMessage.RoleEnum.System, "You are an expert at writing Json."),
new ChatMessage(ChatMessage.RoleEnum.User, "Write me a simple 'hello world' statement in a json object with a single 'result' key.")
}, responseFormat: new ResponseFormat()
{
Type = ResponseFormat.ResponseFormatEnum.JSON
});
var response = await client.Completions.GetCompletionAsync(request);
//parse json
var result = JsonSerializer.Deserialize<JsonResult>(response.Choices.First().Message.Content);
Assert.IsNotNull(result);
}

[TestMethod]
public async Task TestMistralCompletionJsonModeStreaming()
{
var client = new MistralClient();
var request = new ChatCompletionRequest(ModelDefinitions.MistralLarge, new List<ChatMessage>()
{
new ChatMessage(ChatMessage.RoleEnum.System, "You are an expert at writing Json."),
new ChatMessage(ChatMessage.RoleEnum.User, "Write me a simple 'hello world' statement in a json object with a single 'result' key.")
}, responseFormat: new ResponseFormat()
{
Type = ResponseFormat.ResponseFormatEnum.JSON
});
var response = string.Empty;
await foreach (var res in client.Completions.StreamCompletionAsync(request))
{
response += res.Choices.First().Delta.Content;
}
//parse json
var result = JsonSerializer.Deserialize<JsonResult>(response);
Assert.IsNotNull(result);
}

[TestMethod]
public async Task TestMistralCompletionSafeWithOptions()
{
Expand Down
18 changes: 14 additions & 4 deletions Mistral.SDK/DTOs/ChatCompletionRequest.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using Mistral.SDK.Converters;
using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;
using System.Text.Json.Serialization;
Expand All @@ -18,7 +19,8 @@ public class ChatCompletionRequest
/// <param name="stream">Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message. Otherwise, the server will hold the request open until the timeout or until completion, with the response containing the full result as JSON. (default to false).</param>
/// <param name="safePrompt">Whether to inject a safety prompt before all conversations. (default to false).</param>
/// <param name="randomSeed">The seed to use for random sampling. If set, different calls will generate deterministic results. .</param>
public ChatCompletionRequest(string model = default(string), List<ChatMessage> messages = default(List<ChatMessage>), decimal? temperature = 0.7M, decimal? topP = 1M, int? maxTokens = default(int?), bool? stream = false, bool safePrompt = false, int? randomSeed = default(int?))
/// <param name="responseFormat">The response format needed If set, the API will be forced to return the data in this mode</param>
public ChatCompletionRequest(string model = default(string), List<ChatMessage> messages = default(List<ChatMessage>), decimal? temperature = 0.7M, decimal? topP = 1M, int? maxTokens = default(int?), bool? stream = false, bool safePrompt = false, int? randomSeed = default(int?), ResponseFormat responseFormat = default)
{
// to ensure "model" is required (not null)
if (model == null)
Expand All @@ -41,6 +43,7 @@ public class ChatCompletionRequest
this.Stream = stream ?? false;
this.SafePrompt = safePrompt;
this.RandomSeed = randomSeed;
this.ResponseFormat = responseFormat;
}
/// <summary>
/// ID of the model to use. You can use the [List Available Models](/api#operation/listModels) API to see all of your available models, or see our [Model overview](/models) for model descriptions.
Expand Down Expand Up @@ -103,7 +106,14 @@ public class ChatCompletionRequest
[JsonPropertyName("random_seed")]
public int? RandomSeed { get; set; }

IEnumerable<ValidationResult> Validate()
/// <summary>
///
/// </summary>
[JsonPropertyName("response_format")]
public ResponseFormat ResponseFormat { get; set; }


IEnumerable<ValidationResult> Validate()
{
// Temperature (decimal?) maximum
if (this.Temperature > (decimal?)1)
Expand Down Expand Up @@ -137,7 +147,7 @@ IEnumerable<ValidationResult> Validate()

yield break;
}
}
}


}
28 changes: 28 additions & 0 deletions Mistral.SDK/DTOs/ResponseFormat.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using Mistral.SDK.Converters;
using System;
using System.Collections.Generic;
using System.Text;
using System.Text.Json.Serialization;

namespace Mistral.SDK.DTOs
{
public class ResponseFormat
{
[JsonConverter(typeof(JsonPropertyNameEnumConverter<ResponseFormatEnum>))]
public enum ResponseFormatEnum
{
/// <summary>
/// Enum json for value: json_object
/// </summary>
[JsonPropertyName("json_object")]
JSON = 1
}

/// <summary>
/// The output type
/// </summary>
[JsonPropertyName("type")]

public ResponseFormatEnum Type { get; set; }
}
}

0 comments on commit f6c8e3e

Please sign in to comment.