diff --git a/src/AI/ChatTools.cs b/src/AI/ChatTools.cs new file mode 100644 index 00000000..e4c1cba6 --- /dev/null +++ b/src/AI/ChatTools.cs @@ -0,0 +1,92 @@ +using System; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using OpenAI.Chat; + +namespace SourceGit.AI +{ + public static class ChatTools + { + public static readonly ChatTool Tool_GetDetailChangesInFile = ChatTool.CreateFunctionTool( + nameof(GetDetailChangesInFile), + "Get the detailed changes in the specified file in the specified repository.", + BinaryData.FromBytes(Encoding.UTF8.GetBytes(""" + { + "type": "object", + "properties": { + "repo": { + "type": "string", + "description": "The path to the repository." + }, + "file": { + "type": "string", + "description": "The path to the file." + }, + "originalFile": { + "type": "string", + "description": "The path to the original file when it has been renamed." + } + }, + "required": ["repo", "file"] + } + """)), false); + + public static async Task Process(ChatToolCall call, Action output) + { + using var doc = JsonDocument.Parse(call.FunctionArguments); + + switch (call.FunctionName) + { + case nameof(GetDetailChangesInFile): + { + var hasRepo = doc.RootElement.TryGetProperty("repo", out var repoPath); + var hasFile = doc.RootElement.TryGetProperty("file", out var filePath); + var hasOriginalFile = doc.RootElement.TryGetProperty("originalFile", out var originalFilePath); + if (!hasRepo) + throw new ArgumentException("repo", "The repo argument is required"); + if (!hasFile) + throw new ArgumentException("file", "The file argument is required"); + + output?.Invoke($"Read changes in file: {filePath.GetString()}"); + + var toolResult = await ChatTools.GetDetailChangesInFile( + repoPath.GetString(), + filePath.GetString(), + hasOriginalFile ? originalFilePath.GetString() : string.Empty); + return new ToolChatMessage(call.Id, toolResult); + } + default: + throw new NotSupportedException($"The tool {call.FunctionName} is not supported"); + } + } + + private static async Task GetDetailChangesInFile(string repo, string file, string originalFile) + { + var rs = await new GetDiffContentCommand(repo, file, originalFile).ReadAsync(); + return rs.IsSuccess ? rs.StdOut : string.Empty; + } + + private class GetDiffContentCommand : Commands.Command + { + public GetDiffContentCommand(string repo, string file, string originalFile) + { + WorkingDirectory = repo; + Context = repo; + + var builder = new StringBuilder(); + builder.Append("diff --no-color --no-ext-diff --diff-algorithm=minimal --cached -- "); + if (!string.IsNullOrEmpty(originalFile) && !file.Equals(originalFile, StringComparison.Ordinal)) + builder.Append(originalFile.Quoted()).Append(' '); + builder.Append(file.Quoted()); + + Args = builder.ToString(); + } + + public async Task ReadAsync() + { + return await ReadToEndAsync().ConfigureAwait(false); + } + } + } +} diff --git a/src/AI/Service.cs b/src/AI/Service.cs new file mode 100644 index 00000000..70e29ab6 --- /dev/null +++ b/src/AI/Service.cs @@ -0,0 +1,98 @@ +using System; +using System.ClientModel; +using System.Collections.Generic; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Azure.AI.OpenAI; +using OpenAI; +using OpenAI.Chat; + +namespace SourceGit.AI +{ + public class Service + { + public Service(Models.AIProvider ai) + { + _ai = ai; + } + + public async Task GenerateCommitMessage(string repo, string changeList, Action onUpdate, CancellationToken cancellation) + { + var key = _ai.ReadApiKeyFromEnv ? Environment.GetEnvironmentVariable(_ai.ApiKey) : _ai.ApiKey; + var endPoint = new Uri(_ai.Server); + var credential = new ApiKeyCredential(key); + var client = _ai.Server.Contains("openai.azure.com/", StringComparison.Ordinal) + ? new AzureOpenAIClient(endPoint, credential) + : new OpenAIClient(credential, new() { Endpoint = endPoint }); + + var chatClient = client.GetChatClient(_ai.Model); + var options = new ChatCompletionOptions() { Tools = { ChatTools.Tool_GetDetailChangesInFile } }; + + var userMessageBuilder = new StringBuilder(); + userMessageBuilder + .AppendLine("Generate a commit message (follow the rule of conventional commit message) for given git repository.") + .AppendLine("- Read all given changed files before generating. Do not skip any one file.") + .Append("Reposiory path: ").AppendLine(repo.Quoted()) + .AppendLine("Changed files: ") + .Append(changeList); + + var messages = new List() { new UserChatMessage(userMessageBuilder.ToString()) }; + + do + { + var inProgress = false; + var updates = chatClient.CompleteChatStreamingAsync(messages, options).WithCancellation(cancellation); + var toolCalls = new ToolCallsBuilder(); + var contentBuilder = new StringBuilder(); + + await foreach (var update in updates) + { + foreach (var contentPart in update.ContentUpdate) + contentBuilder.Append(contentPart.Text); + + foreach (var toolCall in update.ToolCallUpdates) + toolCalls.Append(toolCall); + + switch (update.FinishReason) + { + case ChatFinishReason.Stop: + onUpdate?.Invoke(string.Empty); + onUpdate?.Invoke("[Assistant]:"); + onUpdate?.Invoke(contentBuilder.ToString()); + break; + case ChatFinishReason.Length: + throw new Exception("The response was cut off because it reached the maximum length. Consider increasing the max tokens limit."); + case ChatFinishReason.ToolCalls: + { + var calls = toolCalls.Build(); + var assistantMessage = new AssistantChatMessage(calls); + if (contentBuilder.Length > 0) + assistantMessage.Content.Add(ChatMessageContentPart.CreateTextPart(contentBuilder.ToString())); + messages.Add(assistantMessage); + + foreach (var call in calls) + { + var result = await ChatTools.Process(call, onUpdate); + messages.Add(result); + } + + inProgress = true; + break; + } + case ChatFinishReason.ContentFilter: + throw new Exception("Ommitted content due to a content filter flag"); + default: + break; + } + + } + + if (!inProgress) + break; + } while (true); + } + + private readonly Models.AIProvider _ai; + } +} diff --git a/src/AI/ToolCallsBuilder.cs b/src/AI/ToolCallsBuilder.cs new file mode 100644 index 00000000..948e3104 --- /dev/null +++ b/src/AI/ToolCallsBuilder.cs @@ -0,0 +1,119 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using OpenAI.Chat; + +namespace SourceGit.AI +{ + public class ToolCallsBuilder + { + private readonly Dictionary _indexToToolCallId = []; + private readonly Dictionary _indexToFunctionName = []; + private readonly Dictionary> _indexToFunctionArguments = []; + + public void Append(StreamingChatToolCallUpdate toolCallUpdate) + { + if (toolCallUpdate.ToolCallId != null) + { + _indexToToolCallId[toolCallUpdate.Index] = toolCallUpdate.ToolCallId; + } + + if (toolCallUpdate.FunctionName != null) + { + _indexToFunctionName[toolCallUpdate.Index] = toolCallUpdate.FunctionName; + } + + if (toolCallUpdate.FunctionArgumentsUpdate != null && !toolCallUpdate.FunctionArgumentsUpdate.ToMemory().IsEmpty) + { + if (!_indexToFunctionArguments.TryGetValue(toolCallUpdate.Index, out SequenceBuilder argumentsBuilder)) + { + argumentsBuilder = new SequenceBuilder(); + _indexToFunctionArguments[toolCallUpdate.Index] = argumentsBuilder; + } + + argumentsBuilder.Append(toolCallUpdate.FunctionArgumentsUpdate); + } + } + + public IReadOnlyList Build() + { + List toolCalls = []; + + foreach ((int index, string toolCallId) in _indexToToolCallId) + { + ReadOnlySequence sequence = _indexToFunctionArguments[index].Build(); + + ChatToolCall toolCall = ChatToolCall.CreateFunctionToolCall( + id: toolCallId, + functionName: _indexToFunctionName[index], + functionArguments: BinaryData.FromBytes(sequence.ToArray())); + + toolCalls.Add(toolCall); + } + + return toolCalls; + } + } + + public class SequenceBuilder + { + Segment _first; + Segment _last; + + public void Append(ReadOnlyMemory data) + { + if (_first == null) + { + Debug.Assert(_last == null); + _first = new Segment(data); + _last = _first; + } + else + { + _last = _last!.Append(data); + } + } + + public ReadOnlySequence Build() + { + if (_first == null) + { + Debug.Assert(_last == null); + return ReadOnlySequence.Empty; + } + + if (_first == _last) + { + Debug.Assert(_first.Next == null); + return new ReadOnlySequence(_first.Memory); + } + + return new ReadOnlySequence(_first, 0, _last!, _last!.Memory.Length); + } + + private sealed class Segment : ReadOnlySequenceSegment + { + public Segment(ReadOnlyMemory items) : this(items, 0) + { + } + + private Segment(ReadOnlyMemory items, long runningIndex) + { + Debug.Assert(runningIndex >= 0); + Memory = items; + RunningIndex = runningIndex; + } + + public Segment Append(ReadOnlyMemory items) + { + long runningIndex; + checked + { runningIndex = RunningIndex + Memory.Length; } + Segment segment = new(items, runningIndex); + Next = segment; + return segment; + } + } + } +} diff --git a/src/Commands/GenerateCommitMessage.cs b/src/Commands/GenerateCommitMessage.cs deleted file mode 100644 index bbefa34e..00000000 --- a/src/Commands/GenerateCommitMessage.cs +++ /dev/null @@ -1,101 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; -using System.Threading; -using System.Threading.Tasks; - -namespace SourceGit.Commands -{ - /// - /// A C# version of https://github.com/anjerodev/commitollama - /// - public class GenerateCommitMessage - { - public class GetDiffContent : Command - { - public GetDiffContent(string repo, Models.DiffOption opt) - { - WorkingDirectory = repo; - Context = repo; - Args = $"diff --no-color --no-ext-diff --diff-algorithm=minimal {opt}"; - } - - public async Task ReadAsync() - { - return await ReadToEndAsync().ConfigureAwait(false); - } - } - - public GenerateCommitMessage(Models.OpenAIService service, string repo, List changes, CancellationToken cancelToken, Action onResponse) - { - _service = service; - _repo = repo; - _changes = changes; - _cancelToken = cancelToken; - _onResponse = onResponse; - } - - public async Task ExecAsync() - { - try - { - _onResponse?.Invoke("Waiting for pre-file analyzing to completed...\n\n"); - - var responseBuilder = new StringBuilder(); - var summaryBuilder = new StringBuilder(); - foreach (var change in _changes) - { - if (_cancelToken.IsCancellationRequested) - return; - - responseBuilder.Append("- "); - summaryBuilder.Append("- "); - - var rs = await new GetDiffContent(_repo, new Models.DiffOption(change, false)).ReadAsync(); - if (rs.IsSuccess) - { - await _service.ChatAsync( - _service.AnalyzeDiffPrompt, - $"Here is the `git diff` output: {rs.StdOut}", - _cancelToken, - update => - { - responseBuilder.Append(update); - summaryBuilder.Append(update); - - _onResponse?.Invoke($"Waiting for pre-file analyzing to completed...\n\n{responseBuilder}"); - }); - } - - responseBuilder.AppendLine(); - summaryBuilder.Append("(file: ").Append(change.Path).AppendLine(")"); - } - - if (_cancelToken.IsCancellationRequested) - return; - - var responseBody = responseBuilder.ToString(); - var subjectBuilder = new StringBuilder(); - await _service.ChatAsync( - _service.GenerateSubjectPrompt, - $"Here are the summaries changes:\n{summaryBuilder}", - _cancelToken, - update => - { - subjectBuilder.Append(update); - _onResponse?.Invoke($"{subjectBuilder}\n\n{responseBody}"); - }); - } - catch (Exception e) - { - App.RaiseException(_repo, $"Failed to generate commit message: {e}"); - } - } - - private Models.OpenAIService _service; - private string _repo; - private List _changes; - private CancellationToken _cancelToken; - private Action _onResponse; - } -} diff --git a/src/Models/AIProvider.cs b/src/Models/AIProvider.cs new file mode 100644 index 00000000..1a39e8bd --- /dev/null +++ b/src/Models/AIProvider.cs @@ -0,0 +1,11 @@ +namespace SourceGit.Models +{ + public class AIProvider + { + public string Name { get; set; } + public string Server { get; set; } + public string Model { get; set; } + public string ApiKey { get; set; } + public bool ReadApiKeyFromEnv { get; set; } + } +} diff --git a/src/Models/OpenAI.cs b/src/Models/OpenAI.cs deleted file mode 100644 index c38eb674..00000000 --- a/src/Models/OpenAI.cs +++ /dev/null @@ -1,239 +0,0 @@ -using System; -using System.ClientModel; -using System.Collections.Generic; -using System.Text; -using System.Text.RegularExpressions; -using System.Threading; -using System.Threading.Tasks; -using Azure.AI.OpenAI; -using CommunityToolkit.Mvvm.ComponentModel; -using OpenAI; -using OpenAI.Chat; - -namespace SourceGit.Models -{ - public partial class OpenAIResponse - { - public OpenAIResponse(Action onUpdate) - { - _onUpdate = onUpdate; - } - - public void Append(string text) - { - var buffer = text; - - if (_thinkTail.Length > 0) - { - _thinkTail.Append(buffer); - buffer = _thinkTail.ToString(); - _thinkTail.Clear(); - } - - buffer = REG_COT().Replace(buffer, ""); - - var startIdx = buffer.IndexOf('<'); - if (startIdx >= 0) - { - if (startIdx > 0) - OnReceive(buffer.Substring(0, startIdx)); - - var endIdx = buffer.IndexOf('>', startIdx + 1); - if (endIdx <= startIdx) - { - if (buffer.Length - startIdx <= 15) - _thinkTail.Append(buffer.AsSpan(startIdx)); - else - OnReceive(buffer.Substring(startIdx)); - } - else if (endIdx < startIdx + 15) - { - var tag = buffer.Substring(startIdx + 1, endIdx - startIdx - 1); - if (_thinkTags.Contains(tag)) - _thinkTail.Append(buffer.AsSpan(startIdx)); - else - OnReceive(buffer.Substring(startIdx)); - } - else - { - OnReceive(buffer.Substring(startIdx)); - } - } - else - { - OnReceive(buffer); - } - } - - public void End() - { - if (_thinkTail.Length > 0) - { - OnReceive(_thinkTail.ToString()); - _thinkTail.Clear(); - } - } - - private void OnReceive(string text) - { - if (!_hasTrimmedStart) - { - text = text.TrimStart(); - if (string.IsNullOrEmpty(text)) - return; - - _hasTrimmedStart = true; - } - - _onUpdate?.Invoke(text); - } - - [GeneratedRegex(@"<(think|thought|thinking|thought_chain)>.*?", RegexOptions.Singleline)] - private static partial Regex REG_COT(); - - private Action _onUpdate = null; - private StringBuilder _thinkTail = new StringBuilder(); - private HashSet _thinkTags = ["think", "thought", "thinking", "thought_chain"]; - private bool _hasTrimmedStart = false; - } - - public class OpenAIService : ObservableObject - { - public string Name - { - get => _name; - set => SetProperty(ref _name, value); - } - - public string Server - { - get => _server; - set => SetProperty(ref _server, value); - } - - public string ApiKey - { - get => _apiKey; - set => SetProperty(ref _apiKey, value); - } - - public bool ReadApiKeyFromEnv - { - get => _readApiKeyFromEnv; - set => SetProperty(ref _readApiKeyFromEnv, value); - } - - public string Model - { - get => _model; - set => SetProperty(ref _model, value); - } - - public bool Streaming - { - get => _streaming; - set => SetProperty(ref _streaming, value); - } - - public string AnalyzeDiffPrompt - { - get => _analyzeDiffPrompt; - set => SetProperty(ref _analyzeDiffPrompt, value); - } - - public string GenerateSubjectPrompt - { - get => _generateSubjectPrompt; - set => SetProperty(ref _generateSubjectPrompt, value); - } - - public OpenAIService() - { - AnalyzeDiffPrompt = """ - You are an expert developer specialist in creating commits. - Provide a super concise one sentence overall changes summary of the user `git diff` output following strictly the next rules: - - Do not use any code snippets, imports, file routes or bullets points. - - Do not mention the route of file that has been change. - - Write clear, concise, and descriptive messages that explain the MAIN GOAL made of the changes. - - Use the present tense and active voice in the message, for example, "Fix bug" instead of "Fixed bug.". - - Use the imperative mood, which gives the message a sense of command, e.g. "Add feature" instead of "Added feature". - - Avoid using general terms like "update" or "change", be specific about what was updated or changed. - - Avoid using terms like "The main goal of", just output directly the summary in plain text - """; - - GenerateSubjectPrompt = """ - You are an expert developer specialist in creating commits messages. - Your only goal is to retrieve a single commit message. - Based on the provided user changes, combine them in ONE SINGLE commit message retrieving the global idea, following strictly the next rules: - - Assign the commit {type} according to the next conditions: - feat: Only when adding a new feature. - fix: When fixing a bug. - docs: When updating documentation. - style: When changing elements styles or design and/or making changes to the code style (formatting, missing semicolons, etc.) without changing the code logic. - test: When adding or updating tests. - chore: When making changes to the build process or auxiliary tools and libraries. - revert: When undoing a previous commit. - refactor: When restructuring code without changing its external behavior, or is any of the other refactor types. - - Do not add any issues numeration, explain your output nor introduce your answer. - - Output directly only one commit message in plain text with the next format: {type}: {commit_message}. - - Be as concise as possible, keep the message under 50 characters. - """; - } - - public async Task ChatAsync(string prompt, string question, CancellationToken cancellation, Action onUpdate) - { - var key = _readApiKeyFromEnv ? Environment.GetEnvironmentVariable(_apiKey) : _apiKey; - var endPoint = new Uri(_server); - var credential = new ApiKeyCredential(key); - var client = _server.Contains("openai.azure.com/", StringComparison.Ordinal) - ? new AzureOpenAIClient(endPoint, credential) - : new OpenAIClient(credential, new() { Endpoint = endPoint }); - - var chatClient = client.GetChatClient(_model); - var messages = new List() - { - _model.Equals("o1-mini", StringComparison.Ordinal) ? new UserChatMessage(prompt) : new SystemChatMessage(prompt), - new UserChatMessage(question), - }; - - try - { - var rsp = new OpenAIResponse(onUpdate); - - if (_streaming) - { - var updates = chatClient.CompleteChatStreamingAsync(messages, null, cancellation); - - await foreach (var update in updates) - { - if (update.ContentUpdate.Count > 0) - rsp.Append(update.ContentUpdate[0].Text); - } - } - else - { - var completion = await chatClient.CompleteChatAsync(messages, null, cancellation); - - if (completion.Value.Content.Count > 0) - rsp.Append(completion.Value.Content[0].Text); - } - - rsp.End(); - } - catch - { - if (!cancellation.IsCancellationRequested) - throw; - } - } - - private string _name; - private string _server; - private string _apiKey; - private bool _readApiKeyFromEnv = false; - private string _model; - private bool _streaming = true; - private string _analyzeDiffPrompt; - private string _generateSubjectPrompt; - } -} diff --git a/src/Resources/Locales/en_US.axaml b/src/Resources/Locales/en_US.axaml index e738e223..620442d0 100644 --- a/src/Resources/Locales/en_US.axaml +++ b/src/Resources/Locales/en_US.axaml @@ -610,14 +610,11 @@ Yesterday Preferences AI - Analyze Diff Prompt API Key - Generate Subject Prompt Model Name Entered value is the name to load API key from ENV Server - Enable Streaming APPEARANCE Default Font Editor Tab Width diff --git a/src/SourceGit.csproj b/src/SourceGit.csproj index f57d7ee3..735d7141 100644 --- a/src/SourceGit.csproj +++ b/src/SourceGit.csproj @@ -33,7 +33,7 @@ - + @@ -60,7 +60,7 @@ - + diff --git a/src/ViewModels/AIAssistant.cs b/src/ViewModels/AIAssistant.cs index d538ce1b..07d89c20 100644 --- a/src/ViewModels/AIAssistant.cs +++ b/src/ViewModels/AIAssistant.cs @@ -1,4 +1,6 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; +using System.Text; using System.Threading; using System.Threading.Tasks; @@ -22,13 +24,17 @@ namespace SourceGit.ViewModels private set => SetProperty(ref _text, value); } - public AIAssistant(Repository repo, Models.OpenAIService service, List changes) + public AIAssistant(string repo, Models.AIProvider provider, List changes) { _repo = repo; - _service = service; - _changes = changes; + _provider = provider; _cancel = new CancellationTokenSource(); + var builder = new StringBuilder(); + foreach (var c in changes) + SerializeChange(c, builder); + _changeList = builder.ToString(); + Gen(); } @@ -40,16 +46,32 @@ namespace SourceGit.ViewModels Gen(); } - public void Apply() - { - _repo.SetCommitMessage(Text); - } - public void Cancel() { _cancel?.Cancel(); } + private void SerializeChange(Models.Change c, StringBuilder builder) + { + var status = c.Index switch + { + Models.ChangeState.Added => "A", + Models.ChangeState.Modified => "M", + Models.ChangeState.Deleted => "D", + Models.ChangeState.TypeChanged => "T", + Models.ChangeState.Renamed => "R", + Models.ChangeState.Copied => "C", + _ => " ", + }; + + builder.Append(status).Append('\t'); + + if (c.Index == Models.ChangeState.Renamed || c.Index == Models.ChangeState.Copied) + builder.Append(c.OriginalPath).Append(" -> ").Append(c.Path).AppendLine(); + else + builder.Append(c.Path).AppendLine(); + } + private void Gen() { Text = string.Empty; @@ -58,18 +80,31 @@ namespace SourceGit.ViewModels _cancel = new CancellationTokenSource(); Task.Run(async () => { - await new Commands.GenerateCommitMessage(_service, _repo.FullPath, _changes, _cancel.Token, message => + var server = new AI.Service(_provider); + var builder = new StringBuilder(); + builder.AppendLine("Asking AI to generate commit message...").AppendLine(); + Dispatcher.UIThread.Post(() => Text = builder.ToString()); + + try { - Dispatcher.UIThread.Post(() => Text = message); - }).ExecAsync().ConfigureAwait(false); + await server.GenerateCommitMessage(_repo, _changeList, message => + { + builder.AppendLine(message); + Dispatcher.UIThread.Post(() => Text = builder.ToString()); + }, _cancel.Token).ConfigureAwait(false); + } + catch (Exception e) + { + App.RaiseException(_repo, e.Message); + } Dispatcher.UIThread.Post(() => IsGenerating = false); }, _cancel.Token); } - private readonly Repository _repo = null; - private Models.OpenAIService _service = null; - private List _changes = null; + private readonly string _repo = null; + private readonly Models.AIProvider _provider = null; + private readonly string _changeList = null; private CancellationTokenSource _cancel = null; private bool _isGenerating = false; private string _text = string.Empty; diff --git a/src/ViewModels/Preferences.cs b/src/ViewModels/Preferences.cs index 95817cc1..520fc560 100644 --- a/src/ViewModels/Preferences.cs +++ b/src/ViewModels/Preferences.cs @@ -480,7 +480,7 @@ namespace SourceGit.ViewModels set; } = []; - public AvaloniaList OpenAIServices + public AvaloniaList OpenAIServices { get; set; diff --git a/src/ViewModels/Repository.cs b/src/ViewModels/Repository.cs index 4588647c..eecbfc85 100644 --- a/src/ViewModels/Repository.cs +++ b/src/ViewModels/Repository.cs @@ -1599,7 +1599,7 @@ namespace SourceGit.ViewModels log.Complete(); } - public List GetPreferredOpenAIServices() + public List GetPreferredOpenAIServices() { var services = Preferences.Instance.OpenAIServices; if (services == null || services.Count == 0) @@ -1609,7 +1609,7 @@ namespace SourceGit.ViewModels return [services[0]]; var preferred = _settings.PreferredOpenAIService; - var all = new List(); + var all = new List(); foreach (var service in services) { if (service.Name.Equals(preferred, StringComparison.Ordinal)) diff --git a/src/Views/AIAssistant.axaml b/src/Views/AIAssistant.axaml index bef13df2..2504ac3b 100644 --- a/src/Views/AIAssistant.axaml +++ b/src/Views/AIAssistant.axaml @@ -51,13 +51,6 @@ -