mirror of
https://fastgit.cc/https://github.com/anomalyco/opencode
synced 2026-05-02 23:04:07 +08:00
Compare commits
216 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8e40160934 | ||
|
|
e97ed735d9 | ||
|
|
6d21525e71 | ||
|
|
b4f809559e | ||
|
|
33109bac4d | ||
|
|
f072ab3276 | ||
|
|
3b746162d2 | ||
|
|
6df19f1828 | ||
|
|
fba56d6871 | ||
|
|
1472efcbfe | ||
|
|
56a5d58945 | ||
|
|
f50a57041f | ||
|
|
f3da73553c | ||
|
|
9a26b3058f | ||
|
|
a09be7cf74 | ||
|
|
91a9e455e2 | ||
|
|
c391c6d3f3 | ||
|
|
ca562266b7 | ||
|
|
c69c9327da | ||
|
|
f5e2c596d4 | ||
|
|
e9bad39a7e | ||
|
|
42c7880858 | ||
|
|
017a440a70 | ||
|
|
1ab9547bb2 | ||
|
|
a4e46e6e18 | ||
|
|
680d52016c | ||
|
|
6ebbcb3179 | ||
|
|
437de4ee36 | ||
|
|
a5b28b5cef | ||
|
|
9bf024f8be | ||
|
|
189d0e5fb2 | ||
|
|
b1b402faa7 | ||
|
|
c5413c8c8d | ||
|
|
da1e8484a9 | ||
|
|
4818bc5426 | ||
|
|
8a8c6b14af | ||
|
|
0e31bbcd93 | ||
|
|
913b3434d8 | ||
|
|
1c01ee4834 | ||
|
|
005d6e0bde | ||
|
|
37c0c1f358 | ||
|
|
2a132f86d6 | ||
|
|
50ba0b380b | ||
|
|
6cccbdccd3 | ||
|
|
d0ad09d798 | ||
|
|
4fa4246c10 | ||
|
|
0fe72864f2 | ||
|
|
ce5b3126d3 | ||
|
|
26606ccbf7 | ||
|
|
fce9e79d38 | ||
|
|
6759674c0f | ||
|
|
a9799136fe | ||
|
|
7a29af4e30 | ||
|
|
d398001f96 | ||
|
|
e68747a64a | ||
|
|
d62ce482da | ||
|
|
f9f41e205d | ||
|
|
80597cd3fd | ||
|
|
48f81fe4d3 | ||
|
|
a96c2ce65c | ||
|
|
6f604bd0f9 | ||
|
|
42c1cd6a85 | ||
|
|
33a831d2be | ||
|
|
d70201cd93 | ||
|
|
9f1a75e938 | ||
|
|
2b77a7f714 | ||
|
|
5974a53071 | ||
|
|
3d61cc5d2b | ||
|
|
a22a2f0f37 | ||
|
|
be65ed6f88 | ||
|
|
e88264075a | ||
|
|
041a080a13 | ||
|
|
9d7c5efb9b | ||
|
|
8863a499a9 | ||
|
|
15d21bf04a | ||
|
|
5e738ce7d3 | ||
|
|
641e9ff664 | ||
|
|
d249766777 | ||
|
|
6cf4b7f00b | ||
|
|
6183398543 | ||
|
|
ff786d9139 | ||
|
|
4767276a0e | ||
|
|
71bab45065 | ||
|
|
cb48813c95 | ||
|
|
520cd02dd5 | ||
|
|
afe741b63e | ||
|
|
f3b224090c | ||
|
|
3b7b7f4bea | ||
|
|
3a4d3b249f | ||
|
|
55a6fcdd3f | ||
|
|
4132fcc1b2 | ||
|
|
37082b2176 | ||
|
|
b9f009c529 | ||
|
|
601f610eb7 | ||
|
|
2e2bdd46b4 | ||
|
|
3a28ce9b0a | ||
|
|
bb6fc2a1fd | ||
|
|
ad76fa8616 | ||
|
|
bdac7d10dd | ||
|
|
0ecfdd7501 | ||
|
|
a9758e0db5 | ||
|
|
e98f915fd5 | ||
|
|
07f0fea4bf | ||
|
|
6a43afc4e7 | ||
|
|
c01eefc729 | ||
|
|
5d4ccc8883 | ||
|
|
98b5390a22 | ||
|
|
c040baae11 | ||
|
|
754cc66741 | ||
|
|
6ef0b991ec | ||
|
|
f6ca06b8ea | ||
|
|
4c198940d5 | ||
|
|
2e938d9da1 | ||
|
|
b840a40759 | ||
|
|
a1d40f8f28 | ||
|
|
575d76fa06 | ||
|
|
b75456f5dd | ||
|
|
eb69cc3943 | ||
|
|
e524209352 | ||
|
|
e1c897c1ae | ||
|
|
39f54e83e1 | ||
|
|
d34c974996 | ||
|
|
c203891b84 | ||
|
|
591bd2a4e3 | ||
|
|
94f35130f7 | ||
|
|
f26873f5de | ||
|
|
66b18959eb | ||
|
|
deacf5991a | ||
|
|
25623d1f84 | ||
|
|
de9f144858 | ||
|
|
0ad8738933 | ||
|
|
db5744bbc4 | ||
|
|
b87ba57819 | ||
|
|
802389a90e | ||
|
|
4880b08b8a | ||
|
|
80555f13e0 | ||
|
|
113c49457f | ||
|
|
e1ec815d1b | ||
|
|
2ed17f4877 | ||
|
|
80118212da | ||
|
|
9b3760247a | ||
|
|
a2d652b13d | ||
|
|
5c491758f5 | ||
|
|
5f750b7368 | ||
|
|
e2dc5a8faf | ||
|
|
72d10a0823 | ||
|
|
7623b33f31 | ||
|
|
9b331a917e | ||
|
|
d51b4263ab | ||
|
|
34a2dcb80a | ||
|
|
8cbd59296e | ||
|
|
83974e0c95 | ||
|
|
59d43fa5da | ||
|
|
e01afb407c | ||
|
|
f0f55bc75f | ||
|
|
2860a2bb1a | ||
|
|
9b564f0b73 | ||
|
|
2437ce3f8b | ||
|
|
fa8a46326a | ||
|
|
652429377b | ||
|
|
99af6146d5 | ||
|
|
020e0ca039 | ||
|
|
0439072420 | ||
|
|
49ad2efef6 | ||
|
|
0e303e6508 | ||
|
|
bcd2fd68b7 | ||
|
|
d0d67029f4 | ||
|
|
a34d020bc6 | ||
|
|
96fbc37f01 | ||
|
|
89e3a72ae1 | ||
|
|
b9ebcea82c | ||
|
|
f31f92119d | ||
|
|
da9b2a18b9 | ||
|
|
2b258b1473 | ||
|
|
6f894950a6 | ||
|
|
9049295cc9 | ||
|
|
4526b14b17 | ||
|
|
f768313c4f | ||
|
|
d9c1b2cc90 | ||
|
|
6cfcf51752 | ||
|
|
dff8e77eb6 | ||
|
|
6e854a4df4 | ||
|
|
2f8984fadb | ||
|
|
c84918cb47 | ||
|
|
05bb065d00 | ||
|
|
3742997889 | ||
|
|
daf0305203 | ||
|
|
307982a099 | ||
|
|
ba416e787b | ||
|
|
b71cae63f1 | ||
|
|
c92f7c6630 | ||
|
|
4a444e9c9b | ||
|
|
623d132772 | ||
|
|
d127a1c4eb | ||
|
|
c9cca48d08 | ||
|
|
3944930fc0 | ||
|
|
825c0b64af | ||
|
|
d7af7dd3fe | ||
|
|
b112216241 | ||
|
|
87237b6462 | ||
|
|
5f5f9dad87 | ||
|
|
aa8b3ce1ee | ||
|
|
a65e593ab4 | ||
|
|
5d9058eb74 | ||
|
|
a850320fad | ||
|
|
ddbb217d0d | ||
|
|
ab150be7c3 | ||
|
|
a203fb8ccc | ||
|
|
acc084c9ea | ||
|
|
3ee213081e | ||
|
|
15bf40bc10 | ||
|
|
a33e3e25b6 | ||
|
|
658faab2bf | ||
|
|
797045ee29 | ||
|
|
c8f8d67a88 | ||
|
|
182e32e4f7 |
25
.github/workflows/deploy.yml
vendored
Normal file
25
.github/workflows/deploy.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
name: deploy
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- dontlook
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- uses: oven-sh/setup-bun@v1
|
||||
with:
|
||||
bun-version: latest
|
||||
|
||||
- run: bun install
|
||||
|
||||
- run: bun sst deploy --stage=dev
|
||||
env:
|
||||
CLOUDFLARE_API_TOKEN: ${{ secrets.CLOUDFLARE_API_TOKEN }}
|
||||
48
.gitignore
vendored
48
.gitignore
vendored
@@ -1,45 +1,5 @@
|
||||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# Test binary, built with `go test -c`
|
||||
*.test
|
||||
|
||||
# Output of the go coverage tool, specifically when used with LiteIDE
|
||||
*.out
|
||||
|
||||
# Dependency directories (remove the comment below to include it)
|
||||
# vendor/
|
||||
|
||||
# Go workspace file
|
||||
go.work
|
||||
|
||||
# IDE specific files
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# OS specific files
|
||||
.DS_Store
|
||||
.DS_Store?
|
||||
._*
|
||||
.Spotlight-V100
|
||||
.Trashes
|
||||
ehthumbs.db
|
||||
Thumbs.db
|
||||
*.log
|
||||
|
||||
# Binary output directory
|
||||
/bin/
|
||||
/dist/
|
||||
|
||||
# Local environment variables
|
||||
node_modules
|
||||
.opencode
|
||||
.sst
|
||||
app.log
|
||||
.env
|
||||
.env.local
|
||||
|
||||
.opencode/
|
||||
|
||||
|
||||
24
CONTEXT.md
24
CONTEXT.md
@@ -1,24 +0,0 @@
|
||||
# OpenCode Development Context
|
||||
|
||||
## Build Commands
|
||||
- Build: `go build`
|
||||
- Run: `go run main.go`
|
||||
- Test: `go test ./...`
|
||||
- Test single package: `go test ./internal/package/...`
|
||||
- Test single test: `go test ./internal/package -run TestName`
|
||||
- Verbose test: `go test -v ./...`
|
||||
- Coverage: `go test -cover ./...`
|
||||
- Lint: `go vet ./...`
|
||||
- Format: `go fmt ./...`
|
||||
- Build snapshot: `./scripts/snapshot`
|
||||
|
||||
## Code Style
|
||||
- Use Go 1.24+ features
|
||||
- Follow standard Go formatting (gofmt)
|
||||
- Use table-driven tests with t.Parallel() when possible
|
||||
- Error handling: check errors immediately, return early
|
||||
- Naming: CamelCase for exported, camelCase for unexported
|
||||
- Imports: standard library first, then external, then internal
|
||||
- Use context.Context for cancellation and timeouts
|
||||
- Prefer interfaces for dependencies to enable testing
|
||||
- Use testify for assertions in tests
|
||||
2
LICENSE
2
LICENSE
@@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Kujtim Hoxha
|
||||
Copyright (c) 2025 OpenCode
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
170
README.md
170
README.md
@@ -1,4 +1,6 @@
|
||||
# ⌬ OpenCode
|
||||
# ◍OpenCode
|
||||
|
||||

|
||||
|
||||
> **⚠️ Early Development Notice:** This project is in early development and is not yet ready for production use. Features may change, break, or be incomplete. Use at your own risk.
|
||||
|
||||
@@ -19,6 +21,7 @@ OpenCode is a Go-based CLI application that brings AI assistance to your termina
|
||||
- **LSP Integration**: Language Server Protocol support for code intelligence
|
||||
- **File Change Tracking**: Track and visualize file changes during sessions
|
||||
- **External Editor Support**: Open your preferred editor for composing messages
|
||||
- **Named Arguments for Custom Commands**: Create powerful custom commands with multiple named placeholders
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -71,6 +74,8 @@ You can configure OpenCode using environment variables:
|
||||
| `ANTHROPIC_API_KEY` | For Claude models |
|
||||
| `OPENAI_API_KEY` | For OpenAI models |
|
||||
| `GEMINI_API_KEY` | For Google Gemini models |
|
||||
| `VERTEXAI_PROJECT` | For Google Cloud VertexAI (Gemini) |
|
||||
| `VERTEXAI_LOCATION` | For Google Cloud VertexAI (Gemini) |
|
||||
| `GROQ_API_KEY` | For Groq models |
|
||||
| `AWS_ACCESS_KEY_ID` | For AWS Bedrock (Claude) |
|
||||
| `AWS_SECRET_ACCESS_KEY` | For AWS Bedrock (Claude) |
|
||||
@@ -132,6 +137,10 @@ You can configure OpenCode using environment variables:
|
||||
"command": "gopls"
|
||||
}
|
||||
},
|
||||
"shell": {
|
||||
"path": "/bin/zsh",
|
||||
"args": ["-l"]
|
||||
},
|
||||
"debug": false,
|
||||
"debugLSP": false
|
||||
}
|
||||
@@ -186,7 +195,43 @@ OpenCode supports a variety of AI models from different providers:
|
||||
- O3 family (o3, o3-mini)
|
||||
- O4 Mini
|
||||
|
||||
## Usage
|
||||
### Google Cloud VertexAI
|
||||
|
||||
- Gemini 2.5
|
||||
- Gemini 2.5 Flash
|
||||
|
||||
## Using Bedrock Models
|
||||
|
||||
To use bedrock models with OpenCode you need three things.
|
||||
|
||||
1. Valid AWS credentials (the env vars: `AWS_SECRET_KEY_ID`, `AWS_SECRET_ACCESS_KEY` and `AWS_REGION`)
|
||||
2. Access to the corresponding model in AWS Bedrock in your region.
|
||||
a. You can request access in the AWS console on the Bedrock -> "Model access" page.
|
||||
3. A correct configuration file. You don't need the `providers` key. Instead you have to prefix your models per agent with `bedrock.` and then a valid model. For now only Claude 3.7 is supported.
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"primary": {
|
||||
"model": "bedrock.claude-3.7-sonnet",
|
||||
"maxTokens": 5000,
|
||||
"reasoningEffort": ""
|
||||
},
|
||||
"task": {
|
||||
"model": "bedrock.claude-3.7-sonnet",
|
||||
"maxTokens": 5000,
|
||||
"reasoningEffort": ""
|
||||
},
|
||||
"title": {
|
||||
"model": "bedrock.claude-3.7-sonnet",
|
||||
"maxTokens": 80,
|
||||
"reasoningEffort": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Interactive Mode Usage
|
||||
|
||||
```bash
|
||||
# Start OpenCode
|
||||
@@ -199,13 +244,78 @@ opencode -d
|
||||
opencode -c /path/to/project
|
||||
```
|
||||
|
||||
## Non-interactive Prompt Mode
|
||||
|
||||
You can run OpenCode in non-interactive mode by passing a prompt directly as a command-line argument or by piping text into the command. This is useful for scripting, automation, or when you want a quick answer without launching the full TUI.
|
||||
|
||||
```bash
|
||||
# Run a single prompt and print the AI's response to the terminal
|
||||
opencode -p "Explain the use of context in Go"
|
||||
|
||||
# Pipe input to OpenCode (equivalent to using -p flag)
|
||||
echo "Explain the use of context in Go" | opencode
|
||||
|
||||
# Get response in JSON format
|
||||
opencode -p "Explain the use of context in Go" -f json
|
||||
# Or with piped input
|
||||
echo "Explain the use of context in Go" | opencode -f json
|
||||
|
||||
# Run without showing the spinner
|
||||
opencode -p "Explain the use of context in Go" -q
|
||||
# Or with piped input
|
||||
echo "Explain the use of context in Go" | opencode -q
|
||||
|
||||
# Enable verbose logging to stderr
|
||||
opencode -p "Explain the use of context in Go" --verbose
|
||||
# Or with piped input
|
||||
echo "Explain the use of context in Go" | opencode --verbose
|
||||
|
||||
# Restrict the agent to only use specific tools
|
||||
opencode -p "Explain the use of context in Go" --allowedTools=view,ls,glob
|
||||
# Or with piped input
|
||||
echo "Explain the use of context in Go" | opencode --allowedTools=view,ls,glob
|
||||
|
||||
# Prevent the agent from using specific tools
|
||||
opencode -p "Explain the use of context in Go" --excludedTools=bash,edit
|
||||
# Or with piped input
|
||||
echo "Explain the use of context in Go" | opencode --excludedTools=bash,edit
|
||||
```
|
||||
|
||||
In this mode, OpenCode will process your prompt, print the result to standard output, and then exit. All permissions are auto-approved for the session.
|
||||
|
||||
### Tool Restrictions
|
||||
|
||||
You can control which tools the AI assistant has access to in non-interactive mode:
|
||||
|
||||
- `--allowedTools`: Comma-separated list of tools that the agent is allowed to use. Only these tools will be available.
|
||||
- `--excludedTools`: Comma-separated list of tools that the agent is not allowed to use. All other tools will be available.
|
||||
|
||||
These flags are mutually exclusive - you can use either `--allowedTools` or `--excludedTools`, but not both at the same time.
|
||||
|
||||
### Output Formats
|
||||
|
||||
OpenCode supports the following output formats in non-interactive mode:
|
||||
|
||||
| Format | Description |
|
||||
| ------ | ------------------------------- |
|
||||
| `text` | Plain text output (default) |
|
||||
| `json` | Output wrapped in a JSON object |
|
||||
|
||||
The output format is implemented as a strongly-typed `OutputFormat` in the codebase, ensuring type safety and validation when processing outputs.
|
||||
|
||||
## Command-line Flags
|
||||
|
||||
| Flag | Short | Description |
|
||||
| --------- | ----- | ----------------------------- |
|
||||
| `--help` | `-h` | Display help information |
|
||||
| `--debug` | `-d` | Enable debug mode |
|
||||
| `--cwd` | `-c` | Set current working directory |
|
||||
| Flag | Short | Description |
|
||||
| ----------------- | ----- | --------------------------------------------------- |
|
||||
| `--help` | `-h` | Display help information |
|
||||
| `--debug` | `-d` | Enable debug mode |
|
||||
| `--cwd` | `-c` | Set current working directory |
|
||||
| `--prompt` | `-p` | Run a single prompt in non-interactive mode |
|
||||
| `--output-format` | `-f` | Output format for non-interactive mode (text, json) |
|
||||
| `--quiet` | `-q` | Hide spinner in non-interactive mode |
|
||||
| `--verbose` | | Display logs to stderr in non-interactive mode |
|
||||
| `--allowedTools` | | Restrict the agent to only use specified tools |
|
||||
| `--excludedTools` | | Prevent the agent from using specified tools |
|
||||
|
||||
## Keyboard Shortcuts
|
||||
|
||||
@@ -371,6 +481,36 @@ You can define any of the following color keys in your `customTheme`:
|
||||
|
||||
You don't need to define all colors. Any undefined colors will fall back to the default "opencode" theme colors.
|
||||
|
||||
### Shell Configuration
|
||||
|
||||
OpenCode allows you to configure the shell used by the `bash` tool. By default, it uses:
|
||||
|
||||
1. The shell specified in the config file (if provided)
|
||||
2. The shell from the `$SHELL` environment variable (if available)
|
||||
3. Falls back to `/bin/bash` if neither of the above is available
|
||||
|
||||
To configure a custom shell, add a `shell` section to your `.opencode.json` configuration file:
|
||||
|
||||
```json
|
||||
{
|
||||
"shell": {
|
||||
"path": "/bin/zsh",
|
||||
"args": ["-l"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
You can specify any shell executable and custom arguments:
|
||||
|
||||
```json
|
||||
{
|
||||
"shell": {
|
||||
"path": "/usr/bin/fish",
|
||||
"args": []
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
OpenCode is built with a modular architecture:
|
||||
@@ -426,13 +566,23 @@ This creates a command called `user:prime-context`.
|
||||
|
||||
### Command Arguments
|
||||
|
||||
You can create commands that accept arguments by including the `$ARGUMENTS` placeholder in your command file:
|
||||
OpenCode supports named arguments in custom commands using placeholders in the format `$NAME` (where NAME consists of uppercase letters, numbers, and underscores, and must start with a letter).
|
||||
|
||||
For example:
|
||||
|
||||
```markdown
|
||||
RUN git show $ARGUMENTS
|
||||
# Fetch Context for Issue $ISSUE_NUMBER
|
||||
|
||||
RUN gh issue view $ISSUE_NUMBER --json title,body,comments
|
||||
RUN git grep --author="$AUTHOR_NAME" -n .
|
||||
RUN grep -R "$SEARCH_PATTERN" $DIRECTORY
|
||||
```
|
||||
|
||||
When you run this command, OpenCode will prompt you to enter the text that should replace `$ARGUMENTS`.
|
||||
When you run a command with arguments, OpenCode will prompt you to enter values for each unique placeholder. Named arguments provide several benefits:
|
||||
|
||||
- Clear identification of what each argument represents
|
||||
- Ability to use the same argument multiple times
|
||||
- Better organization for commands with multiple inputs
|
||||
|
||||
### Organizing Commands
|
||||
|
||||
|
||||
2
bunfig.toml
Normal file
2
bunfig.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[install]
|
||||
exact = true
|
||||
@@ -1,65 +0,0 @@
|
||||
# OpenCode Configuration Schema Generator
|
||||
|
||||
This tool generates a JSON Schema for the OpenCode configuration file. The schema can be used to validate configuration files and provide autocompletion in editors that support JSON Schema.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
go run cmd/schema/main.go > opencode-schema.json
|
||||
```
|
||||
|
||||
This will generate a JSON Schema file that can be used to validate configuration files.
|
||||
|
||||
## Schema Features
|
||||
|
||||
The generated schema includes:
|
||||
|
||||
- All configuration options with descriptions
|
||||
- Default values where applicable
|
||||
- Validation for enum values (e.g., model IDs, provider types)
|
||||
- Required fields
|
||||
- Type checking
|
||||
|
||||
## Using the Schema
|
||||
|
||||
You can use the generated schema in several ways:
|
||||
|
||||
1. **Editor Integration**: Many editors (VS Code, JetBrains IDEs, etc.) support JSON Schema for validation and autocompletion. You can configure your editor to use the generated schema for `.opencode.json` files.
|
||||
|
||||
2. **Validation Tools**: You can use tools like [jsonschema](https://github.com/Julian/jsonschema) to validate your configuration files against the schema.
|
||||
|
||||
3. **Documentation**: The schema serves as documentation for the configuration options.
|
||||
|
||||
## Example Configuration
|
||||
|
||||
Here's an example configuration that conforms to the schema:
|
||||
|
||||
```json
|
||||
{
|
||||
"data": {
|
||||
"directory": ".opencode"
|
||||
},
|
||||
"debug": false,
|
||||
"providers": {
|
||||
"anthropic": {
|
||||
"apiKey": "your-api-key"
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"primary": {
|
||||
"model": "claude-3.7-sonnet",
|
||||
"maxTokens": 5000,
|
||||
"reasoningEffort": "medium"
|
||||
},
|
||||
"task": {
|
||||
"model": "claude-3.7-sonnet",
|
||||
"maxTokens": 5000
|
||||
},
|
||||
"title": {
|
||||
"model": "claude-3.7-sonnet",
|
||||
"maxTokens": 80
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1,335 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
)
|
||||
|
||||
// JSONSchemaType represents a JSON Schema type
|
||||
type JSONSchemaType struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Properties map[string]any `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
AdditionalProperties any `json:"additionalProperties,omitempty"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
Items map[string]any `json:"items,omitempty"`
|
||||
OneOf []map[string]any `json:"oneOf,omitempty"`
|
||||
AnyOf []map[string]any `json:"anyOf,omitempty"`
|
||||
Default any `json:"default,omitempty"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
schema := generateSchema()
|
||||
|
||||
// Pretty print the schema
|
||||
encoder := json.NewEncoder(os.Stdout)
|
||||
encoder.SetIndent("", " ")
|
||||
if err := encoder.Encode(schema); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error encoding schema: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func generateSchema() map[string]any {
|
||||
schema := map[string]any{
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"title": "OpenCode Configuration",
|
||||
"description": "Configuration schema for the OpenCode application",
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
|
||||
// Add Data configuration
|
||||
schema["properties"].(map[string]any)["data"] = map[string]any{
|
||||
"type": "object",
|
||||
"description": "Storage configuration",
|
||||
"properties": map[string]any{
|
||||
"directory": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Directory where application data is stored",
|
||||
"default": ".opencode",
|
||||
},
|
||||
},
|
||||
"required": []string{"directory"},
|
||||
}
|
||||
|
||||
// Add working directory
|
||||
schema["properties"].(map[string]any)["wd"] = map[string]any{
|
||||
"type": "string",
|
||||
"description": "Working directory for the application",
|
||||
}
|
||||
|
||||
// Add debug flags
|
||||
schema["properties"].(map[string]any)["debug"] = map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Enable debug mode",
|
||||
"default": false,
|
||||
}
|
||||
|
||||
schema["properties"].(map[string]any)["debugLSP"] = map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Enable LSP debug mode",
|
||||
"default": false,
|
||||
}
|
||||
|
||||
schema["properties"].(map[string]any)["contextPaths"] = map[string]any{
|
||||
"type": "array",
|
||||
"description": "Context paths for the application",
|
||||
"items": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
"default": []string{
|
||||
".github/copilot-instructions.md",
|
||||
".cursorrules",
|
||||
".cursor/rules/",
|
||||
"CLAUDE.md",
|
||||
"CLAUDE.local.md",
|
||||
"opencode.md",
|
||||
"opencode.local.md",
|
||||
"OpenCode.md",
|
||||
"OpenCode.local.md",
|
||||
"OPENCODE.md",
|
||||
"OPENCODE.local.md",
|
||||
},
|
||||
}
|
||||
|
||||
schema["properties"].(map[string]any)["tui"] = map[string]any{
|
||||
"type": "object",
|
||||
"description": "Terminal User Interface configuration",
|
||||
"properties": map[string]any{
|
||||
"theme": map[string]any{
|
||||
"type": "string",
|
||||
"description": "TUI theme name",
|
||||
"default": "opencode",
|
||||
"enum": []string{
|
||||
"opencode",
|
||||
"catppuccin",
|
||||
"dracula",
|
||||
"flexoki",
|
||||
"gruvbox",
|
||||
"monokai",
|
||||
"onedark",
|
||||
"tokyonight",
|
||||
"tron",
|
||||
"custom",
|
||||
},
|
||||
},
|
||||
"customTheme": map[string]any{
|
||||
"type": "object",
|
||||
"description": "Custom theme color definitions",
|
||||
"additionalProperties": map[string]any{
|
||||
"oneOf": []map[string]any{
|
||||
{
|
||||
"type": "string",
|
||||
"pattern": "^#[0-9a-fA-F]{6}$",
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"dark": map[string]any{
|
||||
"type": "string",
|
||||
"pattern": "^#[0-9a-fA-F]{6}$",
|
||||
},
|
||||
"light": map[string]any{
|
||||
"type": "string",
|
||||
"pattern": "^#[0-9a-fA-F]{6}$",
|
||||
},
|
||||
},
|
||||
"required": []string{"dark", "light"},
|
||||
"additionalProperties": false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add MCP servers
|
||||
schema["properties"].(map[string]any)["mcpServers"] = map[string]any{
|
||||
"type": "object",
|
||||
"description": "Model Control Protocol server configurations",
|
||||
"additionalProperties": map[string]any{
|
||||
"type": "object",
|
||||
"description": "MCP server configuration",
|
||||
"properties": map[string]any{
|
||||
"command": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Command to execute for the MCP server",
|
||||
},
|
||||
"env": map[string]any{
|
||||
"type": "array",
|
||||
"description": "Environment variables for the MCP server",
|
||||
"items": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"args": map[string]any{
|
||||
"type": "array",
|
||||
"description": "Command arguments for the MCP server",
|
||||
"items": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"type": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Type of MCP server",
|
||||
"enum": []string{"stdio", "sse"},
|
||||
"default": "stdio",
|
||||
},
|
||||
"url": map[string]any{
|
||||
"type": "string",
|
||||
"description": "URL for SSE type MCP servers",
|
||||
},
|
||||
"headers": map[string]any{
|
||||
"type": "object",
|
||||
"description": "HTTP headers for SSE type MCP servers",
|
||||
"additionalProperties": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": []string{"command"},
|
||||
},
|
||||
}
|
||||
|
||||
// Add providers
|
||||
providerSchema := map[string]any{
|
||||
"type": "object",
|
||||
"description": "LLM provider configurations",
|
||||
"additionalProperties": map[string]any{
|
||||
"type": "object",
|
||||
"description": "Provider configuration",
|
||||
"properties": map[string]any{
|
||||
"apiKey": map[string]any{
|
||||
"type": "string",
|
||||
"description": "API key for the provider",
|
||||
},
|
||||
"disabled": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Whether the provider is disabled",
|
||||
"default": false,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add known providers
|
||||
knownProviders := []string{
|
||||
string(models.ProviderAnthropic),
|
||||
string(models.ProviderOpenAI),
|
||||
string(models.ProviderGemini),
|
||||
string(models.ProviderGROQ),
|
||||
string(models.ProviderOpenRouter),
|
||||
string(models.ProviderBedrock),
|
||||
string(models.ProviderAzure),
|
||||
}
|
||||
|
||||
providerSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["provider"] = map[string]any{
|
||||
"type": "string",
|
||||
"description": "Provider type",
|
||||
"enum": knownProviders,
|
||||
}
|
||||
|
||||
schema["properties"].(map[string]any)["providers"] = providerSchema
|
||||
|
||||
// Add agents
|
||||
agentSchema := map[string]any{
|
||||
"type": "object",
|
||||
"description": "Agent configurations",
|
||||
"additionalProperties": map[string]any{
|
||||
"type": "object",
|
||||
"description": "Agent configuration",
|
||||
"properties": map[string]any{
|
||||
"model": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Model ID for the agent",
|
||||
},
|
||||
"maxTokens": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Maximum tokens for the agent",
|
||||
"minimum": 1,
|
||||
},
|
||||
"reasoningEffort": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Reasoning effort for models that support it (OpenAI, Anthropic)",
|
||||
"enum": []string{"low", "medium", "high"},
|
||||
},
|
||||
},
|
||||
"required": []string{"model"},
|
||||
},
|
||||
}
|
||||
|
||||
// Add model enum
|
||||
modelEnum := []string{}
|
||||
for modelID := range models.SupportedModels {
|
||||
modelEnum = append(modelEnum, string(modelID))
|
||||
}
|
||||
agentSchema["additionalProperties"].(map[string]any)["properties"].(map[string]any)["model"].(map[string]any)["enum"] = modelEnum
|
||||
|
||||
// Add specific agent properties
|
||||
agentProperties := map[string]any{}
|
||||
knownAgents := []string{
|
||||
string(config.AgentPrimary),
|
||||
string(config.AgentTask),
|
||||
string(config.AgentTitle),
|
||||
}
|
||||
|
||||
for _, agentName := range knownAgents {
|
||||
agentProperties[agentName] = map[string]any{
|
||||
"$ref": "#/definitions/agent",
|
||||
}
|
||||
}
|
||||
|
||||
// Create a combined schema that allows both specific agents and additional ones
|
||||
combinedAgentSchema := map[string]any{
|
||||
"type": "object",
|
||||
"description": "Agent configurations",
|
||||
"properties": agentProperties,
|
||||
"additionalProperties": agentSchema["additionalProperties"],
|
||||
}
|
||||
|
||||
schema["properties"].(map[string]any)["agents"] = combinedAgentSchema
|
||||
schema["definitions"] = map[string]any{
|
||||
"agent": agentSchema["additionalProperties"],
|
||||
}
|
||||
|
||||
// Add LSP configuration
|
||||
schema["properties"].(map[string]any)["lsp"] = map[string]any{
|
||||
"type": "object",
|
||||
"description": "Language Server Protocol configurations",
|
||||
"additionalProperties": map[string]any{
|
||||
"type": "object",
|
||||
"description": "LSP configuration for a language",
|
||||
"properties": map[string]any{
|
||||
"disabled": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Whether the LSP is disabled",
|
||||
"default": false,
|
||||
},
|
||||
"command": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Command to execute for the LSP server",
|
||||
},
|
||||
"args": map[string]any{
|
||||
"type": "array",
|
||||
"description": "Command arguments for the LSP server",
|
||||
"items": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"options": map[string]any{
|
||||
"type": "object",
|
||||
"description": "Additional options for the LSP server",
|
||||
},
|
||||
},
|
||||
"required": []string{"command"},
|
||||
},
|
||||
}
|
||||
|
||||
return schema
|
||||
}
|
||||
128
go.mod
128
go.mod
@@ -1,128 +0,0 @@
|
||||
module github.com/sst/opencode
|
||||
|
||||
go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0
|
||||
github.com/JohannesKaufmann/html-to-markdown v1.6.0
|
||||
github.com/PuerkitoBio/goquery v1.9.2
|
||||
github.com/alecthomas/chroma/v2 v2.15.0
|
||||
github.com/anthropics/anthropic-sdk-go v0.2.0-beta.2
|
||||
github.com/aymanbagabas/go-udiff v0.2.0
|
||||
github.com/bmatcuk/doublestar/v4 v4.8.1
|
||||
github.com/catppuccin/go v0.3.0
|
||||
github.com/charmbracelet/bubbles v0.20.0
|
||||
github.com/charmbracelet/bubbletea v1.3.4
|
||||
github.com/charmbracelet/glamour v0.9.1
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/charmbracelet/x/ansi v0.8.0
|
||||
github.com/fsnotify/fsnotify v1.8.0
|
||||
github.com/go-logfmt/logfmt v0.6.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/lrstanley/bubblezone v0.0.0-20250315020633-c249a3fe1231
|
||||
github.com/mark3labs/mcp-go v0.17.0
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6
|
||||
github.com/muesli/reflow v0.3.0
|
||||
github.com/muesli/termenv v0.16.0
|
||||
github.com/ncruces/go-sqlite3 v0.25.0
|
||||
github.com/openai/openai-go v0.1.0-beta.2
|
||||
github.com/pressly/goose/v3 v3.24.2
|
||||
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3
|
||||
github.com/spf13/cobra v1.9.1
|
||||
github.com/spf13/viper v1.20.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.116.0 // indirect
|
||||
cloud.google.com/go/auth v0.13.0 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.6.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect
|
||||
github.com/andybalholm/cascadia v1.3.2 // indirect
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.27.27 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.27 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 // indirect
|
||||
github.com/aws/smithy-go v1.20.3 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/aymerick/douceur v0.2.0 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
|
||||
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/disintegration/imaging v1.6.2
|
||||
github.com/dlclark/regexp2 v1.11.4 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/s2a-go v0.1.8 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/kylelemons/godebug v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/mfridman/interpolate v0.0.2 // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/ncruces/julianday v1.0.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
github.com/sagikazarmark/locafero v0.7.0 // indirect
|
||||
github.com/sethvargo/go-retry v0.3.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||
github.com/spf13/afero v1.12.0 // indirect
|
||||
github.com/spf13/cast v1.7.1 // indirect
|
||||
github.com/spf13/pflag v1.0.6 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.9.0 // indirect
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/yuin/goldmark v1.7.8 // indirect
|
||||
github.com/yuin/goldmark-emoji v1.0.5 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
|
||||
go.opentelemetry.io/otel v1.35.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.35.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/crypto v0.37.0 // indirect
|
||||
golang.org/x/image v0.26.0 // indirect
|
||||
golang.org/x/net v0.39.0 // indirect
|
||||
golang.org/x/sync v0.13.0 // indirect
|
||||
golang.org/x/sys v0.32.0 // indirect
|
||||
golang.org/x/term v0.31.0 // indirect
|
||||
golang.org/x/text v0.24.0 // indirect
|
||||
google.golang.org/genai v1.3.0
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250324211829-b45e905df463 // indirect
|
||||
google.golang.org/grpc v1.71.0 // indirect
|
||||
google.golang.org/protobuf v1.36.6 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
44
infra/app.ts
Normal file
44
infra/app.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
export const domain = (() => {
|
||||
if ($app.stage === "production") return "opencode.ai"
|
||||
if ($app.stage === "dev") return "dev.opencode.ai"
|
||||
return `${$app.stage}.dev.opencode.ai`
|
||||
})()
|
||||
|
||||
const bucket = new sst.cloudflare.Bucket("Bucket")
|
||||
|
||||
export const api = new sst.cloudflare.Worker("Api", {
|
||||
domain: `api.${domain}`,
|
||||
handler: "packages/function/src/api.ts",
|
||||
url: true,
|
||||
link: [bucket],
|
||||
transform: {
|
||||
worker: (args) => {
|
||||
args.logpush = true
|
||||
args.bindings = $resolve(args.bindings).apply((bindings) => [
|
||||
...bindings,
|
||||
{
|
||||
name: "SYNC_SERVER",
|
||||
type: "durable_object_namespace",
|
||||
className: "SyncServer",
|
||||
},
|
||||
])
|
||||
args.migrations = {
|
||||
oldTag: "v1",
|
||||
newTag: "v1",
|
||||
//newSqliteClasses: ["SyncServer"],
|
||||
}
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
new sst.cloudflare.StaticSite("Web", {
|
||||
path: "packages/web",
|
||||
domain,
|
||||
environment: {
|
||||
VITE_API_URL: api.url,
|
||||
},
|
||||
build: {
|
||||
command: "bun run build",
|
||||
output: "dist",
|
||||
},
|
||||
})
|
||||
@@ -1,152 +0,0 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"maps"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/history"
|
||||
"github.com/sst/opencode/internal/llm/agent"
|
||||
"github.com/sst/opencode/internal/logging"
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/permission"
|
||||
"github.com/sst/opencode/internal/session"
|
||||
"github.com/sst/opencode/internal/status"
|
||||
"github.com/sst/opencode/internal/tui/theme"
|
||||
)
|
||||
|
||||
type App struct {
|
||||
CurrentSession *session.Session
|
||||
Logs logging.Service
|
||||
Sessions session.Service
|
||||
Messages message.Service
|
||||
History history.Service
|
||||
Permissions permission.Service
|
||||
Status status.Service
|
||||
|
||||
PrimaryAgent agent.Service
|
||||
|
||||
LSPClients map[string]*lsp.Client
|
||||
|
||||
clientsMutex sync.RWMutex
|
||||
|
||||
watcherCancelFuncs []context.CancelFunc
|
||||
cancelFuncsMutex sync.Mutex
|
||||
watcherWG sync.WaitGroup
|
||||
}
|
||||
|
||||
func New(ctx context.Context, conn *sql.DB) (*App, error) {
|
||||
err := logging.InitService(conn)
|
||||
if err != nil {
|
||||
slog.Error("Failed to initialize logging service", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
err = session.InitService(conn)
|
||||
if err != nil {
|
||||
slog.Error("Failed to initialize session service", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
err = message.InitService(conn)
|
||||
if err != nil {
|
||||
slog.Error("Failed to initialize message service", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
err = history.InitService(conn)
|
||||
if err != nil {
|
||||
slog.Error("Failed to initialize history service", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
err = permission.InitService()
|
||||
if err != nil {
|
||||
slog.Error("Failed to initialize permission service", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
err = status.InitService()
|
||||
if err != nil {
|
||||
slog.Error("Failed to initialize status service", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
app := &App{
|
||||
CurrentSession: &session.Session{},
|
||||
Logs: logging.GetService(),
|
||||
Sessions: session.GetService(),
|
||||
Messages: message.GetService(),
|
||||
History: history.GetService(),
|
||||
Permissions: permission.GetService(),
|
||||
Status: status.GetService(),
|
||||
LSPClients: make(map[string]*lsp.Client),
|
||||
}
|
||||
|
||||
// Initialize theme based on configuration
|
||||
app.initTheme()
|
||||
|
||||
// Initialize LSP clients in the background
|
||||
go app.initLSPClients(ctx)
|
||||
|
||||
app.PrimaryAgent, err = agent.NewAgent(
|
||||
config.AgentPrimary,
|
||||
app.Sessions,
|
||||
app.Messages,
|
||||
agent.PrimaryAgentTools(
|
||||
app.Permissions,
|
||||
app.Sessions,
|
||||
app.Messages,
|
||||
app.History,
|
||||
app.LSPClients,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
slog.Error("Failed to create primary agent", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return app, nil
|
||||
}
|
||||
|
||||
// initTheme sets the application theme based on the configuration
|
||||
func (app *App) initTheme() {
|
||||
cfg := config.Get()
|
||||
if cfg == nil || cfg.TUI.Theme == "" {
|
||||
return // Use default theme
|
||||
}
|
||||
|
||||
// Try to set the theme from config
|
||||
err := theme.SetTheme(cfg.TUI.Theme)
|
||||
if err != nil {
|
||||
slog.Warn("Failed to set theme from config, using default theme", "theme", cfg.TUI.Theme, "error", err)
|
||||
} else {
|
||||
slog.Debug("Set theme from config", "theme", cfg.TUI.Theme)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown performs a clean shutdown of the application
|
||||
func (app *App) Shutdown() {
|
||||
// Cancel all watcher goroutines
|
||||
app.cancelFuncsMutex.Lock()
|
||||
for _, cancel := range app.watcherCancelFuncs {
|
||||
cancel()
|
||||
}
|
||||
app.cancelFuncsMutex.Unlock()
|
||||
app.watcherWG.Wait()
|
||||
|
||||
// Perform additional cleanup for LSP clients
|
||||
app.clientsMutex.RLock()
|
||||
clients := make(map[string]*lsp.Client, len(app.LSPClients))
|
||||
maps.Copy(clients, app.LSPClients)
|
||||
app.clientsMutex.RUnlock()
|
||||
|
||||
for name, client := range clients {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
if err := client.Shutdown(shutdownCtx); err != nil {
|
||||
slog.Error("Failed to shutdown LSP client", "name", name, "error", err)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
@@ -1,134 +0,0 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/logging"
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/lsp/watcher"
|
||||
)
|
||||
|
||||
func (app *App) initLSPClients(ctx context.Context) {
|
||||
cfg := config.Get()
|
||||
|
||||
// Initialize LSP clients
|
||||
for name, clientConfig := range cfg.LSP {
|
||||
// Start each client initialization in its own goroutine
|
||||
go app.createAndStartLSPClient(ctx, name, clientConfig.Command, clientConfig.Args...)
|
||||
}
|
||||
slog.Info("LSP clients initialization started in background")
|
||||
}
|
||||
|
||||
// createAndStartLSPClient creates a new LSP client, initializes it, and starts its workspace watcher
|
||||
func (app *App) createAndStartLSPClient(ctx context.Context, name string, command string, args ...string) {
|
||||
// Create a specific context for initialization with a timeout
|
||||
slog.Info("Creating LSP client", "name", name, "command", command, "args", args)
|
||||
|
||||
// Create the LSP client
|
||||
lspClient, err := lsp.NewClient(ctx, command, args...)
|
||||
if err != nil {
|
||||
slog.Error("Failed to create LSP client for", name, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a longer timeout for initialization (some servers take time to start)
|
||||
initCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Initialize with the initialization context
|
||||
_, err = lspClient.InitializeLSPClient(initCtx, config.WorkingDirectory())
|
||||
if err != nil {
|
||||
slog.Error("Initialize failed", "name", name, "error", err)
|
||||
// Clean up the client to prevent resource leaks
|
||||
lspClient.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// Wait for the server to be ready
|
||||
if err := lspClient.WaitForServerReady(initCtx); err != nil {
|
||||
slog.Error("Server failed to become ready", "name", name, "error", err)
|
||||
// We'll continue anyway, as some functionality might still work
|
||||
lspClient.SetServerState(lsp.StateError)
|
||||
} else {
|
||||
slog.Info("LSP server is ready", "name", name)
|
||||
lspClient.SetServerState(lsp.StateReady)
|
||||
}
|
||||
|
||||
slog.Info("LSP client initialized", "name", name)
|
||||
|
||||
// Create a child context that can be canceled when the app is shutting down
|
||||
watchCtx, cancelFunc := context.WithCancel(ctx)
|
||||
|
||||
// Create a context with the server name for better identification
|
||||
watchCtx = context.WithValue(watchCtx, "serverName", name)
|
||||
|
||||
// Create the workspace watcher
|
||||
workspaceWatcher := watcher.NewWorkspaceWatcher(lspClient)
|
||||
|
||||
// Store the cancel function to be called during cleanup
|
||||
app.cancelFuncsMutex.Lock()
|
||||
app.watcherCancelFuncs = append(app.watcherCancelFuncs, cancelFunc)
|
||||
app.cancelFuncsMutex.Unlock()
|
||||
|
||||
// Add the watcher to a WaitGroup to track active goroutines
|
||||
app.watcherWG.Add(1)
|
||||
|
||||
// Add to map with mutex protection before starting goroutine
|
||||
app.clientsMutex.Lock()
|
||||
app.LSPClients[name] = lspClient
|
||||
app.clientsMutex.Unlock()
|
||||
|
||||
go app.runWorkspaceWatcher(watchCtx, name, workspaceWatcher)
|
||||
}
|
||||
|
||||
// runWorkspaceWatcher executes the workspace watcher for an LSP client
|
||||
func (app *App) runWorkspaceWatcher(ctx context.Context, name string, workspaceWatcher *watcher.WorkspaceWatcher) {
|
||||
defer app.watcherWG.Done()
|
||||
defer logging.RecoverPanic("LSP-"+name, func() {
|
||||
// Try to restart the client
|
||||
app.restartLSPClient(ctx, name)
|
||||
})
|
||||
|
||||
workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory())
|
||||
slog.Info("Workspace watcher stopped", "client", name)
|
||||
}
|
||||
|
||||
// restartLSPClient attempts to restart a crashed or failed LSP client
|
||||
func (app *App) restartLSPClient(ctx context.Context, name string) {
|
||||
// Get the original configuration
|
||||
cfg := config.Get()
|
||||
clientConfig, exists := cfg.LSP[name]
|
||||
if !exists {
|
||||
slog.Error("Cannot restart client, configuration not found", "client", name)
|
||||
return
|
||||
}
|
||||
|
||||
// Clean up the old client if it exists
|
||||
app.clientsMutex.Lock()
|
||||
oldClient, exists := app.LSPClients[name]
|
||||
if exists {
|
||||
delete(app.LSPClients, name) // Remove from map before potentially slow shutdown
|
||||
}
|
||||
app.clientsMutex.Unlock()
|
||||
|
||||
if exists && oldClient != nil {
|
||||
// Try to shut it down gracefully, but don't block on errors
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_ = oldClient.Shutdown(shutdownCtx)
|
||||
cancel()
|
||||
|
||||
// Ensure we close the client to free resources
|
||||
_ = oldClient.Close()
|
||||
}
|
||||
|
||||
// Wait a moment before restarting to avoid rapid restart cycles
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Create a new client using the shared function
|
||||
app.createAndStartLSPClient(ctx, name, clientConfig.Command, clientConfig.Args...)
|
||||
slog.Info("Successfully restarted LSP client", "client", name)
|
||||
}
|
||||
@@ -1,797 +0,0 @@
|
||||
// Package config manages application configuration from various sources.
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
)
|
||||
|
||||
// MCPType defines the type of MCP (Model Control Protocol) server.
|
||||
type MCPType string
|
||||
|
||||
// Supported MCP types
|
||||
const (
|
||||
MCPStdio MCPType = "stdio"
|
||||
MCPSse MCPType = "sse"
|
||||
)
|
||||
|
||||
// MCPServer defines the configuration for a Model Control Protocol server.
|
||||
type MCPServer struct {
|
||||
Command string `json:"command"`
|
||||
Env []string `json:"env"`
|
||||
Args []string `json:"args"`
|
||||
Type MCPType `json:"type"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
}
|
||||
|
||||
type AgentName string
|
||||
|
||||
const (
|
||||
AgentPrimary AgentName = "primary"
|
||||
AgentTask AgentName = "task"
|
||||
AgentTitle AgentName = "title"
|
||||
)
|
||||
|
||||
// Agent defines configuration for different LLM models and their token limits.
|
||||
type Agent struct {
|
||||
Model models.ModelID `json:"model"`
|
||||
MaxTokens int64 `json:"maxTokens"`
|
||||
ReasoningEffort string `json:"reasoningEffort"` // For openai models low,medium,heigh
|
||||
}
|
||||
|
||||
// Provider defines configuration for an LLM provider.
|
||||
type Provider struct {
|
||||
APIKey string `json:"apiKey"`
|
||||
Disabled bool `json:"disabled"`
|
||||
}
|
||||
|
||||
// Data defines storage configuration.
|
||||
type Data struct {
|
||||
Directory string `json:"directory"`
|
||||
}
|
||||
|
||||
// LSPConfig defines configuration for Language Server Protocol integration.
|
||||
type LSPConfig struct {
|
||||
Disabled bool `json:"enabled"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args"`
|
||||
Options any `json:"options"`
|
||||
}
|
||||
|
||||
// TUIConfig defines the configuration for the Terminal User Interface.
|
||||
type TUIConfig struct {
|
||||
Theme string `json:"theme,omitempty"`
|
||||
CustomTheme map[string]any `json:"customTheme,omitempty"`
|
||||
}
|
||||
|
||||
// Config is the main configuration structure for the application.
|
||||
type Config struct {
|
||||
Data Data `json:"data"`
|
||||
WorkingDir string `json:"wd,omitempty"`
|
||||
MCPServers map[string]MCPServer `json:"mcpServers,omitempty"`
|
||||
Providers map[models.ModelProvider]Provider `json:"providers,omitempty"`
|
||||
LSP map[string]LSPConfig `json:"lsp,omitempty"`
|
||||
Agents map[AgentName]Agent `json:"agents"`
|
||||
Debug bool `json:"debug,omitempty"`
|
||||
DebugLSP bool `json:"debugLSP,omitempty"`
|
||||
ContextPaths []string `json:"contextPaths,omitempty"`
|
||||
TUI TUIConfig `json:"tui"`
|
||||
}
|
||||
|
||||
// Application constants
|
||||
const (
|
||||
defaultDataDirectory = ".opencode"
|
||||
defaultLogLevel = "info"
|
||||
appName = "opencode"
|
||||
|
||||
MaxTokensFallbackDefault = 4096
|
||||
)
|
||||
|
||||
var defaultContextPaths = []string{
|
||||
".github/copilot-instructions.md",
|
||||
".cursorrules",
|
||||
".cursor/rules/",
|
||||
"CLAUDE.md",
|
||||
"CLAUDE.local.md",
|
||||
"CONTEXT.md",
|
||||
"CONTEXT.local.md",
|
||||
"opencode.md",
|
||||
"opencode.local.md",
|
||||
"OpenCode.md",
|
||||
"OpenCode.local.md",
|
||||
"OPENCODE.md",
|
||||
"OPENCODE.local.md",
|
||||
}
|
||||
|
||||
// Global configuration instance
|
||||
var cfg *Config
|
||||
|
||||
// Load initializes the configuration from environment variables and config files.
|
||||
// If debug is true, debug mode is enabled and log level is set to debug.
|
||||
// It returns an error if configuration loading fails.
|
||||
func Load(workingDir string, debug bool, lvl *slog.LevelVar) (*Config, error) {
|
||||
if cfg != nil {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
cfg = &Config{
|
||||
WorkingDir: workingDir,
|
||||
MCPServers: make(map[string]MCPServer),
|
||||
Providers: make(map[models.ModelProvider]Provider),
|
||||
LSP: make(map[string]LSPConfig),
|
||||
}
|
||||
|
||||
configureViper()
|
||||
setDefaults(debug)
|
||||
|
||||
// Read global config
|
||||
if err := readConfig(viper.ReadInConfig()); err != nil {
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
// Load and merge local config
|
||||
mergeLocalConfig(workingDir)
|
||||
|
||||
setProviderDefaults()
|
||||
|
||||
// Apply configuration to the struct
|
||||
if err := viper.Unmarshal(cfg); err != nil {
|
||||
return cfg, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
applyDefaultValues()
|
||||
|
||||
defaultLevel := slog.LevelInfo
|
||||
if cfg.Debug {
|
||||
defaultLevel = slog.LevelDebug
|
||||
}
|
||||
lvl.Set(defaultLevel)
|
||||
slog.SetLogLoggerLevel(defaultLevel)
|
||||
|
||||
// Validate configuration
|
||||
if err := Validate(); err != nil {
|
||||
return cfg, fmt.Errorf("config validation failed: %w", err)
|
||||
}
|
||||
|
||||
if cfg.Agents == nil {
|
||||
cfg.Agents = make(map[AgentName]Agent)
|
||||
}
|
||||
|
||||
// Override the max tokens for title agent
|
||||
cfg.Agents[AgentTitle] = Agent{
|
||||
Model: cfg.Agents[AgentTitle].Model,
|
||||
MaxTokens: 80,
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// configureViper sets up viper's configuration paths and environment variables.
|
||||
func configureViper() {
|
||||
viper.SetConfigName(fmt.Sprintf(".%s", appName))
|
||||
viper.SetConfigType("json")
|
||||
viper.AddConfigPath("$HOME")
|
||||
viper.AddConfigPath(fmt.Sprintf("$XDG_CONFIG_HOME/%s", appName))
|
||||
viper.AddConfigPath(fmt.Sprintf("$HOME/.config/%s", appName))
|
||||
viper.SetEnvPrefix(strings.ToUpper(appName))
|
||||
viper.AutomaticEnv()
|
||||
}
|
||||
|
||||
// setDefaults configures default values for configuration options.
|
||||
func setDefaults(debug bool) {
|
||||
viper.SetDefault("data.directory", defaultDataDirectory)
|
||||
viper.SetDefault("contextPaths", defaultContextPaths)
|
||||
viper.SetDefault("tui.theme", "opencode")
|
||||
|
||||
if debug {
|
||||
viper.SetDefault("debug", true)
|
||||
viper.Set("log.level", "debug")
|
||||
} else {
|
||||
viper.SetDefault("debug", false)
|
||||
viper.SetDefault("log.level", defaultLogLevel)
|
||||
}
|
||||
}
|
||||
|
||||
// setProviderDefaults configures LLM provider defaults based on provider provided by
|
||||
// environment variables and configuration file.
|
||||
func setProviderDefaults() {
|
||||
// Set all API keys we can find in the environment
|
||||
if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" {
|
||||
viper.SetDefault("providers.anthropic.apiKey", apiKey)
|
||||
}
|
||||
if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" {
|
||||
viper.SetDefault("providers.openai.apiKey", apiKey)
|
||||
}
|
||||
if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" {
|
||||
viper.SetDefault("providers.gemini.apiKey", apiKey)
|
||||
}
|
||||
if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" {
|
||||
viper.SetDefault("providers.groq.apiKey", apiKey)
|
||||
}
|
||||
if apiKey := os.Getenv("OPENROUTER_API_KEY"); apiKey != "" {
|
||||
viper.SetDefault("providers.openrouter.apiKey", apiKey)
|
||||
}
|
||||
if apiKey := os.Getenv("XAI_API_KEY"); apiKey != "" {
|
||||
viper.SetDefault("providers.xai.apiKey", apiKey)
|
||||
}
|
||||
if apiKey := os.Getenv("AZURE_OPENAI_ENDPOINT"); apiKey != "" {
|
||||
// api-key may be empty when using Entra ID credentials – that's okay
|
||||
viper.SetDefault("providers.azure.apiKey", os.Getenv("AZURE_OPENAI_API_KEY"))
|
||||
}
|
||||
|
||||
// Use this order to set the default models
|
||||
// 1. Anthropic
|
||||
// 2. OpenAI
|
||||
// 3. Google Gemini
|
||||
// 4. Groq
|
||||
// 5. OpenRouter
|
||||
// 6. AWS Bedrock
|
||||
// 7. Azure
|
||||
|
||||
// Anthropic configuration
|
||||
if key := viper.GetString("providers.anthropic.apiKey"); strings.TrimSpace(key) != "" {
|
||||
viper.SetDefault("agents.primary.model", models.Claude37Sonnet)
|
||||
viper.SetDefault("agents.task.model", models.Claude37Sonnet)
|
||||
viper.SetDefault("agents.title.model", models.Claude37Sonnet)
|
||||
return
|
||||
}
|
||||
|
||||
// OpenAI configuration
|
||||
if key := viper.GetString("providers.openai.apiKey"); strings.TrimSpace(key) != "" {
|
||||
viper.SetDefault("agents.primary.model", models.GPT41)
|
||||
viper.SetDefault("agents.task.model", models.GPT41Mini)
|
||||
viper.SetDefault("agents.title.model", models.GPT41Mini)
|
||||
return
|
||||
}
|
||||
|
||||
// Google Gemini configuration
|
||||
if key := viper.GetString("providers.gemini.apiKey"); strings.TrimSpace(key) != "" {
|
||||
viper.SetDefault("agents.primary.model", models.Gemini25)
|
||||
viper.SetDefault("agents.task.model", models.Gemini25Flash)
|
||||
viper.SetDefault("agents.title.model", models.Gemini25Flash)
|
||||
return
|
||||
}
|
||||
|
||||
// Groq configuration
|
||||
if key := viper.GetString("providers.groq.apiKey"); strings.TrimSpace(key) != "" {
|
||||
viper.SetDefault("agents.primary.model", models.QWENQwq)
|
||||
viper.SetDefault("agents.task.model", models.QWENQwq)
|
||||
viper.SetDefault("agents.title.model", models.QWENQwq)
|
||||
return
|
||||
}
|
||||
|
||||
// OpenRouter configuration
|
||||
if key := viper.GetString("providers.openrouter.apiKey"); strings.TrimSpace(key) != "" {
|
||||
viper.SetDefault("agents.primary.model", models.OpenRouterClaude37Sonnet)
|
||||
viper.SetDefault("agents.task.model", models.OpenRouterClaude37Sonnet)
|
||||
viper.SetDefault("agents.title.model", models.OpenRouterClaude35Haiku)
|
||||
return
|
||||
}
|
||||
|
||||
// XAI configuration
|
||||
if key := viper.GetString("providers.xai.apiKey"); strings.TrimSpace(key) != "" {
|
||||
viper.SetDefault("agents.primary.model", models.XAIGrok3Beta)
|
||||
viper.SetDefault("agents.task.model", models.XAIGrok3Beta)
|
||||
viper.SetDefault("agents.title.model", models.XAiGrok3MiniFastBeta)
|
||||
return
|
||||
}
|
||||
|
||||
// AWS Bedrock configuration
|
||||
if hasAWSCredentials() {
|
||||
viper.SetDefault("agents.primary.model", models.BedrockClaude37Sonnet)
|
||||
viper.SetDefault("agents.task.model", models.BedrockClaude37Sonnet)
|
||||
viper.SetDefault("agents.title.model", models.BedrockClaude37Sonnet)
|
||||
return
|
||||
}
|
||||
|
||||
// Azure OpenAI configuration
|
||||
if os.Getenv("AZURE_OPENAI_ENDPOINT") != "" {
|
||||
viper.SetDefault("agents.primary.model", models.AzureGPT41)
|
||||
viper.SetDefault("agents.task.model", models.AzureGPT41Mini)
|
||||
viper.SetDefault("agents.title.model", models.AzureGPT41Mini)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// hasAWSCredentials checks if AWS credentials are available in the environment.
|
||||
func hasAWSCredentials() bool {
|
||||
// Check for explicit AWS credentials
|
||||
if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for AWS profile
|
||||
if os.Getenv("AWS_PROFILE") != "" || os.Getenv("AWS_DEFAULT_PROFILE") != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for AWS region
|
||||
if os.Getenv("AWS_REGION") != "" || os.Getenv("AWS_DEFAULT_REGION") != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if running on EC2 with instance profile
|
||||
if os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
|
||||
os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// readConfig handles the result of reading a configuration file.
|
||||
func readConfig(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// It's okay if the config file doesn't exist
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to read config: %w", err)
|
||||
}
|
||||
|
||||
// mergeLocalConfig loads and merges configuration from the local directory.
|
||||
func mergeLocalConfig(workingDir string) {
|
||||
local := viper.New()
|
||||
local.SetConfigName(fmt.Sprintf(".%s", appName))
|
||||
local.SetConfigType("json")
|
||||
local.AddConfigPath(workingDir)
|
||||
|
||||
// Merge local config if it exists
|
||||
if err := local.ReadInConfig(); err == nil {
|
||||
viper.MergeConfigMap(local.AllSettings())
|
||||
}
|
||||
}
|
||||
|
||||
// applyDefaultValues sets default values for configuration fields that need processing.
|
||||
func applyDefaultValues() {
|
||||
// Set default MCP type if not specified
|
||||
for k, v := range cfg.MCPServers {
|
||||
if v.Type == "" {
|
||||
v.Type = MCPStdio
|
||||
cfg.MCPServers[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// It validates model IDs and providers, ensuring they are supported.
|
||||
func validateAgent(cfg *Config, name AgentName, agent Agent) error {
|
||||
// Check if model exists
|
||||
model, modelExists := models.SupportedModels[agent.Model]
|
||||
if !modelExists {
|
||||
slog.Warn("unsupported model configured, reverting to default",
|
||||
"agent", name,
|
||||
"configured_model", agent.Model)
|
||||
|
||||
// Set default model based on available providers
|
||||
if setDefaultModelForAgent(name) {
|
||||
slog.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
|
||||
} else {
|
||||
return fmt.Errorf("no valid provider available for agent %s", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if provider for the model is configured
|
||||
provider := model.Provider
|
||||
providerCfg, providerExists := cfg.Providers[provider]
|
||||
|
||||
if !providerExists {
|
||||
// Provider not configured, check if we have environment variables
|
||||
apiKey := getProviderAPIKey(provider)
|
||||
if apiKey == "" {
|
||||
slog.Warn("provider not configured for model, reverting to default",
|
||||
"agent", name,
|
||||
"model", agent.Model,
|
||||
"provider", provider)
|
||||
|
||||
// Set default model based on available providers
|
||||
if setDefaultModelForAgent(name) {
|
||||
slog.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
|
||||
} else {
|
||||
return fmt.Errorf("no valid provider available for agent %s", name)
|
||||
}
|
||||
} else {
|
||||
// Add provider with API key from environment
|
||||
cfg.Providers[provider] = Provider{
|
||||
APIKey: apiKey,
|
||||
}
|
||||
slog.Info("added provider from environment", "provider", provider)
|
||||
}
|
||||
} else if providerCfg.Disabled || providerCfg.APIKey == "" {
|
||||
// Provider is disabled or has no API key
|
||||
slog.Warn("provider is disabled or has no API key, reverting to default",
|
||||
"agent", name,
|
||||
"model", agent.Model,
|
||||
"provider", provider)
|
||||
|
||||
// Set default model based on available providers
|
||||
if setDefaultModelForAgent(name) {
|
||||
slog.Info("set default model for agent", "agent", name, "model", cfg.Agents[name].Model)
|
||||
} else {
|
||||
return fmt.Errorf("no valid provider available for agent %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate max tokens
|
||||
if agent.MaxTokens <= 0 {
|
||||
slog.Warn("invalid max tokens, setting to default",
|
||||
"agent", name,
|
||||
"model", agent.Model,
|
||||
"max_tokens", agent.MaxTokens)
|
||||
|
||||
// Update the agent with default max tokens
|
||||
updatedAgent := cfg.Agents[name]
|
||||
if model.DefaultMaxTokens > 0 {
|
||||
updatedAgent.MaxTokens = model.DefaultMaxTokens
|
||||
} else {
|
||||
updatedAgent.MaxTokens = MaxTokensFallbackDefault
|
||||
}
|
||||
cfg.Agents[name] = updatedAgent
|
||||
} else if model.ContextWindow > 0 && agent.MaxTokens > model.ContextWindow/2 {
|
||||
// Ensure max tokens doesn't exceed half the context window (reasonable limit)
|
||||
slog.Warn("max tokens exceeds half the context window, adjusting",
|
||||
"agent", name,
|
||||
"model", agent.Model,
|
||||
"max_tokens", agent.MaxTokens,
|
||||
"context_window", model.ContextWindow)
|
||||
|
||||
// Update the agent with adjusted max tokens
|
||||
updatedAgent := cfg.Agents[name]
|
||||
updatedAgent.MaxTokens = model.ContextWindow / 2
|
||||
cfg.Agents[name] = updatedAgent
|
||||
}
|
||||
|
||||
// Validate reasoning effort for models that support reasoning
|
||||
if model.CanReason && provider == models.ProviderOpenAI {
|
||||
if agent.ReasoningEffort == "" {
|
||||
// Set default reasoning effort for models that support it
|
||||
slog.Info("setting default reasoning effort for model that supports reasoning",
|
||||
"agent", name,
|
||||
"model", agent.Model)
|
||||
|
||||
// Update the agent with default reasoning effort
|
||||
updatedAgent := cfg.Agents[name]
|
||||
updatedAgent.ReasoningEffort = "medium"
|
||||
cfg.Agents[name] = updatedAgent
|
||||
} else {
|
||||
// Check if reasoning effort is valid (low, medium, high)
|
||||
effort := strings.ToLower(agent.ReasoningEffort)
|
||||
if effort != "low" && effort != "medium" && effort != "high" {
|
||||
slog.Warn("invalid reasoning effort, setting to medium",
|
||||
"agent", name,
|
||||
"model", agent.Model,
|
||||
"reasoning_effort", agent.ReasoningEffort)
|
||||
|
||||
// Update the agent with valid reasoning effort
|
||||
updatedAgent := cfg.Agents[name]
|
||||
updatedAgent.ReasoningEffort = "medium"
|
||||
cfg.Agents[name] = updatedAgent
|
||||
}
|
||||
}
|
||||
} else if !model.CanReason && agent.ReasoningEffort != "" {
|
||||
// Model doesn't support reasoning but reasoning effort is set
|
||||
slog.Warn("model doesn't support reasoning but reasoning effort is set, ignoring",
|
||||
"agent", name,
|
||||
"model", agent.Model,
|
||||
"reasoning_effort", agent.ReasoningEffort)
|
||||
|
||||
// Update the agent to remove reasoning effort
|
||||
updatedAgent := cfg.Agents[name]
|
||||
updatedAgent.ReasoningEffort = ""
|
||||
cfg.Agents[name] = updatedAgent
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate checks if the configuration is valid and applies defaults where needed.
|
||||
func Validate() error {
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("config not loaded")
|
||||
}
|
||||
|
||||
// Validate agent models
|
||||
for name, agent := range cfg.Agents {
|
||||
if err := validateAgent(cfg, name, agent); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate providers
|
||||
for provider, providerCfg := range cfg.Providers {
|
||||
if providerCfg.APIKey == "" && !providerCfg.Disabled {
|
||||
slog.Warn("provider has no API key, marking as disabled", "provider", provider)
|
||||
providerCfg.Disabled = true
|
||||
cfg.Providers[provider] = providerCfg
|
||||
}
|
||||
}
|
||||
|
||||
// Validate LSP configurations
|
||||
for language, lspConfig := range cfg.LSP {
|
||||
if lspConfig.Command == "" && !lspConfig.Disabled {
|
||||
slog.Warn("LSP configuration has no command, marking as disabled", "language", language)
|
||||
lspConfig.Disabled = true
|
||||
cfg.LSP[language] = lspConfig
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getProviderAPIKey gets the API key for a provider from environment variables
|
||||
func getProviderAPIKey(provider models.ModelProvider) string {
|
||||
switch provider {
|
||||
case models.ProviderAnthropic:
|
||||
return os.Getenv("ANTHROPIC_API_KEY")
|
||||
case models.ProviderOpenAI:
|
||||
return os.Getenv("OPENAI_API_KEY")
|
||||
case models.ProviderGemini:
|
||||
return os.Getenv("GEMINI_API_KEY")
|
||||
case models.ProviderGROQ:
|
||||
return os.Getenv("GROQ_API_KEY")
|
||||
case models.ProviderAzure:
|
||||
return os.Getenv("AZURE_OPENAI_API_KEY")
|
||||
case models.ProviderOpenRouter:
|
||||
return os.Getenv("OPENROUTER_API_KEY")
|
||||
case models.ProviderBedrock:
|
||||
if hasAWSCredentials() {
|
||||
return "aws-credentials-available"
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// setDefaultModelForAgent sets a default model for an agent based on available providers
|
||||
func setDefaultModelForAgent(agent AgentName) bool {
|
||||
// Check providers in order of preference
|
||||
if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" {
|
||||
maxTokens := int64(5000)
|
||||
if agent == AgentTitle {
|
||||
maxTokens = 80
|
||||
}
|
||||
cfg.Agents[agent] = Agent{
|
||||
Model: models.Claude37Sonnet,
|
||||
MaxTokens: maxTokens,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" {
|
||||
var model models.ModelID
|
||||
maxTokens := int64(5000)
|
||||
reasoningEffort := ""
|
||||
|
||||
switch agent {
|
||||
case AgentTitle:
|
||||
model = models.GPT41Mini
|
||||
maxTokens = 80
|
||||
case AgentTask:
|
||||
model = models.GPT41Mini
|
||||
default:
|
||||
model = models.GPT41
|
||||
}
|
||||
|
||||
// Check if model supports reasoning
|
||||
if modelInfo, ok := models.SupportedModels[model]; ok && modelInfo.CanReason {
|
||||
reasoningEffort = "medium"
|
||||
}
|
||||
|
||||
cfg.Agents[agent] = Agent{
|
||||
Model: model,
|
||||
MaxTokens: maxTokens,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
if apiKey := os.Getenv("OPENROUTER_API_KEY"); apiKey != "" {
|
||||
var model models.ModelID
|
||||
maxTokens := int64(5000)
|
||||
reasoningEffort := ""
|
||||
|
||||
switch agent {
|
||||
case AgentTitle:
|
||||
model = models.OpenRouterClaude35Haiku
|
||||
maxTokens = 80
|
||||
case AgentTask:
|
||||
model = models.OpenRouterClaude37Sonnet
|
||||
default:
|
||||
model = models.OpenRouterClaude37Sonnet
|
||||
}
|
||||
|
||||
// Check if model supports reasoning
|
||||
if modelInfo, ok := models.SupportedModels[model]; ok && modelInfo.CanReason {
|
||||
reasoningEffort = "medium"
|
||||
}
|
||||
|
||||
cfg.Agents[agent] = Agent{
|
||||
Model: model,
|
||||
MaxTokens: maxTokens,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" {
|
||||
var model models.ModelID
|
||||
maxTokens := int64(5000)
|
||||
|
||||
if agent == AgentTitle {
|
||||
model = models.Gemini25Flash
|
||||
maxTokens = 80
|
||||
} else {
|
||||
model = models.Gemini25
|
||||
}
|
||||
|
||||
cfg.Agents[agent] = Agent{
|
||||
Model: model,
|
||||
MaxTokens: maxTokens,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" {
|
||||
maxTokens := int64(5000)
|
||||
if agent == AgentTitle {
|
||||
maxTokens = 80
|
||||
}
|
||||
|
||||
cfg.Agents[agent] = Agent{
|
||||
Model: models.QWENQwq,
|
||||
MaxTokens: maxTokens,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
if hasAWSCredentials() {
|
||||
maxTokens := int64(5000)
|
||||
if agent == AgentTitle {
|
||||
maxTokens = 80
|
||||
}
|
||||
|
||||
cfg.Agents[agent] = Agent{
|
||||
Model: models.BedrockClaude37Sonnet,
|
||||
MaxTokens: maxTokens,
|
||||
ReasoningEffort: "medium", // Claude models support reasoning
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Get returns the current configuration.
|
||||
// It's safe to call this function multiple times.
|
||||
func Get() *Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
// WorkingDirectory returns the current working directory from the configuration.
|
||||
func WorkingDirectory() string {
|
||||
if cfg == nil {
|
||||
panic("config not loaded")
|
||||
}
|
||||
return cfg.WorkingDir
|
||||
}
|
||||
|
||||
// GetHostname returns the system hostname or "User" if it can't be determined
|
||||
func GetHostname() (string, error) {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return "User", err
|
||||
}
|
||||
return hostname, nil
|
||||
}
|
||||
|
||||
// GetUsername returns the current user's username
|
||||
func GetUsername() (string, error) {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
return "User", err
|
||||
}
|
||||
return currentUser.Username, nil
|
||||
}
|
||||
|
||||
func UpdateAgentModel(agentName AgentName, modelID models.ModelID) error {
|
||||
if cfg == nil {
|
||||
panic("config not loaded")
|
||||
}
|
||||
|
||||
existingAgentCfg := cfg.Agents[agentName]
|
||||
|
||||
model, ok := models.SupportedModels[modelID]
|
||||
if !ok {
|
||||
return fmt.Errorf("model %s not supported", modelID)
|
||||
}
|
||||
|
||||
maxTokens := existingAgentCfg.MaxTokens
|
||||
if model.DefaultMaxTokens > 0 {
|
||||
maxTokens = model.DefaultMaxTokens
|
||||
}
|
||||
|
||||
newAgentCfg := Agent{
|
||||
Model: modelID,
|
||||
MaxTokens: maxTokens,
|
||||
ReasoningEffort: existingAgentCfg.ReasoningEffort,
|
||||
}
|
||||
cfg.Agents[agentName] = newAgentCfg
|
||||
|
||||
if err := validateAgent(cfg, agentName, newAgentCfg); err != nil {
|
||||
// revert config update on failure
|
||||
cfg.Agents[agentName] = existingAgentCfg
|
||||
return fmt.Errorf("failed to update agent model: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTheme updates the theme in the configuration and writes it to the config file.
|
||||
func UpdateTheme(themeName string) error {
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("config not loaded")
|
||||
}
|
||||
|
||||
// Update the in-memory config
|
||||
cfg.TUI.Theme = themeName
|
||||
|
||||
// Get the config file path
|
||||
configFile := viper.ConfigFileUsed()
|
||||
var configData []byte
|
||||
if configFile == "" {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
configFile = filepath.Join(homeDir, fmt.Sprintf(".%s.json", appName))
|
||||
slog.Info("config file not found, creating new one", "path", configFile)
|
||||
configData = []byte(`{}`)
|
||||
} else {
|
||||
// Read the existing config file
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
configData = data
|
||||
}
|
||||
|
||||
// Parse the JSON
|
||||
var configMap map[string]any
|
||||
if err := json.Unmarshal(configData, &configMap); err != nil {
|
||||
return fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
// Update just the theme value
|
||||
tuiConfig, ok := configMap["tui"].(map[string]any)
|
||||
if !ok {
|
||||
// TUI config doesn't exist yet, create it
|
||||
configMap["tui"] = map[string]any{"theme": themeName}
|
||||
} else {
|
||||
// Update existing TUI config
|
||||
tuiConfig["theme"] = themeName
|
||||
configMap["tui"] = tuiConfig
|
||||
}
|
||||
|
||||
// Write the updated config back to file
|
||||
updatedData, err := json.MarshalIndent(configMap, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(configFile, updatedData, 0o644); err != nil {
|
||||
return fmt.Errorf("failed to write config file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"log/slog"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
)
|
||||
|
||||
func Connect() (*sql.DB, error) {
|
||||
dataDir := config.Get().Data.Directory
|
||||
if dataDir == "" {
|
||||
return nil, fmt.Errorf("data.dir is not set")
|
||||
}
|
||||
if err := os.MkdirAll(dataDir, 0o700); err != nil {
|
||||
return nil, fmt.Errorf("failed to create data directory: %w", err)
|
||||
}
|
||||
dbPath := filepath.Join(dataDir, "opencode.db")
|
||||
// Open the SQLite database
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
// Verify connection
|
||||
if err = db.Ping(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
// Set pragmas for better performance
|
||||
pragmas := []string{
|
||||
"PRAGMA foreign_keys = ON;",
|
||||
"PRAGMA journal_mode = WAL;",
|
||||
"PRAGMA page_size = 4096;",
|
||||
"PRAGMA cache_size = -8000;",
|
||||
"PRAGMA synchronous = NORMAL;",
|
||||
}
|
||||
|
||||
for _, pragma := range pragmas {
|
||||
if _, err = db.Exec(pragma); err != nil {
|
||||
slog.Error("Failed to set pragma", pragma, err)
|
||||
} else {
|
||||
slog.Debug("Set pragma", "pragma", pragma)
|
||||
}
|
||||
}
|
||||
|
||||
goose.SetBaseFS(FS)
|
||||
|
||||
if err := goose.SetDialect("sqlite3"); err != nil {
|
||||
slog.Error("Failed to set dialect", "error", err)
|
||||
return nil, fmt.Errorf("failed to set dialect: %w", err)
|
||||
}
|
||||
|
||||
if err := goose.Up(db, "migrations"); err != nil {
|
||||
slog.Error("Failed to apply migrations", "error", err)
|
||||
return nil, fmt.Errorf("failed to apply migrations: %w", err)
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
@@ -1,328 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.29.0
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type DBTX interface {
|
||||
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
|
||||
PrepareContext(context.Context, string) (*sql.Stmt, error)
|
||||
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
func New(db DBTX) *Queries {
|
||||
return &Queries{db: db}
|
||||
}
|
||||
|
||||
func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
|
||||
q := Queries{db: db}
|
||||
var err error
|
||||
if q.createFileStmt, err = db.PrepareContext(ctx, createFile); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query CreateFile: %w", err)
|
||||
}
|
||||
if q.createLogStmt, err = db.PrepareContext(ctx, createLog); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query CreateLog: %w", err)
|
||||
}
|
||||
if q.createMessageStmt, err = db.PrepareContext(ctx, createMessage); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query CreateMessage: %w", err)
|
||||
}
|
||||
if q.createSessionStmt, err = db.PrepareContext(ctx, createSession); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query CreateSession: %w", err)
|
||||
}
|
||||
if q.deleteFileStmt, err = db.PrepareContext(ctx, deleteFile); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query DeleteFile: %w", err)
|
||||
}
|
||||
if q.deleteMessageStmt, err = db.PrepareContext(ctx, deleteMessage); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query DeleteMessage: %w", err)
|
||||
}
|
||||
if q.deleteSessionStmt, err = db.PrepareContext(ctx, deleteSession); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query DeleteSession: %w", err)
|
||||
}
|
||||
if q.deleteSessionFilesStmt, err = db.PrepareContext(ctx, deleteSessionFiles); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query DeleteSessionFiles: %w", err)
|
||||
}
|
||||
if q.deleteSessionMessagesStmt, err = db.PrepareContext(ctx, deleteSessionMessages); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query DeleteSessionMessages: %w", err)
|
||||
}
|
||||
if q.getFileStmt, err = db.PrepareContext(ctx, getFile); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query GetFile: %w", err)
|
||||
}
|
||||
if q.getFileByPathAndSessionStmt, err = db.PrepareContext(ctx, getFileByPathAndSession); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query GetFileByPathAndSession: %w", err)
|
||||
}
|
||||
if q.getMessageStmt, err = db.PrepareContext(ctx, getMessage); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query GetMessage: %w", err)
|
||||
}
|
||||
if q.getSessionByIDStmt, err = db.PrepareContext(ctx, getSessionByID); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query GetSessionByID: %w", err)
|
||||
}
|
||||
if q.listAllLogsStmt, err = db.PrepareContext(ctx, listAllLogs); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query ListAllLogs: %w", err)
|
||||
}
|
||||
if q.listFilesByPathStmt, err = db.PrepareContext(ctx, listFilesByPath); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query ListFilesByPath: %w", err)
|
||||
}
|
||||
if q.listFilesBySessionStmt, err = db.PrepareContext(ctx, listFilesBySession); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query ListFilesBySession: %w", err)
|
||||
}
|
||||
if q.listLatestSessionFilesStmt, err = db.PrepareContext(ctx, listLatestSessionFiles); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query ListLatestSessionFiles: %w", err)
|
||||
}
|
||||
if q.listLogsBySessionStmt, err = db.PrepareContext(ctx, listLogsBySession); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query ListLogsBySession: %w", err)
|
||||
}
|
||||
if q.listMessagesBySessionStmt, err = db.PrepareContext(ctx, listMessagesBySession); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query ListMessagesBySession: %w", err)
|
||||
}
|
||||
if q.listMessagesBySessionAfterStmt, err = db.PrepareContext(ctx, listMessagesBySessionAfter); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query ListMessagesBySessionAfter: %w", err)
|
||||
}
|
||||
if q.listNewFilesStmt, err = db.PrepareContext(ctx, listNewFiles); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query ListNewFiles: %w", err)
|
||||
}
|
||||
if q.listSessionsStmt, err = db.PrepareContext(ctx, listSessions); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query ListSessions: %w", err)
|
||||
}
|
||||
if q.updateFileStmt, err = db.PrepareContext(ctx, updateFile); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query UpdateFile: %w", err)
|
||||
}
|
||||
if q.updateMessageStmt, err = db.PrepareContext(ctx, updateMessage); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query UpdateMessage: %w", err)
|
||||
}
|
||||
if q.updateSessionStmt, err = db.PrepareContext(ctx, updateSession); err != nil {
|
||||
return nil, fmt.Errorf("error preparing query UpdateSession: %w", err)
|
||||
}
|
||||
return &q, nil
|
||||
}
|
||||
|
||||
func (q *Queries) Close() error {
|
||||
var err error
|
||||
if q.createFileStmt != nil {
|
||||
if cerr := q.createFileStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing createFileStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.createLogStmt != nil {
|
||||
if cerr := q.createLogStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing createLogStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.createMessageStmt != nil {
|
||||
if cerr := q.createMessageStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing createMessageStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.createSessionStmt != nil {
|
||||
if cerr := q.createSessionStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing createSessionStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.deleteFileStmt != nil {
|
||||
if cerr := q.deleteFileStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing deleteFileStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.deleteMessageStmt != nil {
|
||||
if cerr := q.deleteMessageStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing deleteMessageStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.deleteSessionStmt != nil {
|
||||
if cerr := q.deleteSessionStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing deleteSessionStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.deleteSessionFilesStmt != nil {
|
||||
if cerr := q.deleteSessionFilesStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing deleteSessionFilesStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.deleteSessionMessagesStmt != nil {
|
||||
if cerr := q.deleteSessionMessagesStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing deleteSessionMessagesStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.getFileStmt != nil {
|
||||
if cerr := q.getFileStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing getFileStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.getFileByPathAndSessionStmt != nil {
|
||||
if cerr := q.getFileByPathAndSessionStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing getFileByPathAndSessionStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.getMessageStmt != nil {
|
||||
if cerr := q.getMessageStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing getMessageStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.getSessionByIDStmt != nil {
|
||||
if cerr := q.getSessionByIDStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing getSessionByIDStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.listAllLogsStmt != nil {
|
||||
if cerr := q.listAllLogsStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing listAllLogsStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.listFilesByPathStmt != nil {
|
||||
if cerr := q.listFilesByPathStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing listFilesByPathStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.listFilesBySessionStmt != nil {
|
||||
if cerr := q.listFilesBySessionStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing listFilesBySessionStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.listLatestSessionFilesStmt != nil {
|
||||
if cerr := q.listLatestSessionFilesStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing listLatestSessionFilesStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.listLogsBySessionStmt != nil {
|
||||
if cerr := q.listLogsBySessionStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing listLogsBySessionStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.listMessagesBySessionStmt != nil {
|
||||
if cerr := q.listMessagesBySessionStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing listMessagesBySessionStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.listMessagesBySessionAfterStmt != nil {
|
||||
if cerr := q.listMessagesBySessionAfterStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing listMessagesBySessionAfterStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.listNewFilesStmt != nil {
|
||||
if cerr := q.listNewFilesStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing listNewFilesStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.listSessionsStmt != nil {
|
||||
if cerr := q.listSessionsStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing listSessionsStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.updateFileStmt != nil {
|
||||
if cerr := q.updateFileStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing updateFileStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.updateMessageStmt != nil {
|
||||
if cerr := q.updateMessageStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing updateMessageStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
if q.updateSessionStmt != nil {
|
||||
if cerr := q.updateSessionStmt.Close(); cerr != nil {
|
||||
err = fmt.Errorf("error closing updateSessionStmt: %w", cerr)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (q *Queries) exec(ctx context.Context, stmt *sql.Stmt, query string, args ...interface{}) (sql.Result, error) {
|
||||
switch {
|
||||
case stmt != nil && q.tx != nil:
|
||||
return q.tx.StmtContext(ctx, stmt).ExecContext(ctx, args...)
|
||||
case stmt != nil:
|
||||
return stmt.ExecContext(ctx, args...)
|
||||
default:
|
||||
return q.db.ExecContext(ctx, query, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Queries) query(ctx context.Context, stmt *sql.Stmt, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
switch {
|
||||
case stmt != nil && q.tx != nil:
|
||||
return q.tx.StmtContext(ctx, stmt).QueryContext(ctx, args...)
|
||||
case stmt != nil:
|
||||
return stmt.QueryContext(ctx, args...)
|
||||
default:
|
||||
return q.db.QueryContext(ctx, query, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, args ...interface{}) *sql.Row {
|
||||
switch {
|
||||
case stmt != nil && q.tx != nil:
|
||||
return q.tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...)
|
||||
case stmt != nil:
|
||||
return stmt.QueryRowContext(ctx, args...)
|
||||
default:
|
||||
return q.db.QueryRowContext(ctx, query, args...)
|
||||
}
|
||||
}
|
||||
|
||||
type Queries struct {
|
||||
db DBTX
|
||||
tx *sql.Tx
|
||||
createFileStmt *sql.Stmt
|
||||
createLogStmt *sql.Stmt
|
||||
createMessageStmt *sql.Stmt
|
||||
createSessionStmt *sql.Stmt
|
||||
deleteFileStmt *sql.Stmt
|
||||
deleteMessageStmt *sql.Stmt
|
||||
deleteSessionStmt *sql.Stmt
|
||||
deleteSessionFilesStmt *sql.Stmt
|
||||
deleteSessionMessagesStmt *sql.Stmt
|
||||
getFileStmt *sql.Stmt
|
||||
getFileByPathAndSessionStmt *sql.Stmt
|
||||
getMessageStmt *sql.Stmt
|
||||
getSessionByIDStmt *sql.Stmt
|
||||
listAllLogsStmt *sql.Stmt
|
||||
listFilesByPathStmt *sql.Stmt
|
||||
listFilesBySessionStmt *sql.Stmt
|
||||
listLatestSessionFilesStmt *sql.Stmt
|
||||
listLogsBySessionStmt *sql.Stmt
|
||||
listMessagesBySessionStmt *sql.Stmt
|
||||
listMessagesBySessionAfterStmt *sql.Stmt
|
||||
listNewFilesStmt *sql.Stmt
|
||||
listSessionsStmt *sql.Stmt
|
||||
updateFileStmt *sql.Stmt
|
||||
updateMessageStmt *sql.Stmt
|
||||
updateSessionStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
|
||||
return &Queries{
|
||||
db: tx,
|
||||
tx: tx,
|
||||
createFileStmt: q.createFileStmt,
|
||||
createLogStmt: q.createLogStmt,
|
||||
createMessageStmt: q.createMessageStmt,
|
||||
createSessionStmt: q.createSessionStmt,
|
||||
deleteFileStmt: q.deleteFileStmt,
|
||||
deleteMessageStmt: q.deleteMessageStmt,
|
||||
deleteSessionStmt: q.deleteSessionStmt,
|
||||
deleteSessionFilesStmt: q.deleteSessionFilesStmt,
|
||||
deleteSessionMessagesStmt: q.deleteSessionMessagesStmt,
|
||||
getFileStmt: q.getFileStmt,
|
||||
getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt,
|
||||
getMessageStmt: q.getMessageStmt,
|
||||
getSessionByIDStmt: q.getSessionByIDStmt,
|
||||
listAllLogsStmt: q.listAllLogsStmt,
|
||||
listFilesByPathStmt: q.listFilesByPathStmt,
|
||||
listFilesBySessionStmt: q.listFilesBySessionStmt,
|
||||
listLatestSessionFilesStmt: q.listLatestSessionFilesStmt,
|
||||
listLogsBySessionStmt: q.listLogsBySessionStmt,
|
||||
listMessagesBySessionStmt: q.listMessagesBySessionStmt,
|
||||
listMessagesBySessionAfterStmt: q.listMessagesBySessionAfterStmt,
|
||||
listNewFilesStmt: q.listNewFilesStmt,
|
||||
listSessionsStmt: q.listSessionsStmt,
|
||||
updateFileStmt: q.updateFileStmt,
|
||||
updateMessageStmt: q.updateMessageStmt,
|
||||
updateSessionStmt: q.updateSessionStmt,
|
||||
}
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
package db
|
||||
|
||||
import "embed"
|
||||
|
||||
//go:embed migrations/*.sql
|
||||
var FS embed.FS
|
||||
@@ -1,317 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.29.0
|
||||
// source: files.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const createFile = `-- name: CreateFile :one
|
||||
INSERT INTO files (
|
||||
id,
|
||||
session_id,
|
||||
path,
|
||||
content,
|
||||
version
|
||||
) VALUES (
|
||||
?, ?, ?, ?, ?
|
||||
)
|
||||
RETURNING id, session_id, path, content, version, is_new, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateFileParams struct {
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Path string `json:"path"`
|
||||
Content string `json:"content"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateFile(ctx context.Context, arg CreateFileParams) (File, error) {
|
||||
row := q.queryRow(ctx, q.createFileStmt, createFile,
|
||||
arg.ID,
|
||||
arg.SessionID,
|
||||
arg.Path,
|
||||
arg.Content,
|
||||
arg.Version,
|
||||
)
|
||||
var i File
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Path,
|
||||
&i.Content,
|
||||
&i.Version,
|
||||
&i.IsNew,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteFile = `-- name: DeleteFile :exec
|
||||
DELETE FROM files
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteFile(ctx context.Context, id string) error {
|
||||
_, err := q.exec(ctx, q.deleteFileStmt, deleteFile, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteSessionFiles = `-- name: DeleteSessionFiles :exec
|
||||
DELETE FROM files
|
||||
WHERE session_id = ?
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteSessionFiles(ctx context.Context, sessionID string) error {
|
||||
_, err := q.exec(ctx, q.deleteSessionFilesStmt, deleteSessionFiles, sessionID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getFile = `-- name: GetFile :one
|
||||
SELECT id, session_id, path, content, version, is_new, created_at, updated_at
|
||||
FROM files
|
||||
WHERE id = ? LIMIT 1
|
||||
`
|
||||
|
||||
func (q *Queries) GetFile(ctx context.Context, id string) (File, error) {
|
||||
row := q.queryRow(ctx, q.getFileStmt, getFile, id)
|
||||
var i File
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Path,
|
||||
&i.Content,
|
||||
&i.Version,
|
||||
&i.IsNew,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getFileByPathAndSession = `-- name: GetFileByPathAndSession :one
|
||||
SELECT id, session_id, path, content, version, is_new, created_at, updated_at
|
||||
FROM files
|
||||
WHERE path = ? AND session_id = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
type GetFileByPathAndSessionParams struct {
|
||||
Path string `json:"path"`
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetFileByPathAndSession(ctx context.Context, arg GetFileByPathAndSessionParams) (File, error) {
|
||||
row := q.queryRow(ctx, q.getFileByPathAndSessionStmt, getFileByPathAndSession, arg.Path, arg.SessionID)
|
||||
var i File
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Path,
|
||||
&i.Content,
|
||||
&i.Version,
|
||||
&i.IsNew,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listFilesByPath = `-- name: ListFilesByPath :many
|
||||
SELECT id, session_id, path, content, version, is_new, created_at, updated_at
|
||||
FROM files
|
||||
WHERE path = ?
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListFilesByPath(ctx context.Context, path string) ([]File, error) {
|
||||
rows, err := q.query(ctx, q.listFilesByPathStmt, listFilesByPath, path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []File{}
|
||||
for rows.Next() {
|
||||
var i File
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Path,
|
||||
&i.Content,
|
||||
&i.Version,
|
||||
&i.IsNew,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listFilesBySession = `-- name: ListFilesBySession :many
|
||||
SELECT id, session_id, path, content, version, is_new, created_at, updated_at
|
||||
FROM files
|
||||
WHERE session_id = ?
|
||||
ORDER BY created_at ASC
|
||||
`
|
||||
|
||||
func (q *Queries) ListFilesBySession(ctx context.Context, sessionID string) ([]File, error) {
|
||||
rows, err := q.query(ctx, q.listFilesBySessionStmt, listFilesBySession, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []File{}
|
||||
for rows.Next() {
|
||||
var i File
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Path,
|
||||
&i.Content,
|
||||
&i.Version,
|
||||
&i.IsNew,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listLatestSessionFiles = `-- name: ListLatestSessionFiles :many
|
||||
SELECT f.id, f.session_id, f.path, f.content, f.version, f.is_new, f.created_at, f.updated_at
|
||||
FROM files f
|
||||
INNER JOIN (
|
||||
SELECT path, MAX(created_at) as max_created_at
|
||||
FROM files
|
||||
GROUP BY path
|
||||
) latest ON f.path = latest.path AND f.created_at = latest.max_created_at
|
||||
WHERE f.session_id = ?
|
||||
ORDER BY f.path
|
||||
`
|
||||
|
||||
func (q *Queries) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) {
|
||||
rows, err := q.query(ctx, q.listLatestSessionFilesStmt, listLatestSessionFiles, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []File{}
|
||||
for rows.Next() {
|
||||
var i File
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Path,
|
||||
&i.Content,
|
||||
&i.Version,
|
||||
&i.IsNew,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listNewFiles = `-- name: ListNewFiles :many
|
||||
SELECT id, session_id, path, content, version, is_new, created_at, updated_at
|
||||
FROM files
|
||||
WHERE is_new = 1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListNewFiles(ctx context.Context) ([]File, error) {
|
||||
rows, err := q.query(ctx, q.listNewFilesStmt, listNewFiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []File{}
|
||||
for rows.Next() {
|
||||
var i File
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Path,
|
||||
&i.Content,
|
||||
&i.Version,
|
||||
&i.IsNew,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateFile = `-- name: UpdateFile :one
|
||||
UPDATE files
|
||||
SET
|
||||
content = ?,
|
||||
version = ?,
|
||||
updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
|
||||
WHERE id = ?
|
||||
RETURNING id, session_id, path, content, version, is_new, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateFileParams struct {
|
||||
Content string `json:"content"`
|
||||
Version string `json:"version"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateFile(ctx context.Context, arg UpdateFileParams) (File, error) {
|
||||
row := q.queryRow(ctx, q.updateFileStmt, updateFile, arg.Content, arg.Version, arg.ID)
|
||||
var i File
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Path,
|
||||
&i.Content,
|
||||
&i.Version,
|
||||
&i.IsNew,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -1,137 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.29.0
|
||||
// source: logs.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
const createLog = `-- name: CreateLog :one
|
||||
INSERT INTO logs (
|
||||
id,
|
||||
session_id,
|
||||
timestamp,
|
||||
level,
|
||||
message,
|
||||
attributes
|
||||
) VALUES (
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?
|
||||
) RETURNING id, session_id, timestamp, level, message, attributes, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateLogParams struct {
|
||||
ID string `json:"id"`
|
||||
SessionID sql.NullString `json:"session_id"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
Level string `json:"level"`
|
||||
Message string `json:"message"`
|
||||
Attributes sql.NullString `json:"attributes"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateLog(ctx context.Context, arg CreateLogParams) (Log, error) {
|
||||
row := q.queryRow(ctx, q.createLogStmt, createLog,
|
||||
arg.ID,
|
||||
arg.SessionID,
|
||||
arg.Timestamp,
|
||||
arg.Level,
|
||||
arg.Message,
|
||||
arg.Attributes,
|
||||
)
|
||||
var i Log
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Timestamp,
|
||||
&i.Level,
|
||||
&i.Message,
|
||||
&i.Attributes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listAllLogs = `-- name: ListAllLogs :many
|
||||
SELECT id, session_id, timestamp, level, message, attributes, created_at, updated_at FROM logs
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
func (q *Queries) ListAllLogs(ctx context.Context, limit int64) ([]Log, error) {
|
||||
rows, err := q.query(ctx, q.listAllLogsStmt, listAllLogs, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []Log{}
|
||||
for rows.Next() {
|
||||
var i Log
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Timestamp,
|
||||
&i.Level,
|
||||
&i.Message,
|
||||
&i.Attributes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listLogsBySession = `-- name: ListLogsBySession :many
|
||||
SELECT id, session_id, timestamp, level, message, attributes, created_at, updated_at FROM logs
|
||||
WHERE session_id = ?
|
||||
ORDER BY timestamp DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListLogsBySession(ctx context.Context, sessionID sql.NullString) ([]Log, error) {
|
||||
rows, err := q.query(ctx, q.listLogsBySessionStmt, listLogsBySession, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []Log{}
|
||||
for rows.Next() {
|
||||
var i Log
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Timestamp,
|
||||
&i.Level,
|
||||
&i.Message,
|
||||
&i.Attributes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
@@ -1,199 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.29.0
|
||||
// source: messages.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
const createMessage = `-- name: CreateMessage :one
|
||||
INSERT INTO messages (
|
||||
id,
|
||||
session_id,
|
||||
role,
|
||||
parts,
|
||||
model
|
||||
) VALUES (
|
||||
?, ?, ?, ?, ?
|
||||
)
|
||||
RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at
|
||||
`
|
||||
|
||||
type CreateMessageParams struct {
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Role string `json:"role"`
|
||||
Parts string `json:"parts"`
|
||||
Model sql.NullString `json:"model"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error) {
|
||||
row := q.queryRow(ctx, q.createMessageStmt, createMessage,
|
||||
arg.ID,
|
||||
arg.SessionID,
|
||||
arg.Role,
|
||||
arg.Parts,
|
||||
arg.Model,
|
||||
)
|
||||
var i Message
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Role,
|
||||
&i.Parts,
|
||||
&i.Model,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.FinishedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteMessage = `-- name: DeleteMessage :exec
|
||||
DELETE FROM messages
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteMessage(ctx context.Context, id string) error {
|
||||
_, err := q.exec(ctx, q.deleteMessageStmt, deleteMessage, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteSessionMessages = `-- name: DeleteSessionMessages :exec
|
||||
DELETE FROM messages
|
||||
WHERE session_id = ?
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteSessionMessages(ctx context.Context, sessionID string) error {
|
||||
_, err := q.exec(ctx, q.deleteSessionMessagesStmt, deleteSessionMessages, sessionID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getMessage = `-- name: GetMessage :one
|
||||
SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at
|
||||
FROM messages
|
||||
WHERE id = ? LIMIT 1
|
||||
`
|
||||
|
||||
func (q *Queries) GetMessage(ctx context.Context, id string) (Message, error) {
|
||||
row := q.queryRow(ctx, q.getMessageStmt, getMessage, id)
|
||||
var i Message
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Role,
|
||||
&i.Parts,
|
||||
&i.Model,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.FinishedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listMessagesBySession = `-- name: ListMessagesBySession :many
|
||||
SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at
|
||||
FROM messages
|
||||
WHERE session_id = ?
|
||||
ORDER BY created_at ASC
|
||||
`
|
||||
|
||||
func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) ([]Message, error) {
|
||||
rows, err := q.query(ctx, q.listMessagesBySessionStmt, listMessagesBySession, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []Message{}
|
||||
for rows.Next() {
|
||||
var i Message
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Role,
|
||||
&i.Parts,
|
||||
&i.Model,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.FinishedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listMessagesBySessionAfter = `-- name: ListMessagesBySessionAfter :many
|
||||
SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at
|
||||
FROM messages
|
||||
WHERE session_id = ? AND created_at > ?
|
||||
ORDER BY created_at ASC
|
||||
`
|
||||
|
||||
type ListMessagesBySessionAfterParams struct {
|
||||
SessionID string `json:"session_id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) ListMessagesBySessionAfter(ctx context.Context, arg ListMessagesBySessionAfterParams) ([]Message, error) {
|
||||
rows, err := q.query(ctx, q.listMessagesBySessionAfterStmt, listMessagesBySessionAfter, arg.SessionID, arg.CreatedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []Message{}
|
||||
for rows.Next() {
|
||||
var i Message
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Role,
|
||||
&i.Parts,
|
||||
&i.Model,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.FinishedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateMessage = `-- name: UpdateMessage :exec
|
||||
UPDATE messages
|
||||
SET
|
||||
parts = ?,
|
||||
finished_at = ?,
|
||||
updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
type UpdateMessageParams struct {
|
||||
Parts string `json:"parts"`
|
||||
FinishedAt sql.NullString `json:"finished_at"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateMessage(ctx context.Context, arg UpdateMessageParams) error {
|
||||
_, err := q.exec(ctx, q.updateMessageStmt, updateMessage, arg.Parts, arg.FinishedAt, arg.ID)
|
||||
return err
|
||||
}
|
||||
@@ -1,125 +0,0 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
-- Sessions
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
parent_session_id TEXT,
|
||||
title TEXT NOT NULL,
|
||||
message_count INTEGER NOT NULL DEFAULT 0 CHECK (message_count >= 0),
|
||||
prompt_tokens INTEGER NOT NULL DEFAULT 0 CHECK (prompt_tokens >= 0),
|
||||
completion_tokens INTEGER NOT NULL DEFAULT 0 CHECK (completion_tokens >= 0),
|
||||
cost REAL NOT NULL DEFAULT 0.0 CHECK (cost >= 0.0),
|
||||
summary TEXT,
|
||||
summarized_at TEXT,
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')),
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now'))
|
||||
);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS update_sessions_updated_at
|
||||
AFTER UPDATE ON sessions
|
||||
BEGIN
|
||||
UPDATE sessions SET updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
|
||||
WHERE id = new.id;
|
||||
END;
|
||||
|
||||
-- Files
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
version TEXT NOT NULL,
|
||||
is_new INTEGER DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')),
|
||||
FOREIGN KEY (session_id) REFERENCES sessions (id) ON DELETE CASCADE,
|
||||
UNIQUE(path, session_id, version)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_files_session_id ON files (session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_files_path ON files (path);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS update_files_updated_at
|
||||
AFTER UPDATE ON files
|
||||
BEGIN
|
||||
UPDATE files SET updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
|
||||
WHERE id = new.id;
|
||||
END;
|
||||
|
||||
-- Messages
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
parts TEXT NOT NULL default '[]',
|
||||
model TEXT,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')),
|
||||
finished_at TEXT,
|
||||
FOREIGN KEY (session_id) REFERENCES sessions (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_session_id ON messages (session_id);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS update_messages_updated_at
|
||||
AFTER UPDATE ON messages
|
||||
BEGIN
|
||||
UPDATE messages SET updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
|
||||
WHERE id = new.id;
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS update_session_message_count_on_insert
|
||||
AFTER INSERT ON messages
|
||||
BEGIN
|
||||
UPDATE sessions SET
|
||||
message_count = message_count + 1
|
||||
WHERE id = new.session_id;
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS update_session_message_count_on_delete
|
||||
AFTER DELETE ON messages
|
||||
BEGIN
|
||||
UPDATE sessions SET
|
||||
message_count = message_count - 1
|
||||
WHERE id = old.session_id;
|
||||
END;
|
||||
|
||||
-- Logs
|
||||
CREATE TABLE IF NOT EXISTS logs (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT REFERENCES sessions(id) ON DELETE CASCADE,
|
||||
timestamp TEXT NOT NULL,
|
||||
level TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
attributes TEXT,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f000Z', 'now'))
|
||||
);
|
||||
|
||||
CREATE INDEX logs_session_id_idx ON logs(session_id);
|
||||
CREATE INDEX logs_timestamp_idx ON logs(timestamp);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS update_logs_updated_at
|
||||
AFTER UPDATE ON logs
|
||||
BEGIN
|
||||
UPDATE logs SET updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
|
||||
WHERE id = new.id;
|
||||
END;
|
||||
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
DROP TRIGGER IF EXISTS update_sessions_updated_at;
|
||||
DROP TRIGGER IF EXISTS update_messages_updated_at;
|
||||
DROP TRIGGER IF EXISTS update_files_updated_at;
|
||||
DROP TRIGGER IF EXISTS update_logs_updated_at;
|
||||
|
||||
DROP TRIGGER IF EXISTS update_session_message_count_on_delete;
|
||||
DROP TRIGGER IF EXISTS update_session_message_count_on_insert;
|
||||
|
||||
DROP TABLE IF EXISTS logs;
|
||||
DROP TABLE IF EXISTS messages;
|
||||
DROP TABLE IF EXISTS files;
|
||||
DROP TABLE IF EXISTS sessions;
|
||||
-- +goose StatementEnd
|
||||
@@ -1,56 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.29.0
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type File struct {
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Path string `json:"path"`
|
||||
Content string `json:"content"`
|
||||
Version string `json:"version"`
|
||||
IsNew sql.NullInt64 `json:"is_new"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
type Log struct {
|
||||
ID string `json:"id"`
|
||||
SessionID sql.NullString `json:"session_id"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
Level string `json:"level"`
|
||||
Message string `json:"message"`
|
||||
Attributes sql.NullString `json:"attributes"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Role string `json:"role"`
|
||||
Parts string `json:"parts"`
|
||||
Model sql.NullString `json:"model"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
FinishedAt sql.NullString `json:"finished_at"`
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
ParentSessionID sql.NullString `json:"parent_session_id"`
|
||||
Title string `json:"title"`
|
||||
MessageCount int64 `json:"message_count"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
Summary sql.NullString `json:"summary"`
|
||||
SummarizedAt sql.NullString `json:"summarized_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.29.0
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type Querier interface {
|
||||
CreateFile(ctx context.Context, arg CreateFileParams) (File, error)
|
||||
CreateLog(ctx context.Context, arg CreateLogParams) (Log, error)
|
||||
CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error)
|
||||
CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error)
|
||||
DeleteFile(ctx context.Context, id string) error
|
||||
DeleteMessage(ctx context.Context, id string) error
|
||||
DeleteSession(ctx context.Context, id string) error
|
||||
DeleteSessionFiles(ctx context.Context, sessionID string) error
|
||||
DeleteSessionMessages(ctx context.Context, sessionID string) error
|
||||
GetFile(ctx context.Context, id string) (File, error)
|
||||
GetFileByPathAndSession(ctx context.Context, arg GetFileByPathAndSessionParams) (File, error)
|
||||
GetMessage(ctx context.Context, id string) (Message, error)
|
||||
GetSessionByID(ctx context.Context, id string) (Session, error)
|
||||
ListAllLogs(ctx context.Context, limit int64) ([]Log, error)
|
||||
ListFilesByPath(ctx context.Context, path string) ([]File, error)
|
||||
ListFilesBySession(ctx context.Context, sessionID string) ([]File, error)
|
||||
ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error)
|
||||
ListLogsBySession(ctx context.Context, sessionID sql.NullString) ([]Log, error)
|
||||
ListMessagesBySession(ctx context.Context, sessionID string) ([]Message, error)
|
||||
ListMessagesBySessionAfter(ctx context.Context, arg ListMessagesBySessionAfterParams) ([]Message, error)
|
||||
ListNewFiles(ctx context.Context) ([]File, error)
|
||||
ListSessions(ctx context.Context) ([]Session, error)
|
||||
UpdateFile(ctx context.Context, arg UpdateFileParams) (File, error)
|
||||
UpdateMessage(ctx context.Context, arg UpdateMessageParams) error
|
||||
UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
|
||||
}
|
||||
|
||||
var _ Querier = (*Queries)(nil)
|
||||
@@ -1,203 +0,0 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.29.0
|
||||
// source: sessions.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
const createSession = `-- name: CreateSession :one
|
||||
INSERT INTO sessions (
|
||||
id,
|
||||
parent_session_id,
|
||||
title,
|
||||
message_count,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cost,
|
||||
summary,
|
||||
summarized_at
|
||||
) VALUES (
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?
|
||||
) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, summary, summarized_at, updated_at, created_at
|
||||
`
|
||||
|
||||
type CreateSessionParams struct {
|
||||
ID string `json:"id"`
|
||||
ParentSessionID sql.NullString `json:"parent_session_id"`
|
||||
Title string `json:"title"`
|
||||
MessageCount int64 `json:"message_count"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
Summary sql.NullString `json:"summary"`
|
||||
SummarizedAt sql.NullString `json:"summarized_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) {
|
||||
row := q.queryRow(ctx, q.createSessionStmt, createSession,
|
||||
arg.ID,
|
||||
arg.ParentSessionID,
|
||||
arg.Title,
|
||||
arg.MessageCount,
|
||||
arg.PromptTokens,
|
||||
arg.CompletionTokens,
|
||||
arg.Cost,
|
||||
arg.Summary,
|
||||
arg.SummarizedAt,
|
||||
)
|
||||
var i Session
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ParentSessionID,
|
||||
&i.Title,
|
||||
&i.MessageCount,
|
||||
&i.PromptTokens,
|
||||
&i.CompletionTokens,
|
||||
&i.Cost,
|
||||
&i.Summary,
|
||||
&i.SummarizedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteSession = `-- name: DeleteSession :exec
|
||||
DELETE FROM sessions
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteSession(ctx context.Context, id string) error {
|
||||
_, err := q.exec(ctx, q.deleteSessionStmt, deleteSession, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const getSessionByID = `-- name: GetSessionByID :one
|
||||
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, summary, summarized_at, updated_at, created_at
|
||||
FROM sessions
|
||||
WHERE id = ? LIMIT 1
|
||||
`
|
||||
|
||||
func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error) {
|
||||
row := q.queryRow(ctx, q.getSessionByIDStmt, getSessionByID, id)
|
||||
var i Session
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ParentSessionID,
|
||||
&i.Title,
|
||||
&i.MessageCount,
|
||||
&i.PromptTokens,
|
||||
&i.CompletionTokens,
|
||||
&i.Cost,
|
||||
&i.Summary,
|
||||
&i.SummarizedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listSessions = `-- name: ListSessions :many
|
||||
SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, summary, summarized_at, updated_at, created_at
|
||||
FROM sessions
|
||||
WHERE parent_session_id is NULL
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) {
|
||||
rows, err := q.query(ctx, q.listSessionsStmt, listSessions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []Session{}
|
||||
for rows.Next() {
|
||||
var i Session
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.ParentSessionID,
|
||||
&i.Title,
|
||||
&i.MessageCount,
|
||||
&i.PromptTokens,
|
||||
&i.CompletionTokens,
|
||||
&i.Cost,
|
||||
&i.Summary,
|
||||
&i.SummarizedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateSession = `-- name: UpdateSession :one
|
||||
UPDATE sessions
|
||||
SET
|
||||
title = ?,
|
||||
prompt_tokens = ?,
|
||||
completion_tokens = ?,
|
||||
cost = ?,
|
||||
summary = ?,
|
||||
summarized_at = ?
|
||||
WHERE id = ?
|
||||
RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, summary, summarized_at, updated_at, created_at
|
||||
`
|
||||
|
||||
type UpdateSessionParams struct {
|
||||
Title string `json:"title"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
Summary sql.NullString `json:"summary"`
|
||||
SummarizedAt sql.NullString `json:"summarized_at"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) {
|
||||
row := q.queryRow(ctx, q.updateSessionStmt, updateSession,
|
||||
arg.Title,
|
||||
arg.PromptTokens,
|
||||
arg.CompletionTokens,
|
||||
arg.Cost,
|
||||
arg.Summary,
|
||||
arg.SummarizedAt,
|
||||
arg.ID,
|
||||
)
|
||||
var i Session
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ParentSessionID,
|
||||
&i.Title,
|
||||
&i.MessageCount,
|
||||
&i.PromptTokens,
|
||||
&i.CompletionTokens,
|
||||
&i.Cost,
|
||||
&i.Summary,
|
||||
&i.SummarizedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
-- name: GetFile :one
|
||||
SELECT *
|
||||
FROM files
|
||||
WHERE id = ? LIMIT 1;
|
||||
|
||||
-- name: GetFileByPathAndSession :one
|
||||
SELECT *
|
||||
FROM files
|
||||
WHERE path = ? AND session_id = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1;
|
||||
|
||||
-- name: ListFilesBySession :many
|
||||
SELECT *
|
||||
FROM files
|
||||
WHERE session_id = ?
|
||||
ORDER BY created_at ASC;
|
||||
|
||||
-- name: ListFilesByPath :many
|
||||
SELECT *
|
||||
FROM files
|
||||
WHERE path = ?
|
||||
ORDER BY created_at DESC;
|
||||
|
||||
-- name: CreateFile :one
|
||||
INSERT INTO files (
|
||||
id,
|
||||
session_id,
|
||||
path,
|
||||
content,
|
||||
version
|
||||
) VALUES (
|
||||
?, ?, ?, ?, ?
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateFile :one
|
||||
UPDATE files
|
||||
SET
|
||||
content = ?,
|
||||
version = ?,
|
||||
updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
|
||||
WHERE id = ?
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteFile :exec
|
||||
DELETE FROM files
|
||||
WHERE id = ?;
|
||||
|
||||
-- name: DeleteSessionFiles :exec
|
||||
DELETE FROM files
|
||||
WHERE session_id = ?;
|
||||
|
||||
-- name: ListLatestSessionFiles :many
|
||||
SELECT f.*
|
||||
FROM files f
|
||||
INNER JOIN (
|
||||
SELECT path, MAX(created_at) as max_created_at
|
||||
FROM files
|
||||
GROUP BY path
|
||||
) latest ON f.path = latest.path AND f.created_at = latest.max_created_at
|
||||
WHERE f.session_id = ?
|
||||
ORDER BY f.path;
|
||||
|
||||
-- name: ListNewFiles :many
|
||||
SELECT *
|
||||
FROM files
|
||||
WHERE is_new = 1
|
||||
ORDER BY created_at DESC;
|
||||
@@ -1,26 +0,0 @@
|
||||
-- name: CreateLog :one
|
||||
INSERT INTO logs (
|
||||
id,
|
||||
session_id,
|
||||
timestamp,
|
||||
level,
|
||||
message,
|
||||
attributes
|
||||
) VALUES (
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?
|
||||
) RETURNING *;
|
||||
|
||||
-- name: ListLogsBySession :many
|
||||
SELECT * FROM logs
|
||||
WHERE session_id = ?
|
||||
ORDER BY timestamp DESC;
|
||||
|
||||
-- name: ListAllLogs :many
|
||||
SELECT * FROM logs
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?;
|
||||
@@ -1,45 +0,0 @@
|
||||
-- name: GetMessage :one
|
||||
SELECT *
|
||||
FROM messages
|
||||
WHERE id = ? LIMIT 1;
|
||||
|
||||
-- name: ListMessagesBySession :many
|
||||
SELECT *
|
||||
FROM messages
|
||||
WHERE session_id = ?
|
||||
ORDER BY created_at ASC;
|
||||
|
||||
-- name: ListMessagesBySessionAfter :many
|
||||
SELECT *
|
||||
FROM messages
|
||||
WHERE session_id = ? AND created_at > ?
|
||||
ORDER BY created_at ASC;
|
||||
|
||||
-- name: CreateMessage :one
|
||||
INSERT INTO messages (
|
||||
id,
|
||||
session_id,
|
||||
role,
|
||||
parts,
|
||||
model
|
||||
) VALUES (
|
||||
?, ?, ?, ?, ?
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateMessage :exec
|
||||
UPDATE messages
|
||||
SET
|
||||
parts = ?,
|
||||
finished_at = ?,
|
||||
updated_at = strftime('%Y-%m-%dT%H:%M:%f000Z', 'now')
|
||||
WHERE id = ?;
|
||||
|
||||
|
||||
-- name: DeleteMessage :exec
|
||||
DELETE FROM messages
|
||||
WHERE id = ?;
|
||||
|
||||
-- name: DeleteSessionMessages :exec
|
||||
DELETE FROM messages
|
||||
WHERE session_id = ?;
|
||||
@@ -1,50 +0,0 @@
|
||||
-- name: CreateSession :one
|
||||
INSERT INTO sessions (
|
||||
id,
|
||||
parent_session_id,
|
||||
title,
|
||||
message_count,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cost,
|
||||
summary,
|
||||
summarized_at
|
||||
) VALUES (
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?,
|
||||
?
|
||||
) RETURNING *;
|
||||
|
||||
-- name: GetSessionByID :one
|
||||
SELECT *
|
||||
FROM sessions
|
||||
WHERE id = ? LIMIT 1;
|
||||
|
||||
-- name: ListSessions :many
|
||||
SELECT *
|
||||
FROM sessions
|
||||
WHERE parent_session_id is NULL
|
||||
ORDER BY created_at DESC;
|
||||
|
||||
-- name: UpdateSession :one
|
||||
UPDATE sessions
|
||||
SET
|
||||
title = ?,
|
||||
prompt_tokens = ?,
|
||||
completion_tokens = ?,
|
||||
cost = ?,
|
||||
summary = ?,
|
||||
summarized_at = ?
|
||||
WHERE id = ?
|
||||
RETURNING *;
|
||||
|
||||
|
||||
-- name: DeleteSession :exec
|
||||
DELETE FROM sessions
|
||||
WHERE id = ?;
|
||||
@@ -1,441 +0,0 @@
|
||||
package history
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sst/opencode/internal/db"
|
||||
"github.com/sst/opencode/internal/pubsub"
|
||||
)
|
||||
|
||||
const (
|
||||
InitialVersion = "initial"
|
||||
)
|
||||
|
||||
type File struct {
|
||||
ID string
|
||||
SessionID string
|
||||
Path string
|
||||
Content string
|
||||
Version string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
EventFileCreated pubsub.EventType = "history_file_created"
|
||||
EventFileVersionCreated pubsub.EventType = "history_file_version_created"
|
||||
EventFileUpdated pubsub.EventType = "history_file_updated"
|
||||
EventFileDeleted pubsub.EventType = "history_file_deleted"
|
||||
EventSessionFilesDeleted pubsub.EventType = "history_session_files_deleted"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
pubsub.Subscriber[File]
|
||||
|
||||
Create(ctx context.Context, sessionID, path, content string) (File, error)
|
||||
CreateVersion(ctx context.Context, sessionID, path, content string) (File, error)
|
||||
Get(ctx context.Context, id string) (File, error)
|
||||
GetByPathAndVersion(ctx context.Context, sessionID, path, version string) (File, error)
|
||||
GetLatestByPathAndSession(ctx context.Context, path, sessionID string) (File, error)
|
||||
ListBySession(ctx context.Context, sessionID string) ([]File, error)
|
||||
ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error)
|
||||
ListVersionsByPath(ctx context.Context, path string) ([]File, error)
|
||||
Update(ctx context.Context, file File) (File, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
DeleteSessionFiles(ctx context.Context, sessionID string) error
|
||||
}
|
||||
|
||||
type service struct {
|
||||
db *db.Queries
|
||||
sqlDB *sql.DB
|
||||
broker *pubsub.Broker[File]
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var globalHistoryService *service
|
||||
|
||||
func InitService(sqlDatabase *sql.DB) error {
|
||||
if globalHistoryService != nil {
|
||||
return fmt.Errorf("history service already initialized")
|
||||
}
|
||||
queries := db.New(sqlDatabase)
|
||||
broker := pubsub.NewBroker[File]()
|
||||
|
||||
globalHistoryService = &service{
|
||||
db: queries,
|
||||
sqlDB: sqlDatabase,
|
||||
broker: broker,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetService() Service {
|
||||
if globalHistoryService == nil {
|
||||
panic("history service not initialized. Call history.InitService() first.")
|
||||
}
|
||||
return globalHistoryService
|
||||
}
|
||||
|
||||
func (s *service) Create(ctx context.Context, sessionID, path, content string) (File, error) {
|
||||
return s.createWithVersion(ctx, sessionID, path, content, InitialVersion, EventFileCreated)
|
||||
}
|
||||
|
||||
func (s *service) CreateVersion(ctx context.Context, sessionID, path, content string) (File, error) {
|
||||
s.mu.RLock()
|
||||
files, err := s.db.ListFilesByPath(ctx, path)
|
||||
s.mu.RUnlock()
|
||||
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return File{}, fmt.Errorf("db.ListFilesByPath for next version: %w", err)
|
||||
}
|
||||
|
||||
latestVersionNumber := 0
|
||||
if len(files) > 0 {
|
||||
// Sort to be absolutely sure about the latest version globally for this path
|
||||
slices.SortFunc(files, func(a, b db.File) int {
|
||||
if strings.HasPrefix(a.Version, "v") && strings.HasPrefix(b.Version, "v") {
|
||||
vA, _ := strconv.Atoi(a.Version[1:])
|
||||
vB, _ := strconv.Atoi(b.Version[1:])
|
||||
return vB - vA // Descending to get latest first
|
||||
}
|
||||
if a.Version == InitialVersion && b.Version != InitialVersion {
|
||||
return 1 // initial comes after vX
|
||||
}
|
||||
if b.Version == InitialVersion && a.Version != InitialVersion {
|
||||
return -1
|
||||
}
|
||||
// Compare timestamps as strings (ISO format sorts correctly)
|
||||
if b.CreatedAt > a.CreatedAt {
|
||||
return 1
|
||||
} else if a.CreatedAt > b.CreatedAt {
|
||||
return -1
|
||||
}
|
||||
return 0 // Equal timestamps
|
||||
})
|
||||
|
||||
latestFile := files[0]
|
||||
if strings.HasPrefix(latestFile.Version, "v") {
|
||||
vNum, parseErr := strconv.Atoi(latestFile.Version[1:])
|
||||
if parseErr == nil {
|
||||
latestVersionNumber = vNum
|
||||
}
|
||||
}
|
||||
}
|
||||
nextVersionStr := fmt.Sprintf("v%d", latestVersionNumber+1)
|
||||
return s.createWithVersion(ctx, sessionID, path, content, nextVersionStr, EventFileVersionCreated)
|
||||
}
|
||||
|
||||
func (s *service) createWithVersion(ctx context.Context, sessionID, path, content, version string, eventType pubsub.EventType) (File, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
const maxRetries = 3
|
||||
var file File
|
||||
var err error
|
||||
|
||||
for attempt := range maxRetries {
|
||||
tx, txErr := s.sqlDB.BeginTx(ctx, nil)
|
||||
if txErr != nil {
|
||||
return File{}, fmt.Errorf("failed to begin transaction: %w", txErr)
|
||||
}
|
||||
qtx := s.db.WithTx(tx)
|
||||
|
||||
dbFile, createErr := qtx.CreateFile(ctx, db.CreateFileParams{
|
||||
ID: uuid.New().String(),
|
||||
SessionID: sessionID,
|
||||
Path: path,
|
||||
Content: content,
|
||||
Version: version,
|
||||
})
|
||||
|
||||
if createErr != nil {
|
||||
if rbErr := tx.Rollback(); rbErr != nil {
|
||||
slog.Error("Failed to rollback transaction on create error", "error", rbErr)
|
||||
}
|
||||
if strings.Contains(createErr.Error(), "UNIQUE constraint failed: files.path, files.session_id, files.version") {
|
||||
if attempt < maxRetries-1 {
|
||||
slog.Warn("Unique constraint violation for file version, retrying with incremented version", "path", path, "session", sessionID, "attempted_version", version, "attempt", attempt+1)
|
||||
// Increment version string like v1, v2, v3...
|
||||
if strings.HasPrefix(version, "v") {
|
||||
numPart := version[1:]
|
||||
num, parseErr := strconv.Atoi(numPart)
|
||||
if parseErr == nil {
|
||||
version = fmt.Sprintf("v%d", num+1)
|
||||
continue // Retry with new version
|
||||
}
|
||||
}
|
||||
// Fallback if version is not "vX" or parsing failed
|
||||
version = fmt.Sprintf("%s-retry%d", version, attempt+1)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return File{}, fmt.Errorf("db.CreateFile within transaction: %w", createErr)
|
||||
}
|
||||
|
||||
if commitErr := tx.Commit(); commitErr != nil {
|
||||
return File{}, fmt.Errorf("failed to commit transaction: %w", commitErr)
|
||||
}
|
||||
|
||||
file = s.fromDBItem(dbFile)
|
||||
s.broker.Publish(eventType, file)
|
||||
return file, nil // Success
|
||||
}
|
||||
|
||||
return File{}, fmt.Errorf("failed to create file after %d retries due to version conflicts: %w", maxRetries, err)
|
||||
}
|
||||
|
||||
func (s *service) Get(ctx context.Context, id string) (File, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
dbFile, err := s.db.GetFile(ctx, id)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return File{}, fmt.Errorf("file with ID '%s' not found", id)
|
||||
}
|
||||
return File{}, fmt.Errorf("db.GetFile: %w", err)
|
||||
}
|
||||
return s.fromDBItem(dbFile), nil
|
||||
}
|
||||
|
||||
func (s *service) GetByPathAndVersion(ctx context.Context, sessionID, path, version string) (File, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// sqlc doesn't directly support GetyByPathAndVersionAndSession
|
||||
// We list and filter. This could be optimized with a custom query if performance is an issue.
|
||||
allFilesForPath, err := s.db.ListFilesByPath(ctx, path)
|
||||
if err != nil {
|
||||
return File{}, fmt.Errorf("db.ListFilesByPath for GetByPathAndVersion: %w", err)
|
||||
}
|
||||
|
||||
for _, dbFile := range allFilesForPath {
|
||||
if dbFile.SessionID == sessionID && dbFile.Version == version {
|
||||
return s.fromDBItem(dbFile), nil
|
||||
}
|
||||
}
|
||||
return File{}, fmt.Errorf("file not found for session '%s', path '%s', version '%s'", sessionID, path, version)
|
||||
}
|
||||
|
||||
func (s *service) GetLatestByPathAndSession(ctx context.Context, path, sessionID string) (File, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
// GetFileByPathAndSession in sqlc already orders by created_at DESC and takes LIMIT 1
|
||||
dbFile, err := s.db.GetFileByPathAndSession(ctx, db.GetFileByPathAndSessionParams{
|
||||
Path: path,
|
||||
SessionID: sessionID,
|
||||
})
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return File{}, fmt.Errorf("no file found for path '%s' in session '%s'", path, sessionID)
|
||||
}
|
||||
return File{}, fmt.Errorf("db.GetFileByPathAndSession: %w", err)
|
||||
}
|
||||
return s.fromDBItem(dbFile), nil
|
||||
}
|
||||
|
||||
func (s *service) ListBySession(ctx context.Context, sessionID string) ([]File, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
dbFiles, err := s.db.ListFilesBySession(ctx, sessionID) // Assumes this orders by created_at ASC
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListFilesBySession: %w", err)
|
||||
}
|
||||
files := make([]File, len(dbFiles))
|
||||
for i, dbF := range dbFiles {
|
||||
files[i] = s.fromDBItem(dbF)
|
||||
}
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (s *service) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
dbFiles, err := s.db.ListLatestSessionFiles(ctx, sessionID) // Uses the specific sqlc query
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListLatestSessionFiles: %w", err)
|
||||
}
|
||||
files := make([]File, len(dbFiles))
|
||||
for i, dbF := range dbFiles {
|
||||
files[i] = s.fromDBItem(dbF)
|
||||
}
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (s *service) ListVersionsByPath(ctx context.Context, path string) ([]File, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
dbFiles, err := s.db.ListFilesByPath(ctx, path) // sqlc query orders by created_at DESC
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListFilesByPath: %w", err)
|
||||
}
|
||||
files := make([]File, len(dbFiles))
|
||||
for i, dbF := range dbFiles {
|
||||
files[i] = s.fromDBItem(dbF)
|
||||
}
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (s *service) Update(ctx context.Context, file File) (File, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if file.ID == "" {
|
||||
return File{}, fmt.Errorf("cannot update file with empty ID")
|
||||
}
|
||||
// UpdatedAt is handled by DB trigger
|
||||
dbFile, err := s.db.UpdateFile(ctx, db.UpdateFileParams{
|
||||
ID: file.ID,
|
||||
Content: file.Content,
|
||||
Version: file.Version,
|
||||
})
|
||||
if err != nil {
|
||||
return File{}, fmt.Errorf("db.UpdateFile: %w", err)
|
||||
}
|
||||
updatedFile := s.fromDBItem(dbFile)
|
||||
s.broker.Publish(EventFileUpdated, updatedFile)
|
||||
return updatedFile, nil
|
||||
}
|
||||
|
||||
func (s *service) Delete(ctx context.Context, id string) error {
|
||||
s.mu.Lock()
|
||||
fileToPublish, err := s.getServiceForPublish(ctx, id) // Use internal method with appropriate locking
|
||||
s.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
slog.Warn("Attempted to delete non-existent file history", "id", id)
|
||||
return nil // Or return specific error if needed
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
err = s.db.DeleteFile(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.DeleteFile: %w", err)
|
||||
}
|
||||
if fileToPublish != nil {
|
||||
s.broker.Publish(EventFileDeleted, *fileToPublish)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) getServiceForPublish(ctx context.Context, id string) (*File, error) {
|
||||
// Assumes outer lock is NOT held or caller manages it.
|
||||
// For GetFile, it has its own RLock.
|
||||
dbFile, err := s.db.GetFile(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
file := s.fromDBItem(dbFile)
|
||||
return &file, nil
|
||||
}
|
||||
|
||||
func (s *service) DeleteSessionFiles(ctx context.Context, sessionID string) error {
|
||||
s.mu.Lock() // Lock for the entire operation
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Get files first for publishing events
|
||||
filesToDelete, err := s.db.ListFilesBySession(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.ListFilesBySession for deletion: %w", err)
|
||||
}
|
||||
|
||||
err = s.db.DeleteSessionFiles(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.DeleteSessionFiles: %w", err)
|
||||
}
|
||||
|
||||
for _, dbFile := range filesToDelete {
|
||||
file := s.fromDBItem(dbFile)
|
||||
s.broker.Publish(EventFileDeleted, file) // Individual delete events
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[File] {
|
||||
return s.broker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func (s *service) fromDBItem(item db.File) File {
|
||||
// Parse timestamps from ISO strings
|
||||
createdAt, err := time.Parse(time.RFC3339Nano, item.CreatedAt)
|
||||
if err != nil {
|
||||
slog.Error("Failed to parse created_at", "value", item.CreatedAt, "error", err)
|
||||
createdAt = time.Now() // Fallback
|
||||
}
|
||||
|
||||
updatedAt, err := time.Parse(time.RFC3339Nano, item.UpdatedAt)
|
||||
if err != nil {
|
||||
slog.Error("Failed to parse created_at", "value", item.CreatedAt, "error", err)
|
||||
updatedAt = time.Now() // Fallback
|
||||
}
|
||||
|
||||
return File{
|
||||
ID: item.ID,
|
||||
SessionID: item.SessionID,
|
||||
Path: item.Path,
|
||||
Content: item.Content,
|
||||
Version: item.Version,
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func Create(ctx context.Context, sessionID, path, content string) (File, error) {
|
||||
return GetService().Create(ctx, sessionID, path, content)
|
||||
}
|
||||
|
||||
func CreateVersion(ctx context.Context, sessionID, path, content string) (File, error) {
|
||||
return GetService().CreateVersion(ctx, sessionID, path, content)
|
||||
}
|
||||
|
||||
func Get(ctx context.Context, id string) (File, error) {
|
||||
return GetService().Get(ctx, id)
|
||||
}
|
||||
|
||||
func GetByPathAndVersion(ctx context.Context, sessionID, path, version string) (File, error) {
|
||||
return GetService().GetByPathAndVersion(ctx, sessionID, path, version)
|
||||
}
|
||||
|
||||
func GetLatestByPathAndSession(ctx context.Context, path, sessionID string) (File, error) {
|
||||
return GetService().GetLatestByPathAndSession(ctx, path, sessionID)
|
||||
}
|
||||
|
||||
func ListBySession(ctx context.Context, sessionID string) ([]File, error) {
|
||||
return GetService().ListBySession(ctx, sessionID)
|
||||
}
|
||||
|
||||
func ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) {
|
||||
return GetService().ListLatestSessionFiles(ctx, sessionID)
|
||||
}
|
||||
|
||||
func ListVersionsByPath(ctx context.Context, path string) ([]File, error) {
|
||||
return GetService().ListVersionsByPath(ctx, path)
|
||||
}
|
||||
|
||||
func Update(ctx context.Context, file File) (File, error) {
|
||||
return GetService().Update(ctx, file)
|
||||
}
|
||||
|
||||
func Delete(ctx context.Context, id string) error {
|
||||
return GetService().Delete(ctx, id)
|
||||
}
|
||||
|
||||
func DeleteSessionFiles(ctx context.Context, sessionID string) error {
|
||||
return GetService().DeleteSessionFiles(ctx, sessionID)
|
||||
}
|
||||
|
||||
func Subscribe(ctx context.Context) <-chan pubsub.Event[File] {
|
||||
return GetService().Subscribe(ctx)
|
||||
}
|
||||
@@ -1,109 +0,0 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/session"
|
||||
)
|
||||
|
||||
type agentTool struct {
|
||||
sessions session.Service
|
||||
messages message.Service
|
||||
lspClients map[string]*lsp.Client
|
||||
}
|
||||
|
||||
const (
|
||||
AgentToolName = "agent"
|
||||
)
|
||||
|
||||
type AgentParams struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
func (b *agentTool) Info() tools.ToolInfo {
|
||||
return tools.ToolInfo{
|
||||
Name: AgentToolName,
|
||||
Description: "Launch a new agent that has access to the following tools: GlobTool, GrepTool, LS, View. When you are searching for a keyword or file and are not confident that you will find the right match on the first try, use the Agent tool to perform the search for you. For example:\n\n- If you are searching for a keyword like \"config\" or \"logger\", or for questions like \"which file does X?\", the Agent tool is strongly recommended\n- If you want to read a specific file path, use the View or GlobTool tool instead of the Agent tool, to find the match more quickly\n- If you are searching for a specific class definition like \"class Foo\", use the GlobTool tool instead, to find the match more quickly\n\nUsage notes:\n1. Launch multiple agents concurrently whenever possible, to maximize performance; to do that, use a single message with multiple tool uses\n2. When the agent is done, it will return a single message back to you. The result returned by the agent is not visible to the user. To show the user the result, you should send a text message back to the user with a concise summary of the result.\n3. Each agent invocation is stateless. You will not be able to send additional messages to the agent, nor will the agent be able to communicate with you outside of its final report. Therefore, your prompt should contain a highly detailed task description for the agent to perform autonomously and you should specify exactly what information the agent should return back to you in its final and only message to you.\n4. The agent's outputs should generally be trusted\n5. IMPORTANT: The agent can not use Bash, Replace, Edit, so can not modify files. If you want to use these tools, use them directly instead of going through the agent.",
|
||||
Parameters: map[string]any{
|
||||
"prompt": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The task for the agent to perform",
|
||||
},
|
||||
},
|
||||
Required: []string{"prompt"},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
|
||||
var params AgentParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
if params.Prompt == "" {
|
||||
return tools.NewTextErrorResponse("prompt is required"), nil
|
||||
}
|
||||
|
||||
sessionID, messageID := tools.GetContextValues(ctx)
|
||||
if sessionID == "" || messageID == "" {
|
||||
return tools.ToolResponse{}, fmt.Errorf("session_id and message_id are required")
|
||||
}
|
||||
|
||||
agent, err := NewAgent(config.AgentTask, b.sessions, b.messages, TaskAgentTools(b.lspClients))
|
||||
if err != nil {
|
||||
return tools.ToolResponse{}, fmt.Errorf("error creating agent: %s", err)
|
||||
}
|
||||
|
||||
session, err := b.sessions.CreateTaskSession(ctx, call.ID, sessionID, "New Agent Session")
|
||||
if err != nil {
|
||||
return tools.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
|
||||
}
|
||||
|
||||
done, err := agent.Run(ctx, session.ID, params.Prompt)
|
||||
if err != nil {
|
||||
return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", err)
|
||||
}
|
||||
result := <-done
|
||||
if result.Err() != nil {
|
||||
return tools.ToolResponse{}, fmt.Errorf("error generating agent: %s", result.Err())
|
||||
}
|
||||
|
||||
response := result.Response()
|
||||
if response.Role != message.Assistant {
|
||||
return tools.NewTextErrorResponse("no response"), nil
|
||||
}
|
||||
|
||||
updatedSession, err := b.sessions.Get(ctx, session.ID)
|
||||
if err != nil {
|
||||
return tools.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
|
||||
}
|
||||
parentSession, err := b.sessions.Get(ctx, sessionID)
|
||||
if err != nil {
|
||||
return tools.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
|
||||
}
|
||||
|
||||
parentSession.Cost += updatedSession.Cost
|
||||
|
||||
_, err = b.sessions.Update(ctx, parentSession)
|
||||
if err != nil {
|
||||
return tools.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
|
||||
}
|
||||
return tools.NewTextResponse(response.Content().String()), nil
|
||||
}
|
||||
|
||||
func NewAgentTool(
|
||||
Sessions session.Service,
|
||||
Messages message.Service,
|
||||
LspClients map[string]*lsp.Client,
|
||||
) tools.BaseTool {
|
||||
return &agentTool{
|
||||
sessions: Sessions,
|
||||
messages: Messages,
|
||||
lspClients: LspClients,
|
||||
}
|
||||
}
|
||||
@@ -1,814 +0,0 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
"github.com/sst/opencode/internal/llm/prompt"
|
||||
"github.com/sst/opencode/internal/llm/provider"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/logging"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/permission"
|
||||
"github.com/sst/opencode/internal/session"
|
||||
"github.com/sst/opencode/internal/status"
|
||||
)
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
ErrRequestCancelled = errors.New("request cancelled by user")
|
||||
ErrSessionBusy = errors.New("session is currently processing another request")
|
||||
)
|
||||
|
||||
type AgentEvent struct {
|
||||
message message.Message
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *AgentEvent) Err() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
func (e *AgentEvent) Response() message.Message {
|
||||
return e.message
|
||||
}
|
||||
|
||||
type Service interface {
|
||||
Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
|
||||
Cancel(sessionID string)
|
||||
IsSessionBusy(sessionID string) bool
|
||||
IsBusy() bool
|
||||
Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
|
||||
CompactSession(ctx context.Context, sessionID string, force bool) error
|
||||
GetUsage(ctx context.Context, sessionID string) (*int64, error)
|
||||
EstimateContextWindowUsage(ctx context.Context, sessionID string) (float64, bool, error)
|
||||
}
|
||||
|
||||
type agent struct {
|
||||
sessions session.Service
|
||||
messages message.Service
|
||||
|
||||
tools []tools.BaseTool
|
||||
provider provider.Provider
|
||||
|
||||
titleProvider provider.Provider
|
||||
|
||||
activeRequests sync.Map
|
||||
}
|
||||
|
||||
func NewAgent(
|
||||
agentName config.AgentName,
|
||||
sessions session.Service,
|
||||
messages message.Service,
|
||||
agentTools []tools.BaseTool,
|
||||
) (Service, error) {
|
||||
agentProvider, err := createAgentProvider(agentName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var titleProvider provider.Provider
|
||||
// Only generate titles for the primary agent
|
||||
if agentName == config.AgentPrimary {
|
||||
titleProvider, err = createAgentProvider(config.AgentTitle)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
agent := &agent{
|
||||
provider: agentProvider,
|
||||
messages: messages,
|
||||
sessions: sessions,
|
||||
tools: agentTools,
|
||||
titleProvider: titleProvider,
|
||||
activeRequests: sync.Map{},
|
||||
}
|
||||
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
func (a *agent) Cancel(sessionID string) {
|
||||
if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
|
||||
if cancel, ok := cancelFunc.(context.CancelFunc); ok {
|
||||
status.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *agent) IsBusy() bool {
|
||||
busy := false
|
||||
a.activeRequests.Range(func(key, value interface{}) bool {
|
||||
if cancelFunc, ok := value.(context.CancelFunc); ok {
|
||||
if cancelFunc != nil {
|
||||
busy = true
|
||||
return false // Stop iterating
|
||||
}
|
||||
}
|
||||
return true // Continue iterating
|
||||
})
|
||||
return busy
|
||||
}
|
||||
|
||||
func (a *agent) IsSessionBusy(sessionID string) bool {
|
||||
_, busy := a.activeRequests.Load(sessionID)
|
||||
return busy
|
||||
}
|
||||
|
||||
func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
|
||||
if content == "" {
|
||||
return nil
|
||||
}
|
||||
if a.titleProvider == nil {
|
||||
return nil
|
||||
}
|
||||
session, err := a.sessions.Get(ctx, sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
parts := []message.ContentPart{message.TextContent{Text: content}}
|
||||
response, err := a.titleProvider.SendMessages(
|
||||
ctx,
|
||||
[]message.Message{
|
||||
{
|
||||
Role: message.User,
|
||||
Parts: parts,
|
||||
},
|
||||
},
|
||||
make([]tools.BaseTool, 0),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
|
||||
if title == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
session.Title = title
|
||||
_, err = a.sessions.Update(ctx, session)
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *agent) err(err error) AgentEvent {
|
||||
return AgentEvent{
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
|
||||
if !a.provider.Model().SupportsAttachments && attachments != nil {
|
||||
attachments = nil
|
||||
}
|
||||
events := make(chan AgentEvent)
|
||||
if a.IsSessionBusy(sessionID) {
|
||||
return nil, ErrSessionBusy
|
||||
}
|
||||
|
||||
genCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
a.activeRequests.Store(sessionID, cancel)
|
||||
go func() {
|
||||
slog.Debug("Request started", "sessionID", sessionID)
|
||||
defer logging.RecoverPanic("agent.Run", func() {
|
||||
events <- a.err(fmt.Errorf("panic while running the agent"))
|
||||
})
|
||||
var attachmentParts []message.ContentPart
|
||||
for _, attachment := range attachments {
|
||||
attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
|
||||
}
|
||||
result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
|
||||
if result.Err() != nil && !errors.Is(result.Err(), ErrRequestCancelled) && !errors.Is(result.Err(), context.Canceled) {
|
||||
status.Error(result.Err().Error())
|
||||
}
|
||||
slog.Debug("Request completed", "sessionID", sessionID)
|
||||
a.activeRequests.Delete(sessionID)
|
||||
cancel()
|
||||
events <- result
|
||||
close(events)
|
||||
}()
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func (a *agent) prepareMessageHistory(ctx context.Context, sessionID string) (session.Session, []message.Message, error) {
|
||||
currentSession, err := a.sessions.Get(ctx, sessionID)
|
||||
if err != nil {
|
||||
return currentSession, nil, fmt.Errorf("failed to get session: %w", err)
|
||||
}
|
||||
|
||||
var sessionMessages []message.Message
|
||||
if currentSession.Summary != "" && !currentSession.SummarizedAt.IsZero() {
|
||||
// If summary exists, only fetch messages after the summarization timestamp
|
||||
sessionMessages, err = a.messages.ListAfter(ctx, sessionID, currentSession.SummarizedAt)
|
||||
if err != nil {
|
||||
return currentSession, nil, fmt.Errorf("failed to list messages after summary: %w", err)
|
||||
}
|
||||
} else {
|
||||
// If no summary, fetch all messages
|
||||
sessionMessages, err = a.messages.List(ctx, sessionID)
|
||||
if err != nil {
|
||||
return currentSession, nil, fmt.Errorf("failed to list messages: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var messages []message.Message
|
||||
if currentSession.Summary != "" && !currentSession.SummarizedAt.IsZero() {
|
||||
// If summary exists, create a temporary message for the summary
|
||||
summaryMessage := message.Message{
|
||||
Role: message.Assistant,
|
||||
Parts: []message.ContentPart{
|
||||
message.TextContent{Text: currentSession.Summary},
|
||||
},
|
||||
}
|
||||
// Start with the summary, then add messages after the summary timestamp
|
||||
messages = append([]message.Message{summaryMessage}, sessionMessages...)
|
||||
} else {
|
||||
// If no summary, just use all messages
|
||||
messages = sessionMessages
|
||||
}
|
||||
|
||||
return currentSession, messages, nil
|
||||
}
|
||||
|
||||
func (a *agent) triggerTitleGeneration(sessionID string, content string) {
|
||||
go func() {
|
||||
defer logging.RecoverPanic("agent.Run", func() {
|
||||
status.Error("panic while generating title")
|
||||
})
|
||||
titleErr := a.generateTitle(context.Background(), sessionID, content)
|
||||
if titleErr != nil {
|
||||
status.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
|
||||
currentSession, sessionMessages, err := a.prepareMessageHistory(ctx, sessionID)
|
||||
if err != nil {
|
||||
return a.err(err)
|
||||
}
|
||||
|
||||
// If this is a new session, start title generation asynchronously
|
||||
if len(sessionMessages) == 0 && currentSession.Summary == "" {
|
||||
a.triggerTitleGeneration(sessionID, content)
|
||||
}
|
||||
|
||||
userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
|
||||
if err != nil {
|
||||
return a.err(fmt.Errorf("failed to create user message: %w", err))
|
||||
}
|
||||
|
||||
messages := append(sessionMessages, userMsg)
|
||||
|
||||
for {
|
||||
// Check for cancellation before each iteration
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return a.err(ctx.Err())
|
||||
default:
|
||||
// Continue processing
|
||||
}
|
||||
|
||||
// Check if auto-compaction is needed before calling the provider
|
||||
usagePercentage, needsCompaction, errEstimate := a.EstimateContextWindowUsage(ctx, sessionID)
|
||||
if errEstimate != nil {
|
||||
slog.Warn("Failed to estimate context window usage for auto-compaction", "error", errEstimate, "sessionID", sessionID)
|
||||
} else if needsCompaction {
|
||||
status.Info(fmt.Sprintf("Context window usage is at %.2f%%. Auto-compacting conversation...", usagePercentage))
|
||||
|
||||
// Run compaction synchronously
|
||||
compactCtx, cancelCompact := context.WithTimeout(ctx, 30*time.Second) // Use appropriate context
|
||||
errCompact := a.CompactSession(compactCtx, sessionID, true)
|
||||
cancelCompact()
|
||||
|
||||
if errCompact != nil {
|
||||
status.Warn(fmt.Sprintf("Auto-compaction failed: %v. Context window usage may continue to grow.", errCompact))
|
||||
} else {
|
||||
status.Info("Auto-compaction completed successfully.")
|
||||
// After compaction, message history needs to be re-prepared.
|
||||
// The 'messages' slice needs to be updated with the new summary and subsequent messages,
|
||||
// ensuring the latest user message is correctly appended.
|
||||
_, sessionMessagesFromCompact, errPrepare := a.prepareMessageHistory(ctx, sessionID)
|
||||
if errPrepare != nil {
|
||||
return a.err(fmt.Errorf("failed to re-prepare message history after compaction: %w", errPrepare))
|
||||
}
|
||||
messages = sessionMessagesFromCompact
|
||||
|
||||
// Ensure the user message that triggered this cycle is the last one.
|
||||
// 'userMsg' was created before this loop using a.createUserMessage.
|
||||
// It should be appended to the 'messages' slice if it's not already the last element.
|
||||
if len(messages) == 0 || (len(messages) > 0 && messages[len(messages)-1].ID != userMsg.ID) {
|
||||
messages = append(messages, userMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, messages)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
agentMessage.AddFinish(message.FinishReasonCanceled)
|
||||
a.messages.Update(context.Background(), agentMessage)
|
||||
return a.err(ErrRequestCancelled)
|
||||
}
|
||||
return a.err(fmt.Errorf("failed to process events: %w", err))
|
||||
}
|
||||
slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
|
||||
if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
|
||||
// We are not done, we need to respond with the tool response
|
||||
messages = append(messages, agentMessage, *toolResults)
|
||||
continue
|
||||
}
|
||||
return AgentEvent{
|
||||
message: agentMessage,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
|
||||
parts := []message.ContentPart{message.TextContent{Text: content}}
|
||||
parts = append(parts, attachmentParts...)
|
||||
return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
|
||||
Role: message.User,
|
||||
Parts: parts,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *agent) createToolResponseMessage(ctx context.Context, sessionID string, toolResults []message.ToolResult) (*message.Message, error) {
|
||||
if len(toolResults) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
parts := make([]message.ContentPart, 0, len(toolResults))
|
||||
for _, tr := range toolResults {
|
||||
parts = append(parts, tr)
|
||||
}
|
||||
|
||||
msg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
|
||||
Role: message.Tool,
|
||||
Parts: parts,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create tool response message: %w", err)
|
||||
}
|
||||
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
|
||||
eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
|
||||
|
||||
assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
|
||||
Role: message.Assistant,
|
||||
Parts: []message.ContentPart{},
|
||||
Model: a.provider.Model().ID,
|
||||
})
|
||||
if err != nil {
|
||||
return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
|
||||
}
|
||||
|
||||
// Add the session and message ID into the context if needed by tools.
|
||||
ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
|
||||
ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
|
||||
|
||||
// Process each event in the stream.
|
||||
for event := range eventChan {
|
||||
if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
|
||||
a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
|
||||
return assistantMsg, nil, processErr
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
|
||||
return assistantMsg, nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// If the assistant wants to use tools, execute them
|
||||
if assistantMsg.FinishReason() == message.FinishReasonToolUse {
|
||||
toolCalls := assistantMsg.ToolCalls()
|
||||
if len(toolCalls) > 0 {
|
||||
toolResults, err := a.executeToolCalls(ctx, toolCalls)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
|
||||
}
|
||||
return assistantMsg, nil, err
|
||||
}
|
||||
|
||||
// Create a message with the tool results
|
||||
toolResponseMsg, err := a.createToolResponseMessage(ctx, sessionID, toolResults)
|
||||
if err != nil {
|
||||
return assistantMsg, nil, err
|
||||
}
|
||||
|
||||
return assistantMsg, toolResponseMsg, nil
|
||||
}
|
||||
}
|
||||
|
||||
return assistantMsg, nil, nil
|
||||
}
|
||||
|
||||
func (a *agent) executeToolCalls(ctx context.Context, toolCalls []message.ToolCall) ([]message.ToolResult, error) {
|
||||
toolResults := make([]message.ToolResult, len(toolCalls))
|
||||
|
||||
for i, toolCall := range toolCalls {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Make all future tool calls cancelled
|
||||
for j := i; j < len(toolCalls); j++ {
|
||||
toolResults[j] = message.ToolResult{
|
||||
ToolCallID: toolCalls[j].ID,
|
||||
Content: "Tool execution canceled by user",
|
||||
IsError: true,
|
||||
}
|
||||
}
|
||||
return toolResults, ctx.Err()
|
||||
default:
|
||||
// Continue processing
|
||||
var tool tools.BaseTool
|
||||
for _, availableTools := range a.tools {
|
||||
if availableTools.Info().Name == toolCall.Name {
|
||||
tool = availableTools
|
||||
}
|
||||
}
|
||||
|
||||
// Tool not found
|
||||
if tool == nil {
|
||||
toolResults[i] = message.ToolResult{
|
||||
ToolCallID: toolCall.ID,
|
||||
Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
|
||||
IsError: true,
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
|
||||
ID: toolCall.ID,
|
||||
Name: toolCall.Name,
|
||||
Input: toolCall.Input,
|
||||
})
|
||||
|
||||
if toolErr != nil {
|
||||
if errors.Is(toolErr, permission.ErrorPermissionDenied) {
|
||||
toolResults[i] = message.ToolResult{
|
||||
ToolCallID: toolCall.ID,
|
||||
Content: "Permission denied",
|
||||
IsError: true,
|
||||
}
|
||||
// Cancel all remaining tool calls if permission is denied
|
||||
for j := i + 1; j < len(toolCalls); j++ {
|
||||
toolResults[j] = message.ToolResult{
|
||||
ToolCallID: toolCalls[j].ID,
|
||||
Content: "Tool execution canceled by user",
|
||||
IsError: true,
|
||||
}
|
||||
}
|
||||
return toolResults, nil
|
||||
}
|
||||
|
||||
// Handle other errors
|
||||
toolResults[i] = message.ToolResult{
|
||||
ToolCallID: toolCall.ID,
|
||||
Content: toolErr.Error(),
|
||||
IsError: true,
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
toolResults[i] = message.ToolResult{
|
||||
ToolCallID: toolCall.ID,
|
||||
Content: toolResult.Content,
|
||||
Metadata: toolResult.Metadata,
|
||||
IsError: toolResult.IsError,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return toolResults, nil
|
||||
}
|
||||
|
||||
func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
|
||||
msg.AddFinish(finishReson)
|
||||
_, _ = a.messages.Update(ctx, *msg)
|
||||
}
|
||||
|
||||
func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
// Continue processing
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case provider.EventThinkingDelta:
|
||||
assistantMsg.AppendReasoningContent(event.Content)
|
||||
_, err := a.messages.Update(ctx, *assistantMsg)
|
||||
return err
|
||||
case provider.EventContentDelta:
|
||||
assistantMsg.AppendContent(event.Content)
|
||||
_, err := a.messages.Update(ctx, *assistantMsg)
|
||||
return err
|
||||
case provider.EventToolUseStart:
|
||||
assistantMsg.AddToolCall(*event.ToolCall)
|
||||
_, err := a.messages.Update(ctx, *assistantMsg)
|
||||
return err
|
||||
// TODO: see how to handle this
|
||||
// case provider.EventToolUseDelta:
|
||||
// tm := time.Unix(assistantMsg.UpdatedAt, 0)
|
||||
// assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
|
||||
// if time.Since(tm) > 1000*time.Millisecond {
|
||||
// err := a.messages.Update(ctx, *assistantMsg)
|
||||
// assistantMsg.UpdatedAt = time.Now().Unix()
|
||||
// return err
|
||||
// }
|
||||
case provider.EventToolUseStop:
|
||||
assistantMsg.FinishToolCall(event.ToolCall.ID)
|
||||
_, err := a.messages.Update(ctx, *assistantMsg)
|
||||
return err
|
||||
case provider.EventError:
|
||||
if errors.Is(event.Error, context.Canceled) {
|
||||
status.Info(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
|
||||
return context.Canceled
|
||||
}
|
||||
status.Error(event.Error.Error())
|
||||
return event.Error
|
||||
case provider.EventComplete:
|
||||
assistantMsg.SetToolCalls(event.Response.ToolCalls)
|
||||
assistantMsg.AddFinish(event.Response.FinishReason)
|
||||
if _, err := a.messages.Update(ctx, *assistantMsg); err != nil {
|
||||
return fmt.Errorf("failed to update message: %w", err)
|
||||
}
|
||||
return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *agent) GetUsage(ctx context.Context, sessionID string) (*int64, error) {
|
||||
session, err := a.sessions.Get(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get session: %w", err)
|
||||
}
|
||||
|
||||
usage := session.PromptTokens + session.CompletionTokens
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
func (a *agent) EstimateContextWindowUsage(ctx context.Context, sessionID string) (float64, bool, error) {
|
||||
session, err := a.sessions.Get(ctx, sessionID)
|
||||
if err != nil {
|
||||
return 0, false, fmt.Errorf("failed to get session: %w", err)
|
||||
}
|
||||
|
||||
// Get the model's context window size
|
||||
model := a.provider.Model()
|
||||
contextWindow := model.ContextWindow
|
||||
if contextWindow <= 0 {
|
||||
// Default to a reasonable size if not specified
|
||||
contextWindow = 100000
|
||||
}
|
||||
|
||||
// Calculate current token usage
|
||||
currentTokens := session.PromptTokens + session.CompletionTokens
|
||||
|
||||
// Get the max tokens setting for the agent
|
||||
maxTokens := a.provider.MaxTokens()
|
||||
|
||||
// Calculate percentage of context window used
|
||||
usagePercentage := float64(currentTokens) / float64(contextWindow)
|
||||
|
||||
// Check if we need to auto-compact
|
||||
// Auto-compact when:
|
||||
// 1. Usage exceeds 90% of context window, OR
|
||||
// 2. Current usage + maxTokens would exceed 100% of context window
|
||||
needsCompaction := usagePercentage >= 0.9 ||
|
||||
float64(currentTokens+maxTokens) > float64(contextWindow)
|
||||
|
||||
return usagePercentage * 100, needsCompaction, nil
|
||||
}
|
||||
|
||||
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
|
||||
sess, err := a.sessions.Get(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get session: %w", err)
|
||||
}
|
||||
|
||||
cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
|
||||
model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
|
||||
model.CostPer1MIn/1e6*float64(usage.InputTokens) +
|
||||
model.CostPer1MOut/1e6*float64(usage.OutputTokens)
|
||||
|
||||
sess.Cost += cost
|
||||
sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
|
||||
sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
|
||||
|
||||
_, err = a.sessions.Update(ctx, sess)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
|
||||
if a.IsBusy() {
|
||||
return models.Model{}, fmt.Errorf("cannot change model while processing requests")
|
||||
}
|
||||
|
||||
if err := config.UpdateAgentModel(agentName, modelID); err != nil {
|
||||
return models.Model{}, fmt.Errorf("failed to update config: %w", err)
|
||||
}
|
||||
|
||||
provider, err := createAgentProvider(agentName)
|
||||
if err != nil {
|
||||
return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
|
||||
}
|
||||
|
||||
a.provider = provider
|
||||
|
||||
return a.provider.Model(), nil
|
||||
}
|
||||
|
||||
func (a *agent) CompactSession(ctx context.Context, sessionID string, force bool) error {
|
||||
// Check if the session is busy
|
||||
if a.IsSessionBusy(sessionID) && !force {
|
||||
return ErrSessionBusy
|
||||
}
|
||||
|
||||
// Create a cancellable context
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Mark the session as busy during compaction
|
||||
compactionCancelFunc := func() {}
|
||||
a.activeRequests.Store(sessionID+"-compact", compactionCancelFunc)
|
||||
defer a.activeRequests.Delete(sessionID + "-compact")
|
||||
|
||||
// Fetch the session
|
||||
session, err := a.sessions.Get(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get session: %w", err)
|
||||
}
|
||||
|
||||
// Fetch all messages for the session
|
||||
sessionMessages, err := a.messages.List(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list messages: %w", err)
|
||||
}
|
||||
|
||||
var existingSummary string
|
||||
if session.Summary != "" && !session.SummarizedAt.IsZero() {
|
||||
// Filter messages that were created after the last summarization
|
||||
var newMessages []message.Message
|
||||
for _, msg := range sessionMessages {
|
||||
if msg.CreatedAt.After(session.SummarizedAt) {
|
||||
newMessages = append(newMessages, msg)
|
||||
}
|
||||
}
|
||||
sessionMessages = newMessages
|
||||
existingSummary = session.Summary
|
||||
}
|
||||
|
||||
// If there are no messages to summarize and no existing summary, return early
|
||||
if len(sessionMessages) == 0 && existingSummary == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
messages := []message.Message{
|
||||
message.Message{
|
||||
Role: message.System,
|
||||
Parts: []message.ContentPart{
|
||||
message.TextContent{
|
||||
Text: `You are a helpful AI assistant tasked with summarizing conversations.
|
||||
|
||||
When asked to summarize, provide a detailed but concise summary of the conversation.
|
||||
Focus on information that would be helpful for continuing the conversation, including:
|
||||
- What was done
|
||||
- What is currently being worked on
|
||||
- Which files are being modified
|
||||
- What needs to be done next
|
||||
|
||||
Your summary should be comprehensive enough to provide context but concise enough to be quickly understood.`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// If there's an existing summary, include it
|
||||
if existingSummary != "" {
|
||||
messages = append(messages, message.Message{
|
||||
Role: message.Assistant,
|
||||
Parts: []message.ContentPart{
|
||||
message.TextContent{
|
||||
Text: existingSummary,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Add all messages since the last summarized message
|
||||
messages = append(messages, sessionMessages...)
|
||||
|
||||
// Add a final user message requesting the summary
|
||||
messages = append(messages, message.Message{
|
||||
Role: message.User,
|
||||
Parts: []message.ContentPart{
|
||||
message.TextContent{
|
||||
Text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// Call provider to get the summary
|
||||
response, err := a.provider.SendMessages(ctx, messages, a.tools)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get summary from the assistant: %w", err)
|
||||
}
|
||||
|
||||
// Extract the summary text
|
||||
summaryText := strings.TrimSpace(response.Content)
|
||||
if summaryText == "" {
|
||||
return fmt.Errorf("received empty summary from the assistant")
|
||||
}
|
||||
|
||||
// Update the session with the new summary
|
||||
session.Summary = summaryText
|
||||
session.SummarizedAt = time.Now()
|
||||
|
||||
// Save the updated session
|
||||
_, err = a.sessions.Update(ctx, session)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save session with summary: %w", err)
|
||||
}
|
||||
|
||||
// Track token usage
|
||||
err = a.TrackUsage(ctx, sessionID, a.provider.Model(), response.Usage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to track usage: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
|
||||
cfg := config.Get()
|
||||
agentConfig, ok := cfg.Agents[agentName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("agent %s not found", agentName)
|
||||
}
|
||||
model, ok := models.SupportedModels[agentConfig.Model]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
|
||||
}
|
||||
|
||||
providerCfg, ok := cfg.Providers[model.Provider]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("provider %s not supported", model.Provider)
|
||||
}
|
||||
if providerCfg.Disabled {
|
||||
return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
|
||||
}
|
||||
maxTokens := model.DefaultMaxTokens
|
||||
if agentConfig.MaxTokens > 0 {
|
||||
maxTokens = agentConfig.MaxTokens
|
||||
}
|
||||
opts := []provider.ProviderClientOption{
|
||||
provider.WithAPIKey(providerCfg.APIKey),
|
||||
provider.WithModel(model),
|
||||
provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
|
||||
provider.WithMaxTokens(maxTokens),
|
||||
}
|
||||
if model.Provider == models.ProviderOpenAI && model.CanReason {
|
||||
opts = append(
|
||||
opts,
|
||||
provider.WithOpenAIOptions(
|
||||
provider.WithReasoningEffort(agentConfig.ReasoningEffort),
|
||||
),
|
||||
)
|
||||
} else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentPrimary {
|
||||
opts = append(
|
||||
opts,
|
||||
provider.WithAnthropicOptions(
|
||||
provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
|
||||
),
|
||||
)
|
||||
}
|
||||
agentProvider, err := provider.NewProvider(
|
||||
model.Provider,
|
||||
opts...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create provider: %v", err)
|
||||
}
|
||||
|
||||
return agentProvider, nil
|
||||
}
|
||||
@@ -1,198 +0,0 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/permission"
|
||||
"github.com/sst/opencode/internal/version"
|
||||
"log/slog"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
type mcpTool struct {
|
||||
mcpName string
|
||||
tool mcp.Tool
|
||||
mcpConfig config.MCPServer
|
||||
permissions permission.Service
|
||||
}
|
||||
|
||||
type MCPClient interface {
|
||||
Initialize(
|
||||
ctx context.Context,
|
||||
request mcp.InitializeRequest,
|
||||
) (*mcp.InitializeResult, error)
|
||||
ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
|
||||
CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
func (b *mcpTool) Info() tools.ToolInfo {
|
||||
return tools.ToolInfo{
|
||||
Name: fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
|
||||
Description: b.tool.Description,
|
||||
Parameters: b.tool.InputSchema.Properties,
|
||||
Required: b.tool.InputSchema.Required,
|
||||
}
|
||||
}
|
||||
|
||||
func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
|
||||
defer c.Close()
|
||||
initRequest := mcp.InitializeRequest{}
|
||||
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
|
||||
initRequest.Params.ClientInfo = mcp.Implementation{
|
||||
Name: "OpenCode",
|
||||
Version: version.Version,
|
||||
}
|
||||
|
||||
_, err := c.Initialize(ctx, initRequest)
|
||||
if err != nil {
|
||||
return tools.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
toolRequest := mcp.CallToolRequest{}
|
||||
toolRequest.Params.Name = toolName
|
||||
var args map[string]any
|
||||
if err = json.Unmarshal([]byte(input), &args); err != nil {
|
||||
return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
toolRequest.Params.Arguments = args
|
||||
result, err := c.CallTool(ctx, toolRequest)
|
||||
if err != nil {
|
||||
return tools.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
output := ""
|
||||
for _, v := range result.Content {
|
||||
if v, ok := v.(mcp.TextContent); ok {
|
||||
output = v.Text
|
||||
} else {
|
||||
output = fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
return tools.NewTextResponse(output), nil
|
||||
}
|
||||
|
||||
func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
|
||||
sessionID, messageID := tools.GetContextValues(ctx)
|
||||
if sessionID == "" || messageID == "" {
|
||||
return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
|
||||
}
|
||||
permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
|
||||
p := b.permissions.Request(
|
||||
ctx,
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: config.WorkingDirectory(),
|
||||
ToolName: b.Info().Name,
|
||||
Action: "execute",
|
||||
Description: permissionDescription,
|
||||
Params: params.Input,
|
||||
},
|
||||
)
|
||||
if !p {
|
||||
return tools.NewTextErrorResponse("permission denied"), nil
|
||||
}
|
||||
|
||||
switch b.mcpConfig.Type {
|
||||
case config.MCPStdio:
|
||||
c, err := client.NewStdioMCPClient(
|
||||
b.mcpConfig.Command,
|
||||
b.mcpConfig.Env,
|
||||
b.mcpConfig.Args...,
|
||||
)
|
||||
if err != nil {
|
||||
return tools.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return runTool(ctx, c, b.tool.Name, params.Input)
|
||||
case config.MCPSse:
|
||||
c, err := client.NewSSEMCPClient(
|
||||
b.mcpConfig.URL,
|
||||
client.WithHeaders(b.mcpConfig.Headers),
|
||||
)
|
||||
if err != nil {
|
||||
return tools.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return runTool(ctx, c, b.tool.Name, params.Input)
|
||||
}
|
||||
|
||||
return tools.NewTextErrorResponse("invalid mcp type"), nil
|
||||
}
|
||||
|
||||
func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPServer) tools.BaseTool {
|
||||
return &mcpTool{
|
||||
mcpName: name,
|
||||
tool: tool,
|
||||
mcpConfig: mcpConfig,
|
||||
permissions: permissions,
|
||||
}
|
||||
}
|
||||
|
||||
var mcpTools []tools.BaseTool
|
||||
|
||||
func getTools(ctx context.Context, name string, m config.MCPServer, permissions permission.Service, c MCPClient) []tools.BaseTool {
|
||||
var stdioTools []tools.BaseTool
|
||||
initRequest := mcp.InitializeRequest{}
|
||||
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
|
||||
initRequest.Params.ClientInfo = mcp.Implementation{
|
||||
Name: "OpenCode",
|
||||
Version: version.Version,
|
||||
}
|
||||
|
||||
_, err := c.Initialize(ctx, initRequest)
|
||||
if err != nil {
|
||||
slog.Error("error initializing mcp client", "error", err)
|
||||
return stdioTools
|
||||
}
|
||||
toolsRequest := mcp.ListToolsRequest{}
|
||||
tools, err := c.ListTools(ctx, toolsRequest)
|
||||
if err != nil {
|
||||
slog.Error("error listing tools", "error", err)
|
||||
return stdioTools
|
||||
}
|
||||
for _, t := range tools.Tools {
|
||||
stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m))
|
||||
}
|
||||
defer c.Close()
|
||||
return stdioTools
|
||||
}
|
||||
|
||||
func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.BaseTool {
|
||||
if len(mcpTools) > 0 {
|
||||
return mcpTools
|
||||
}
|
||||
for name, m := range config.Get().MCPServers {
|
||||
switch m.Type {
|
||||
case config.MCPStdio:
|
||||
c, err := client.NewStdioMCPClient(
|
||||
m.Command,
|
||||
m.Env,
|
||||
m.Args...,
|
||||
)
|
||||
if err != nil {
|
||||
slog.Error("error creating mcp client", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
|
||||
case config.MCPSse:
|
||||
c, err := client.NewSSEMCPClient(
|
||||
m.URL,
|
||||
client.WithHeaders(m.Headers),
|
||||
)
|
||||
if err != nil {
|
||||
slog.Error("error creating mcp client", "error", err)
|
||||
continue
|
||||
}
|
||||
mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
|
||||
}
|
||||
}
|
||||
|
||||
return mcpTools
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sst/opencode/internal/history"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/permission"
|
||||
"github.com/sst/opencode/internal/session"
|
||||
)
|
||||
|
||||
func PrimaryAgentTools(
|
||||
permissions permission.Service,
|
||||
sessions session.Service,
|
||||
messages message.Service,
|
||||
history history.Service,
|
||||
lspClients map[string]*lsp.Client,
|
||||
) []tools.BaseTool {
|
||||
ctx := context.Background()
|
||||
mcpTools := GetMcpTools(ctx, permissions)
|
||||
|
||||
return append(
|
||||
[]tools.BaseTool{
|
||||
tools.NewBashTool(permissions),
|
||||
tools.NewEditTool(lspClients, permissions, history),
|
||||
tools.NewFetchTool(permissions),
|
||||
tools.NewGlobTool(),
|
||||
tools.NewGrepTool(),
|
||||
tools.NewLsTool(),
|
||||
tools.NewViewTool(lspClients),
|
||||
tools.NewPatchTool(lspClients, permissions, history),
|
||||
tools.NewWriteTool(lspClients, permissions, history),
|
||||
tools.NewDiagnosticsTool(lspClients),
|
||||
tools.NewDefinitionTool(lspClients),
|
||||
tools.NewReferencesTool(lspClients),
|
||||
tools.NewDocSymbolsTool(lspClients),
|
||||
tools.NewWorkspaceSymbolsTool(lspClients),
|
||||
NewAgentTool(sessions, messages, lspClients),
|
||||
}, mcpTools...,
|
||||
)
|
||||
}
|
||||
|
||||
func TaskAgentTools(lspClients map[string]*lsp.Client) []tools.BaseTool {
|
||||
return []tools.BaseTool{
|
||||
tools.NewGlobTool(),
|
||||
tools.NewGrepTool(),
|
||||
tools.NewLsTool(),
|
||||
tools.NewViewTool(lspClients),
|
||||
tools.NewDefinitionTool(lspClients),
|
||||
tools.NewReferencesTool(lspClients),
|
||||
tools.NewDocSymbolsTool(lspClients),
|
||||
tools.NewWorkspaceSymbolsTool(lspClients),
|
||||
}
|
||||
}
|
||||
@@ -1,82 +0,0 @@
|
||||
package models
|
||||
|
||||
const (
|
||||
ProviderAnthropic ModelProvider = "anthropic"
|
||||
|
||||
// Models
|
||||
Claude35Sonnet ModelID = "claude-3.5-sonnet"
|
||||
Claude3Haiku ModelID = "claude-3-haiku"
|
||||
Claude37Sonnet ModelID = "claude-3.7-sonnet"
|
||||
Claude35Haiku ModelID = "claude-3.5-haiku"
|
||||
Claude3Opus ModelID = "claude-3-opus"
|
||||
)
|
||||
|
||||
// https://docs.anthropic.com/en/docs/about-claude/models/all-models
|
||||
var AnthropicModels = map[ModelID]Model{
|
||||
Claude35Sonnet: {
|
||||
ID: Claude35Sonnet,
|
||||
Name: "Claude 3.5 Sonnet",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-5-sonnet-latest",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 5000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Claude3Haiku: {
|
||||
ID: Claude3Haiku,
|
||||
Name: "Claude 3 Haiku",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-haiku-20240307", // doesn't support "-latest"
|
||||
CostPer1MIn: 0.25,
|
||||
CostPer1MInCached: 0.30,
|
||||
CostPer1MOutCached: 0.03,
|
||||
CostPer1MOut: 1.25,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 4096,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Claude37Sonnet: {
|
||||
ID: Claude37Sonnet,
|
||||
Name: "Claude 3.7 Sonnet",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-7-sonnet-latest",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Claude35Haiku: {
|
||||
ID: Claude35Haiku,
|
||||
Name: "Claude 3.5 Haiku",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-5-haiku-latest",
|
||||
CostPer1MIn: 0.80,
|
||||
CostPer1MInCached: 1.0,
|
||||
CostPer1MOutCached: 0.08,
|
||||
CostPer1MOut: 4.0,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 4096,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Claude3Opus: {
|
||||
ID: Claude3Opus,
|
||||
Name: "Claude 3 Opus",
|
||||
Provider: ProviderAnthropic,
|
||||
APIModel: "claude-3-opus-latest",
|
||||
CostPer1MIn: 15.0,
|
||||
CostPer1MInCached: 18.75,
|
||||
CostPer1MOutCached: 1.50,
|
||||
CostPer1MOut: 75.0,
|
||||
ContextWindow: 200000,
|
||||
DefaultMaxTokens: 4096,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
}
|
||||
@@ -1,168 +0,0 @@
|
||||
package models
|
||||
|
||||
const ProviderAzure ModelProvider = "azure"
|
||||
|
||||
const (
|
||||
AzureGPT41 ModelID = "azure.gpt-4.1"
|
||||
AzureGPT41Mini ModelID = "azure.gpt-4.1-mini"
|
||||
AzureGPT41Nano ModelID = "azure.gpt-4.1-nano"
|
||||
AzureGPT45Preview ModelID = "azure.gpt-4.5-preview"
|
||||
AzureGPT4o ModelID = "azure.gpt-4o"
|
||||
AzureGPT4oMini ModelID = "azure.gpt-4o-mini"
|
||||
AzureO1 ModelID = "azure.o1"
|
||||
AzureO1Mini ModelID = "azure.o1-mini"
|
||||
AzureO3 ModelID = "azure.o3"
|
||||
AzureO3Mini ModelID = "azure.o3-mini"
|
||||
AzureO4Mini ModelID = "azure.o4-mini"
|
||||
)
|
||||
|
||||
var AzureModels = map[ModelID]Model{
|
||||
AzureGPT41: {
|
||||
ID: AzureGPT41,
|
||||
Name: "Azure OpenAI – GPT 4.1",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4.1",
|
||||
CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT41].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT41].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT41].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT41].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT41].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureGPT41Mini: {
|
||||
ID: AzureGPT41Mini,
|
||||
Name: "Azure OpenAI – GPT 4.1 mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4.1-mini",
|
||||
CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT41Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT41Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT41Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT41Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT41Mini].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureGPT41Nano: {
|
||||
ID: AzureGPT41Nano,
|
||||
Name: "Azure OpenAI – GPT 4.1 nano",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4.1-nano",
|
||||
CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT41Nano].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT41Nano].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT41Nano].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT41Nano].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT41Nano].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureGPT45Preview: {
|
||||
ID: AzureGPT45Preview,
|
||||
Name: "Azure OpenAI – GPT 4.5 preview",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4.5-preview",
|
||||
CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT45Preview].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT45Preview].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT45Preview].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT45Preview].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT45Preview].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureGPT4o: {
|
||||
ID: AzureGPT4o,
|
||||
Name: "Azure OpenAI – GPT-4o",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4o",
|
||||
CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT4o].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT4o].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT4o].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT4o].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT4o].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureGPT4oMini: {
|
||||
ID: AzureGPT4oMini,
|
||||
Name: "Azure OpenAI – GPT-4o mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "gpt-4o-mini",
|
||||
CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT4oMini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT4oMini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT4oMini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT4oMini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT4oMini].DefaultMaxTokens,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureO1: {
|
||||
ID: AzureO1,
|
||||
Name: "Azure OpenAI – O1",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o1",
|
||||
CostPer1MIn: OpenAIModels[O1].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O1].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O1].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O1].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O1].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O1].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O1].CanReason,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureO1Mini: {
|
||||
ID: AzureO1Mini,
|
||||
Name: "Azure OpenAI – O1 mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o1-mini",
|
||||
CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O1Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O1Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O1Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O1Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O1Mini].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O1Mini].CanReason,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureO3: {
|
||||
ID: AzureO3,
|
||||
Name: "Azure OpenAI – O3",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o3",
|
||||
CostPer1MIn: OpenAIModels[O3].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O3].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O3].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O3].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O3].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O3].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O3].CanReason,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
AzureO3Mini: {
|
||||
ID: AzureO3Mini,
|
||||
Name: "Azure OpenAI – O3 mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o3-mini",
|
||||
CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O3Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O3Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O3Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O3Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O3Mini].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O3Mini].CanReason,
|
||||
SupportsAttachments: false,
|
||||
},
|
||||
AzureO4Mini: {
|
||||
ID: AzureO4Mini,
|
||||
Name: "Azure OpenAI – O4 mini",
|
||||
Provider: ProviderAzure,
|
||||
APIModel: "o4-mini",
|
||||
CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O4Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O4Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O4Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O4Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O4Mini].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O4Mini].CanReason,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
package models
|
||||
|
||||
const (
|
||||
ProviderGemini ModelProvider = "gemini"
|
||||
|
||||
// Models
|
||||
Gemini25Flash ModelID = "gemini-2.5-flash"
|
||||
Gemini25 ModelID = "gemini-2.5"
|
||||
Gemini20Flash ModelID = "gemini-2.0-flash"
|
||||
Gemini20FlashLite ModelID = "gemini-2.0-flash-lite"
|
||||
)
|
||||
|
||||
var GeminiModels = map[ModelID]Model{
|
||||
Gemini25Flash: {
|
||||
ID: Gemini25Flash,
|
||||
Name: "Gemini 2.5 Flash",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.5-flash-preview-04-17",
|
||||
CostPer1MIn: 0.15,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.60,
|
||||
ContextWindow: 1000000,
|
||||
DefaultMaxTokens: 50000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Gemini25: {
|
||||
ID: Gemini25,
|
||||
Name: "Gemini 2.5 Pro",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.5-pro-preview-03-25",
|
||||
CostPer1MIn: 1.25,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 10,
|
||||
ContextWindow: 1000000,
|
||||
DefaultMaxTokens: 50000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
|
||||
Gemini20Flash: {
|
||||
ID: Gemini20Flash,
|
||||
Name: "Gemini 2.0 Flash",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.0-flash",
|
||||
CostPer1MIn: 0.10,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.40,
|
||||
ContextWindow: 1000000,
|
||||
DefaultMaxTokens: 6000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
Gemini20FlashLite: {
|
||||
ID: Gemini20FlashLite,
|
||||
Name: "Gemini 2.0 Flash Lite",
|
||||
Provider: ProviderGemini,
|
||||
APIModel: "gemini-2.0-flash-lite",
|
||||
CostPer1MIn: 0.05,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.30,
|
||||
ContextWindow: 1000000,
|
||||
DefaultMaxTokens: 6000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
}
|
||||
@@ -1,87 +0,0 @@
|
||||
package models
|
||||
|
||||
const (
|
||||
ProviderGROQ ModelProvider = "groq"
|
||||
|
||||
// GROQ
|
||||
QWENQwq ModelID = "qwen-qwq"
|
||||
|
||||
// GROQ preview models
|
||||
Llama4Scout ModelID = "meta-llama/llama-4-scout-17b-16e-instruct"
|
||||
Llama4Maverick ModelID = "meta-llama/llama-4-maverick-17b-128e-instruct"
|
||||
Llama3_3_70BVersatile ModelID = "llama-3.3-70b-versatile"
|
||||
DeepseekR1DistillLlama70b ModelID = "deepseek-r1-distill-llama-70b"
|
||||
)
|
||||
|
||||
var GroqModels = map[ModelID]Model{
|
||||
//
|
||||
// GROQ
|
||||
QWENQwq: {
|
||||
ID: QWENQwq,
|
||||
Name: "Qwen Qwq",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "qwen-qwq-32b",
|
||||
CostPer1MIn: 0.29,
|
||||
CostPer1MInCached: 0.275,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 0.39,
|
||||
ContextWindow: 128_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
// for some reason, the groq api doesn't like the reasoningEffort parameter
|
||||
CanReason: false,
|
||||
SupportsAttachments: false,
|
||||
},
|
||||
|
||||
Llama4Scout: {
|
||||
ID: Llama4Scout,
|
||||
Name: "Llama4Scout",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
CostPer1MIn: 0.11,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.34,
|
||||
ContextWindow: 128_000, // 10M when?
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
|
||||
Llama4Maverick: {
|
||||
ID: Llama4Maverick,
|
||||
Name: "Llama4Maverick",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
CostPer1MIn: 0.20,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.20,
|
||||
ContextWindow: 128_000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
|
||||
Llama3_3_70BVersatile: {
|
||||
ID: Llama3_3_70BVersatile,
|
||||
Name: "Llama3_3_70BVersatile",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "llama-3.3-70b-versatile",
|
||||
CostPer1MIn: 0.59,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.79,
|
||||
ContextWindow: 128_000,
|
||||
SupportsAttachments: false,
|
||||
},
|
||||
|
||||
DeepseekR1DistillLlama70b: {
|
||||
ID: DeepseekR1DistillLlama70b,
|
||||
Name: "DeepseekR1DistillLlama70b",
|
||||
Provider: ProviderGROQ,
|
||||
APIModel: "deepseek-r1-distill-llama-70b",
|
||||
CostPer1MIn: 0.75,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOutCached: 0,
|
||||
CostPer1MOut: 0.99,
|
||||
ContextWindow: 128_000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: false,
|
||||
},
|
||||
}
|
||||
@@ -1,98 +0,0 @@
|
||||
package models
|
||||
|
||||
import "maps"
|
||||
|
||||
type (
|
||||
ModelID string
|
||||
ModelProvider string
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
ID ModelID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider ModelProvider `json:"provider"`
|
||||
APIModel string `json:"api_model"`
|
||||
CostPer1MIn float64 `json:"cost_per_1m_in"`
|
||||
CostPer1MOut float64 `json:"cost_per_1m_out"`
|
||||
CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
|
||||
CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
|
||||
ContextWindow int64 `json:"context_window"`
|
||||
DefaultMaxTokens int64 `json:"default_max_tokens"`
|
||||
CanReason bool `json:"can_reason"`
|
||||
SupportsAttachments bool `json:"supports_attachments"`
|
||||
}
|
||||
|
||||
// Model IDs
|
||||
const ( // GEMINI
|
||||
// Bedrock
|
||||
BedrockClaude37Sonnet ModelID = "bedrock.claude-3.7-sonnet"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderBedrock ModelProvider = "bedrock"
|
||||
// ForTests
|
||||
ProviderMock ModelProvider = "__mock"
|
||||
)
|
||||
|
||||
// Providers in order of popularity
|
||||
var ProviderPopularity = map[ModelProvider]int{
|
||||
ProviderAnthropic: 1,
|
||||
ProviderOpenAI: 2,
|
||||
ProviderGemini: 3,
|
||||
ProviderGROQ: 4,
|
||||
ProviderOpenRouter: 5,
|
||||
ProviderBedrock: 6,
|
||||
ProviderAzure: 7,
|
||||
}
|
||||
|
||||
var SupportedModels = map[ModelID]Model{
|
||||
//
|
||||
// // GEMINI
|
||||
// GEMINI25: {
|
||||
// ID: GEMINI25,
|
||||
// Name: "Gemini 2.5 Pro",
|
||||
// Provider: ProviderGemini,
|
||||
// APIModel: "gemini-2.5-pro-exp-03-25",
|
||||
// CostPer1MIn: 0,
|
||||
// CostPer1MInCached: 0,
|
||||
// CostPer1MOutCached: 0,
|
||||
// CostPer1MOut: 0,
|
||||
// },
|
||||
//
|
||||
// GRMINI20Flash: {
|
||||
// ID: GRMINI20Flash,
|
||||
// Name: "Gemini 2.0 Flash",
|
||||
// Provider: ProviderGemini,
|
||||
// APIModel: "gemini-2.0-flash",
|
||||
// CostPer1MIn: 0.1,
|
||||
// CostPer1MInCached: 0,
|
||||
// CostPer1MOutCached: 0.025,
|
||||
// CostPer1MOut: 0.4,
|
||||
// },
|
||||
//
|
||||
// // Bedrock
|
||||
BedrockClaude37Sonnet: {
|
||||
ID: BedrockClaude37Sonnet,
|
||||
Name: "Bedrock: Claude 3.7 Sonnet",
|
||||
Provider: ProviderBedrock,
|
||||
APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 3.75,
|
||||
CostPer1MOutCached: 0.30,
|
||||
CostPer1MOut: 15.0,
|
||||
ContextWindow: 200_000,
|
||||
DefaultMaxTokens: 50_000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
maps.Copy(SupportedModels, AnthropicModels)
|
||||
maps.Copy(SupportedModels, OpenAIModels)
|
||||
maps.Copy(SupportedModels, GeminiModels)
|
||||
maps.Copy(SupportedModels, GroqModels)
|
||||
maps.Copy(SupportedModels, AzureModels)
|
||||
maps.Copy(SupportedModels, OpenRouterModels)
|
||||
maps.Copy(SupportedModels, XAIModels)
|
||||
}
|
||||
@@ -1,181 +0,0 @@
|
||||
package models
|
||||
|
||||
const (
|
||||
ProviderOpenAI ModelProvider = "openai"
|
||||
|
||||
GPT41 ModelID = "gpt-4.1"
|
||||
GPT41Mini ModelID = "gpt-4.1-mini"
|
||||
GPT41Nano ModelID = "gpt-4.1-nano"
|
||||
GPT45Preview ModelID = "gpt-4.5-preview"
|
||||
GPT4o ModelID = "gpt-4o"
|
||||
GPT4oMini ModelID = "gpt-4o-mini"
|
||||
O1 ModelID = "o1"
|
||||
O1Pro ModelID = "o1-pro"
|
||||
O1Mini ModelID = "o1-mini"
|
||||
O3 ModelID = "o3"
|
||||
O3Mini ModelID = "o3-mini"
|
||||
O4Mini ModelID = "o4-mini"
|
||||
)
|
||||
|
||||
var OpenAIModels = map[ModelID]Model{
|
||||
GPT41: {
|
||||
ID: GPT41,
|
||||
Name: "GPT 4.1",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4.1",
|
||||
CostPer1MIn: 2.00,
|
||||
CostPer1MInCached: 0.50,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 8.00,
|
||||
ContextWindow: 1_047_576,
|
||||
DefaultMaxTokens: 20000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
GPT41Mini: {
|
||||
ID: GPT41Mini,
|
||||
Name: "GPT 4.1 mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4.1",
|
||||
CostPer1MIn: 0.40,
|
||||
CostPer1MInCached: 0.10,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 1.60,
|
||||
ContextWindow: 200_000,
|
||||
DefaultMaxTokens: 20000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
GPT41Nano: {
|
||||
ID: GPT41Nano,
|
||||
Name: "GPT 4.1 nano",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4.1-nano",
|
||||
CostPer1MIn: 0.10,
|
||||
CostPer1MInCached: 0.025,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 0.40,
|
||||
ContextWindow: 1_047_576,
|
||||
DefaultMaxTokens: 20000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
GPT45Preview: {
|
||||
ID: GPT45Preview,
|
||||
Name: "GPT 4.5 preview",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4.5-preview",
|
||||
CostPer1MIn: 75.00,
|
||||
CostPer1MInCached: 37.50,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 150.00,
|
||||
ContextWindow: 128_000,
|
||||
DefaultMaxTokens: 15000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
GPT4o: {
|
||||
ID: GPT4o,
|
||||
Name: "GPT 4o",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4o",
|
||||
CostPer1MIn: 2.50,
|
||||
CostPer1MInCached: 1.25,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 10.00,
|
||||
ContextWindow: 128_000,
|
||||
DefaultMaxTokens: 4096,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
GPT4oMini: {
|
||||
ID: GPT4oMini,
|
||||
Name: "GPT 4o mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "gpt-4o-mini",
|
||||
CostPer1MIn: 0.15,
|
||||
CostPer1MInCached: 0.075,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 0.60,
|
||||
ContextWindow: 128_000,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
O1: {
|
||||
ID: O1,
|
||||
Name: "O1",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o1",
|
||||
CostPer1MIn: 15.00,
|
||||
CostPer1MInCached: 7.50,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 60.00,
|
||||
ContextWindow: 200_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
O1Pro: {
|
||||
ID: O1Pro,
|
||||
Name: "o1 pro",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o1-pro",
|
||||
CostPer1MIn: 150.00,
|
||||
CostPer1MInCached: 0.0,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 600.00,
|
||||
ContextWindow: 200_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
O1Mini: {
|
||||
ID: O1Mini,
|
||||
Name: "o1 mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o1-mini",
|
||||
CostPer1MIn: 1.10,
|
||||
CostPer1MInCached: 0.55,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 4.40,
|
||||
ContextWindow: 128_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
O3: {
|
||||
ID: O3,
|
||||
Name: "o3",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o3",
|
||||
CostPer1MIn: 10.00,
|
||||
CostPer1MInCached: 2.50,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 40.00,
|
||||
ContextWindow: 200_000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
O3Mini: {
|
||||
ID: O3Mini,
|
||||
Name: "o3 mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o3-mini",
|
||||
CostPer1MIn: 1.10,
|
||||
CostPer1MInCached: 0.55,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 4.40,
|
||||
ContextWindow: 200_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: false,
|
||||
},
|
||||
O4Mini: {
|
||||
ID: O4Mini,
|
||||
Name: "o4 mini",
|
||||
Provider: ProviderOpenAI,
|
||||
APIModel: "o4-mini",
|
||||
CostPer1MIn: 1.10,
|
||||
CostPer1MInCached: 0.275,
|
||||
CostPer1MOutCached: 0.0,
|
||||
CostPer1MOut: 4.40,
|
||||
ContextWindow: 128_000,
|
||||
DefaultMaxTokens: 50000,
|
||||
CanReason: true,
|
||||
SupportsAttachments: true,
|
||||
},
|
||||
}
|
||||
@@ -1,327 +0,0 @@
|
||||
package models
|
||||
|
||||
const (
|
||||
ProviderOpenRouter ModelProvider = "openrouter"
|
||||
|
||||
OpenRouterGPT41 ModelID = "openrouter.gpt-4.1"
|
||||
OpenRouterGPT41Mini ModelID = "openrouter.gpt-4.1-mini"
|
||||
OpenRouterGPT41Nano ModelID = "openrouter.gpt-4.1-nano"
|
||||
OpenRouterGPT45Preview ModelID = "openrouter.gpt-4.5-preview"
|
||||
OpenRouterGPT4o ModelID = "openrouter.gpt-4o"
|
||||
OpenRouterGPT4oMini ModelID = "openrouter.gpt-4o-mini"
|
||||
OpenRouterO1 ModelID = "openrouter.o1"
|
||||
OpenRouterO1Pro ModelID = "openrouter.o1-pro"
|
||||
OpenRouterO1Mini ModelID = "openrouter.o1-mini"
|
||||
OpenRouterO3 ModelID = "openrouter.o3"
|
||||
OpenRouterO3Mini ModelID = "openrouter.o3-mini"
|
||||
OpenRouterO4Mini ModelID = "openrouter.o4-mini"
|
||||
OpenRouterGemini25Flash ModelID = "openrouter.gemini-2.5-flash"
|
||||
OpenRouterGemini25 ModelID = "openrouter.gemini-2.5"
|
||||
OpenRouterClaude35Sonnet ModelID = "openrouter.claude-3.5-sonnet"
|
||||
OpenRouterClaude3Haiku ModelID = "openrouter.claude-3-haiku"
|
||||
OpenRouterClaude37Sonnet ModelID = "openrouter.claude-3.7-sonnet"
|
||||
OpenRouterClaude35Haiku ModelID = "openrouter.claude-3.5-haiku"
|
||||
OpenRouterClaude3Opus ModelID = "openrouter.claude-3-opus"
|
||||
OpenRouterQwen235B ModelID = "openrouter.qwen-3-235b"
|
||||
OpenRouterQwen32B ModelID = "openrouter.qwen-3-32b"
|
||||
OpenRouterQwen30B ModelID = "openrouter.qwen-3-30b"
|
||||
OpenRouterQwen14B ModelID = "openrouter.qwen-3-14b"
|
||||
OpenRouterQwen8B ModelID = "openrouter.qwen-3-8b"
|
||||
)
|
||||
|
||||
var OpenRouterModels = map[ModelID]Model{
|
||||
OpenRouterGPT41: {
|
||||
ID: OpenRouterGPT41,
|
||||
Name: "OpenRouter: GPT 4.1",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/gpt-4.1",
|
||||
CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT41].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT41].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT41].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT41].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT41].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterGPT41Mini: {
|
||||
ID: OpenRouterGPT41Mini,
|
||||
Name: "OpenRouter: GPT 4.1 mini",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/gpt-4.1-mini",
|
||||
CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT41Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT41Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT41Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT41Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT41Mini].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterGPT41Nano: {
|
||||
ID: OpenRouterGPT41Nano,
|
||||
Name: "OpenRouter: GPT 4.1 nano",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/gpt-4.1-nano",
|
||||
CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT41Nano].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT41Nano].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT41Nano].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT41Nano].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT41Nano].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterGPT45Preview: {
|
||||
ID: OpenRouterGPT45Preview,
|
||||
Name: "OpenRouter: GPT 4.5 preview",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/gpt-4.5-preview",
|
||||
CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT45Preview].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT45Preview].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT45Preview].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT45Preview].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT45Preview].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterGPT4o: {
|
||||
ID: OpenRouterGPT4o,
|
||||
Name: "OpenRouter: GPT 4o",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/gpt-4o",
|
||||
CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT4o].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT4o].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT4o].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT4o].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[GPT4o].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterGPT4oMini: {
|
||||
ID: OpenRouterGPT4oMini,
|
||||
Name: "OpenRouter: GPT 4o mini",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/gpt-4o-mini",
|
||||
CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[GPT4oMini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[GPT4oMini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[GPT4oMini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[GPT4oMini].ContextWindow,
|
||||
},
|
||||
OpenRouterO1: {
|
||||
ID: OpenRouterO1,
|
||||
Name: "OpenRouter: O1",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/o1",
|
||||
CostPer1MIn: OpenAIModels[O1].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O1].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O1].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O1].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O1].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O1].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O1].CanReason,
|
||||
},
|
||||
OpenRouterO1Pro: {
|
||||
ID: OpenRouterO1Pro,
|
||||
Name: "OpenRouter: o1 pro",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/o1-pro",
|
||||
CostPer1MIn: OpenAIModels[O1Pro].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O1Pro].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O1Pro].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O1Pro].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O1Pro].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O1Pro].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O1Pro].CanReason,
|
||||
},
|
||||
OpenRouterO1Mini: {
|
||||
ID: OpenRouterO1Mini,
|
||||
Name: "OpenRouter: o1 mini",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/o1-mini",
|
||||
CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O1Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O1Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O1Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O1Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O1Mini].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O1Mini].CanReason,
|
||||
},
|
||||
OpenRouterO3: {
|
||||
ID: OpenRouterO3,
|
||||
Name: "OpenRouter: o3",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/o3",
|
||||
CostPer1MIn: OpenAIModels[O3].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O3].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O3].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O3].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O3].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O3].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O3].CanReason,
|
||||
},
|
||||
OpenRouterO3Mini: {
|
||||
ID: OpenRouterO3Mini,
|
||||
Name: "OpenRouter: o3 mini",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/o3-mini-high",
|
||||
CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O3Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O3Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O3Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O3Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O3Mini].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O3Mini].CanReason,
|
||||
},
|
||||
OpenRouterO4Mini: {
|
||||
ID: OpenRouterO4Mini,
|
||||
Name: "OpenRouter: o4 mini",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "openai/o4-mini-high",
|
||||
CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn,
|
||||
CostPer1MInCached: OpenAIModels[O4Mini].CostPer1MInCached,
|
||||
CostPer1MOut: OpenAIModels[O4Mini].CostPer1MOut,
|
||||
CostPer1MOutCached: OpenAIModels[O4Mini].CostPer1MOutCached,
|
||||
ContextWindow: OpenAIModels[O4Mini].ContextWindow,
|
||||
DefaultMaxTokens: OpenAIModels[O4Mini].DefaultMaxTokens,
|
||||
CanReason: OpenAIModels[O4Mini].CanReason,
|
||||
},
|
||||
OpenRouterGemini25Flash: {
|
||||
ID: OpenRouterGemini25Flash,
|
||||
Name: "OpenRouter: Gemini 2.5 Flash",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "google/gemini-2.5-flash-preview:thinking",
|
||||
CostPer1MIn: GeminiModels[Gemini25Flash].CostPer1MIn,
|
||||
CostPer1MInCached: GeminiModels[Gemini25Flash].CostPer1MInCached,
|
||||
CostPer1MOut: GeminiModels[Gemini25Flash].CostPer1MOut,
|
||||
CostPer1MOutCached: GeminiModels[Gemini25Flash].CostPer1MOutCached,
|
||||
ContextWindow: GeminiModels[Gemini25Flash].ContextWindow,
|
||||
DefaultMaxTokens: GeminiModels[Gemini25Flash].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterGemini25: {
|
||||
ID: OpenRouterGemini25,
|
||||
Name: "OpenRouter: Gemini 2.5 Pro",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "google/gemini-2.5-pro-preview-03-25",
|
||||
CostPer1MIn: GeminiModels[Gemini25].CostPer1MIn,
|
||||
CostPer1MInCached: GeminiModels[Gemini25].CostPer1MInCached,
|
||||
CostPer1MOut: GeminiModels[Gemini25].CostPer1MOut,
|
||||
CostPer1MOutCached: GeminiModels[Gemini25].CostPer1MOutCached,
|
||||
ContextWindow: GeminiModels[Gemini25].ContextWindow,
|
||||
DefaultMaxTokens: GeminiModels[Gemini25].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterClaude35Sonnet: {
|
||||
ID: OpenRouterClaude35Sonnet,
|
||||
Name: "OpenRouter: Claude 3.5 Sonnet",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "anthropic/claude-3.5-sonnet",
|
||||
CostPer1MIn: AnthropicModels[Claude35Sonnet].CostPer1MIn,
|
||||
CostPer1MInCached: AnthropicModels[Claude35Sonnet].CostPer1MInCached,
|
||||
CostPer1MOut: AnthropicModels[Claude35Sonnet].CostPer1MOut,
|
||||
CostPer1MOutCached: AnthropicModels[Claude35Sonnet].CostPer1MOutCached,
|
||||
ContextWindow: AnthropicModels[Claude35Sonnet].ContextWindow,
|
||||
DefaultMaxTokens: AnthropicModels[Claude35Sonnet].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterClaude3Haiku: {
|
||||
ID: OpenRouterClaude3Haiku,
|
||||
Name: "OpenRouter: Claude 3 Haiku",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "anthropic/claude-3-haiku",
|
||||
CostPer1MIn: AnthropicModels[Claude3Haiku].CostPer1MIn,
|
||||
CostPer1MInCached: AnthropicModels[Claude3Haiku].CostPer1MInCached,
|
||||
CostPer1MOut: AnthropicModels[Claude3Haiku].CostPer1MOut,
|
||||
CostPer1MOutCached: AnthropicModels[Claude3Haiku].CostPer1MOutCached,
|
||||
ContextWindow: AnthropicModels[Claude3Haiku].ContextWindow,
|
||||
DefaultMaxTokens: AnthropicModels[Claude3Haiku].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterClaude37Sonnet: {
|
||||
ID: OpenRouterClaude37Sonnet,
|
||||
Name: "OpenRouter: Claude 3.7 Sonnet",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "anthropic/claude-3.7-sonnet",
|
||||
CostPer1MIn: AnthropicModels[Claude37Sonnet].CostPer1MIn,
|
||||
CostPer1MInCached: AnthropicModels[Claude37Sonnet].CostPer1MInCached,
|
||||
CostPer1MOut: AnthropicModels[Claude37Sonnet].CostPer1MOut,
|
||||
CostPer1MOutCached: AnthropicModels[Claude37Sonnet].CostPer1MOutCached,
|
||||
ContextWindow: AnthropicModels[Claude37Sonnet].ContextWindow,
|
||||
DefaultMaxTokens: AnthropicModels[Claude37Sonnet].DefaultMaxTokens,
|
||||
CanReason: AnthropicModels[Claude37Sonnet].CanReason,
|
||||
},
|
||||
OpenRouterClaude35Haiku: {
|
||||
ID: OpenRouterClaude35Haiku,
|
||||
Name: "OpenRouter: Claude 3.5 Haiku",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "anthropic/claude-3.5-haiku",
|
||||
CostPer1MIn: AnthropicModels[Claude35Haiku].CostPer1MIn,
|
||||
CostPer1MInCached: AnthropicModels[Claude35Haiku].CostPer1MInCached,
|
||||
CostPer1MOut: AnthropicModels[Claude35Haiku].CostPer1MOut,
|
||||
CostPer1MOutCached: AnthropicModels[Claude35Haiku].CostPer1MOutCached,
|
||||
ContextWindow: AnthropicModels[Claude35Haiku].ContextWindow,
|
||||
DefaultMaxTokens: AnthropicModels[Claude35Haiku].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterClaude3Opus: {
|
||||
ID: OpenRouterClaude3Opus,
|
||||
Name: "OpenRouter: Claude 3 Opus",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "anthropic/claude-3-opus",
|
||||
CostPer1MIn: AnthropicModels[Claude3Opus].CostPer1MIn,
|
||||
CostPer1MInCached: AnthropicModels[Claude3Opus].CostPer1MInCached,
|
||||
CostPer1MOut: AnthropicModels[Claude3Opus].CostPer1MOut,
|
||||
CostPer1MOutCached: AnthropicModels[Claude3Opus].CostPer1MOutCached,
|
||||
ContextWindow: AnthropicModels[Claude3Opus].ContextWindow,
|
||||
DefaultMaxTokens: AnthropicModels[Claude3Opus].DefaultMaxTokens,
|
||||
},
|
||||
OpenRouterQwen235B: {
|
||||
ID: OpenRouterQwen235B,
|
||||
Name: "OpenRouter: Qwen3 235B A22B",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "qwen/qwen3-235b-a22b",
|
||||
CostPer1MIn: 0.1,
|
||||
CostPer1MInCached: 0.1,
|
||||
CostPer1MOut: 0.1,
|
||||
CostPer1MOutCached: 0.1,
|
||||
ContextWindow: 40960,
|
||||
DefaultMaxTokens: 4096,
|
||||
},
|
||||
OpenRouterQwen32B: {
|
||||
ID: OpenRouterQwen32B,
|
||||
Name: "OpenRouter: Qwen3 32B",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "qwen/qwen3-32b",
|
||||
CostPer1MIn: 0.1,
|
||||
CostPer1MInCached: 0.1,
|
||||
CostPer1MOut: 0.3,
|
||||
CostPer1MOutCached: 0.3,
|
||||
ContextWindow: 40960,
|
||||
DefaultMaxTokens: 4096,
|
||||
},
|
||||
OpenRouterQwen30B: {
|
||||
ID: OpenRouterQwen30B,
|
||||
Name: "OpenRouter: Qwen3 30B A3B",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "qwen/qwen3-30b-a3b",
|
||||
CostPer1MIn: 0.1,
|
||||
CostPer1MInCached: 0.1,
|
||||
CostPer1MOut: 0.3,
|
||||
CostPer1MOutCached: 0.3,
|
||||
ContextWindow: 40960,
|
||||
DefaultMaxTokens: 4096,
|
||||
},
|
||||
OpenRouterQwen14B: {
|
||||
ID: OpenRouterQwen14B,
|
||||
Name: "OpenRouter: Qwen3 14B",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "qwen/qwen3-14b",
|
||||
CostPer1MIn: 0.7,
|
||||
CostPer1MInCached: 0.7,
|
||||
CostPer1MOut: 0.24,
|
||||
CostPer1MOutCached: 0.24,
|
||||
ContextWindow: 40960,
|
||||
DefaultMaxTokens: 4096,
|
||||
},
|
||||
OpenRouterQwen8B: {
|
||||
ID: OpenRouterQwen8B,
|
||||
Name: "OpenRouter: Qwen3 8B",
|
||||
Provider: ProviderOpenRouter,
|
||||
APIModel: "qwen/qwen3-8b",
|
||||
CostPer1MIn: 0.35,
|
||||
CostPer1MInCached: 0.35,
|
||||
CostPer1MOut: 0.138,
|
||||
CostPer1MOutCached: 0.138,
|
||||
ContextWindow: 128000,
|
||||
DefaultMaxTokens: 4096,
|
||||
},
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
package models
|
||||
|
||||
const (
|
||||
ProviderXAI ModelProvider = "xai"
|
||||
|
||||
XAIGrok3Beta ModelID = "grok-3-beta"
|
||||
XAIGrok3MiniBeta ModelID = "grok-3-mini-beta"
|
||||
XAIGrok3FastBeta ModelID = "grok-3-fast-beta"
|
||||
XAiGrok3MiniFastBeta ModelID = "grok-3-mini-fast-beta"
|
||||
)
|
||||
|
||||
var XAIModels = map[ModelID]Model{
|
||||
XAIGrok3Beta: {
|
||||
ID: XAIGrok3Beta,
|
||||
Name: "Grok3 Beta",
|
||||
Provider: ProviderXAI,
|
||||
APIModel: "grok-3-beta",
|
||||
CostPer1MIn: 3.0,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOut: 15,
|
||||
CostPer1MOutCached: 0,
|
||||
ContextWindow: 131_072,
|
||||
DefaultMaxTokens: 20_000,
|
||||
},
|
||||
XAIGrok3MiniBeta: {
|
||||
ID: XAIGrok3MiniBeta,
|
||||
Name: "Grok3 Mini Beta",
|
||||
Provider: ProviderXAI,
|
||||
APIModel: "grok-3-mini-beta",
|
||||
CostPer1MIn: 0.3,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOut: 0.5,
|
||||
CostPer1MOutCached: 0,
|
||||
ContextWindow: 131_072,
|
||||
DefaultMaxTokens: 20_000,
|
||||
},
|
||||
XAIGrok3FastBeta: {
|
||||
ID: XAIGrok3FastBeta,
|
||||
Name: "Grok3 Fast Beta",
|
||||
Provider: ProviderXAI,
|
||||
APIModel: "grok-3-fast-beta",
|
||||
CostPer1MIn: 5,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOut: 25,
|
||||
CostPer1MOutCached: 0,
|
||||
ContextWindow: 131_072,
|
||||
DefaultMaxTokens: 20_000,
|
||||
},
|
||||
XAiGrok3MiniFastBeta: {
|
||||
ID: XAiGrok3MiniFastBeta,
|
||||
Name: "Grok3 Mini Fast Beta",
|
||||
Provider: ProviderXAI,
|
||||
APIModel: "grok-3-mini-fast-beta",
|
||||
CostPer1MIn: 0.6,
|
||||
CostPer1MInCached: 0,
|
||||
CostPer1MOut: 4.0,
|
||||
CostPer1MOutCached: 0,
|
||||
ContextWindow: 131_072,
|
||||
DefaultMaxTokens: 20_000,
|
||||
},
|
||||
}
|
||||
@@ -1,135 +0,0 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
func GetAgentPrompt(agentName config.AgentName, provider models.ModelProvider) string {
|
||||
basePrompt := ""
|
||||
switch agentName {
|
||||
case config.AgentPrimary:
|
||||
basePrompt = PrimaryPrompt(provider)
|
||||
case config.AgentTitle:
|
||||
basePrompt = TitlePrompt(provider)
|
||||
case config.AgentTask:
|
||||
basePrompt = TaskPrompt(provider)
|
||||
default:
|
||||
basePrompt = "You are a helpful assistant"
|
||||
}
|
||||
|
||||
if agentName == config.AgentPrimary || agentName == config.AgentTask {
|
||||
// Add context from project-specific instruction files if they exist
|
||||
contextContent := getContextFromPaths()
|
||||
slog.Debug("Context content", "Context", contextContent)
|
||||
if contextContent != "" {
|
||||
return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent)
|
||||
}
|
||||
}
|
||||
return basePrompt
|
||||
}
|
||||
|
||||
var (
|
||||
onceContext sync.Once
|
||||
contextContent string
|
||||
)
|
||||
|
||||
func getContextFromPaths() string {
|
||||
onceContext.Do(func() {
|
||||
var (
|
||||
cfg = config.Get()
|
||||
workDir = cfg.WorkingDir
|
||||
contextPaths = cfg.ContextPaths
|
||||
)
|
||||
|
||||
contextContent = processContextPaths(workDir, contextPaths)
|
||||
})
|
||||
|
||||
return contextContent
|
||||
}
|
||||
|
||||
func processContextPaths(workDir string, paths []string) string {
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
resultCh = make(chan string)
|
||||
)
|
||||
|
||||
// Track processed files to avoid duplicates
|
||||
processedFiles := make(map[string]bool)
|
||||
var processedMutex sync.Mutex
|
||||
|
||||
for _, path := range paths {
|
||||
wg.Add(1)
|
||||
go func(p string) {
|
||||
defer wg.Done()
|
||||
|
||||
if strings.HasSuffix(p, "/") {
|
||||
filepath.WalkDir(filepath.Join(workDir, p), func(path string, d os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !d.IsDir() {
|
||||
// Check if we've already processed this file (case-insensitive)
|
||||
processedMutex.Lock()
|
||||
lowerPath := strings.ToLower(path)
|
||||
if !processedFiles[lowerPath] {
|
||||
processedFiles[lowerPath] = true
|
||||
processedMutex.Unlock()
|
||||
|
||||
if result := processFile(path); result != "" {
|
||||
resultCh <- result
|
||||
}
|
||||
} else {
|
||||
processedMutex.Unlock()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
} else {
|
||||
fullPath := filepath.Join(workDir, p)
|
||||
|
||||
// Check if we've already processed this file (case-insensitive)
|
||||
processedMutex.Lock()
|
||||
lowerPath := strings.ToLower(fullPath)
|
||||
if !processedFiles[lowerPath] {
|
||||
processedFiles[lowerPath] = true
|
||||
processedMutex.Unlock()
|
||||
|
||||
result := processFile(fullPath)
|
||||
if result != "" {
|
||||
resultCh <- result
|
||||
}
|
||||
} else {
|
||||
processedMutex.Unlock()
|
||||
}
|
||||
}
|
||||
}(path)
|
||||
}
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(resultCh)
|
||||
}()
|
||||
|
||||
results := make([]string, 0)
|
||||
for result := range resultCh {
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
return strings.Join(results, "\n")
|
||||
}
|
||||
|
||||
func processFile(filePath string) string {
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return "# From:" + filePath + "\n" + string(content)
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetContextFromPaths(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lvl := new(slog.LevelVar)
|
||||
lvl.Set(slog.LevelDebug)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
_, err := config.Load(tmpDir, false, lvl)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
cfg := config.Get()
|
||||
cfg.WorkingDir = tmpDir
|
||||
cfg.ContextPaths = []string{
|
||||
"file.txt",
|
||||
"directory/",
|
||||
}
|
||||
testFiles := []string{
|
||||
"file.txt",
|
||||
"directory/file_a.txt",
|
||||
"directory/file_b.txt",
|
||||
"directory/file_c.txt",
|
||||
}
|
||||
|
||||
createTestFiles(t, tmpDir, testFiles)
|
||||
|
||||
context := getContextFromPaths()
|
||||
expectedContext := fmt.Sprintf("# From:%s/file.txt\nfile.txt: test content\n# From:%s/directory/file_a.txt\ndirectory/file_a.txt: test content\n# From:%s/directory/file_b.txt\ndirectory/file_b.txt: test content\n# From:%s/directory/file_c.txt\ndirectory/file_c.txt: test content", tmpDir, tmpDir, tmpDir, tmpDir)
|
||||
assert.Equal(t, expectedContext, context)
|
||||
}
|
||||
|
||||
func createTestFiles(t *testing.T, tmpDir string, testFiles []string) {
|
||||
t.Helper()
|
||||
for _, path := range testFiles {
|
||||
fullPath := filepath.Join(tmpDir, path)
|
||||
if path[len(path)-1] == '/' {
|
||||
err := os.MkdirAll(fullPath, 0755)
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
dir := filepath.Dir(fullPath)
|
||||
err := os.MkdirAll(dir, 0755)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(fullPath, []byte(path+": test content"), 0644)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
)
|
||||
|
||||
func TaskPrompt(_ models.ModelProvider) string {
|
||||
agentPrompt := `You are an agent for OpenCode. Given the user's prompt, you should use the tools available to you to answer the user's question.
|
||||
Notes:
|
||||
1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
|
||||
2. When relevant, share file names and code snippets relevant to the query
|
||||
3. Any file paths you return in your final response MUST be absolute. DO NOT use relative paths.`
|
||||
|
||||
return fmt.Sprintf("%s\n%s\n", agentPrompt, getEnvironmentInfo())
|
||||
}
|
||||
@@ -1,470 +0,0 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/bedrock"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/status"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type anthropicOptions struct {
|
||||
useBedrock bool
|
||||
disableCache bool
|
||||
shouldThink func(userMessage string) bool
|
||||
}
|
||||
|
||||
type AnthropicOption func(*anthropicOptions)
|
||||
|
||||
type anthropicClient struct {
|
||||
providerOptions providerClientOptions
|
||||
options anthropicOptions
|
||||
client anthropic.Client
|
||||
}
|
||||
|
||||
type AnthropicClient ProviderClient
|
||||
|
||||
func newAnthropicClient(opts providerClientOptions) AnthropicClient {
|
||||
anthropicOpts := anthropicOptions{}
|
||||
for _, o := range opts.anthropicOptions {
|
||||
o(&anthropicOpts)
|
||||
}
|
||||
|
||||
anthropicClientOptions := []option.RequestOption{}
|
||||
if opts.apiKey != "" {
|
||||
anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
|
||||
}
|
||||
if anthropicOpts.useBedrock {
|
||||
anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
|
||||
}
|
||||
|
||||
client := anthropic.NewClient(anthropicClientOptions...)
|
||||
return &anthropicClient{
|
||||
providerOptions: opts,
|
||||
options: anthropicOpts,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
|
||||
for i, msg := range messages {
|
||||
cache := false
|
||||
if i > len(messages)-3 {
|
||||
cache = true
|
||||
}
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
content := anthropic.NewTextBlock(msg.Content().String())
|
||||
if cache && !a.options.disableCache {
|
||||
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
}
|
||||
}
|
||||
var contentBlocks []anthropic.ContentBlockParamUnion
|
||||
contentBlocks = append(contentBlocks, content)
|
||||
for _, binaryContent := range msg.BinaryContent() {
|
||||
base64Image := binaryContent.String(models.ProviderAnthropic)
|
||||
imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
|
||||
contentBlocks = append(contentBlocks, imageBlock)
|
||||
}
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(contentBlocks...))
|
||||
|
||||
case message.Assistant:
|
||||
blocks := []anthropic.ContentBlockParamUnion{}
|
||||
|
||||
if msg.Content() != nil {
|
||||
content := msg.Content().String()
|
||||
if strings.TrimSpace(content) != "" {
|
||||
block := anthropic.NewTextBlock(content)
|
||||
if cache && !a.options.disableCache {
|
||||
block.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
}
|
||||
}
|
||||
blocks = append(blocks, block)
|
||||
}
|
||||
}
|
||||
|
||||
for _, toolCall := range msg.ToolCalls() {
|
||||
var inputMap map[string]any
|
||||
err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
|
||||
}
|
||||
|
||||
if len(blocks) == 0 {
|
||||
slog.Warn("There is a message without content, investigate, this should not happen")
|
||||
continue
|
||||
}
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
|
||||
|
||||
case message.Tool:
|
||||
results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
|
||||
for i, toolResult := range msg.ToolResults() {
|
||||
results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
|
||||
}
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
|
||||
anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
|
||||
|
||||
for i, tool := range tools {
|
||||
info := tool.Info()
|
||||
toolParam := anthropic.ToolParam{
|
||||
Name: info.Name,
|
||||
Description: anthropic.String(info.Description),
|
||||
InputSchema: anthropic.ToolInputSchemaParam{
|
||||
Properties: info.Parameters,
|
||||
// TODO: figure out how we can tell claude the required fields?
|
||||
},
|
||||
}
|
||||
|
||||
if i == len(tools)-1 && !a.options.disableCache {
|
||||
toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
}
|
||||
}
|
||||
|
||||
anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
|
||||
}
|
||||
|
||||
return anthropicTools
|
||||
}
|
||||
|
||||
func (a *anthropicClient) finishReason(reason string) message.FinishReason {
|
||||
switch reason {
|
||||
case "end_turn":
|
||||
return message.FinishReasonEndTurn
|
||||
case "max_tokens":
|
||||
return message.FinishReasonMaxTokens
|
||||
case "tool_use":
|
||||
return message.FinishReasonToolUse
|
||||
case "stop_sequence":
|
||||
return message.FinishReasonEndTurn
|
||||
default:
|
||||
return message.FinishReasonUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
|
||||
var thinkingParam anthropic.ThinkingConfigParamUnion
|
||||
lastMessage := messages[len(messages)-1]
|
||||
isUser := lastMessage.Role == anthropic.MessageParamRoleUser
|
||||
messageContent := ""
|
||||
temperature := anthropic.Float(0)
|
||||
if isUser {
|
||||
for _, m := range lastMessage.Content {
|
||||
if m.OfRequestTextBlock != nil && m.OfRequestTextBlock.Text != "" {
|
||||
messageContent = m.OfRequestTextBlock.Text
|
||||
}
|
||||
}
|
||||
if messageContent != "" && a.options.shouldThink != nil && a.options.shouldThink(messageContent) {
|
||||
thinkingParam = anthropic.ThinkingConfigParamUnion{
|
||||
OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
|
||||
BudgetTokens: int64(float64(a.providerOptions.maxTokens) * 0.8),
|
||||
Type: "enabled",
|
||||
},
|
||||
}
|
||||
temperature = anthropic.Float(1)
|
||||
}
|
||||
}
|
||||
|
||||
return anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(a.providerOptions.model.APIModel),
|
||||
MaxTokens: a.providerOptions.maxTokens,
|
||||
Temperature: temperature,
|
||||
Messages: messages,
|
||||
Tools: tools,
|
||||
Thinking: thinkingParam,
|
||||
System: []anthropic.TextBlockParam{
|
||||
{
|
||||
Text: a.providerOptions.systemMessage,
|
||||
CacheControl: anthropic.CacheControlEphemeralParam{
|
||||
Type: "ephemeral",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (resposne *ProviderResponse, err error) {
|
||||
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
jsonData, _ := json.Marshal(preparedMessages)
|
||||
slog.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
|
||||
attempts := 0
|
||||
for {
|
||||
attempts++
|
||||
anthropicResponse, err := a.client.Messages.New(
|
||||
ctx,
|
||||
preparedMessages,
|
||||
)
|
||||
// If there is an error we are going to see if we can retry the call
|
||||
if err != nil {
|
||||
slog.Error("Error in Anthropic API call", "error", err)
|
||||
retry, after, retryErr := a.shouldRetry(attempts, err)
|
||||
if retryErr != nil {
|
||||
return nil, retryErr
|
||||
}
|
||||
if retry {
|
||||
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(time.Duration(after) * time.Millisecond):
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, retryErr
|
||||
}
|
||||
|
||||
content := ""
|
||||
for _, block := range anthropicResponse.Content {
|
||||
if text, ok := block.AsAny().(anthropic.TextBlock); ok {
|
||||
content += text.Text
|
||||
}
|
||||
}
|
||||
|
||||
return &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: a.toolCalls(*anthropicResponse),
|
||||
Usage: a.usage(*anthropicResponse),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
||||
preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
jsonData, _ := json.Marshal(preparedMessages)
|
||||
slog.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
attempts := 0
|
||||
eventChan := make(chan ProviderEvent)
|
||||
go func() {
|
||||
for {
|
||||
attempts++
|
||||
anthropicStream := a.client.Messages.NewStreaming(
|
||||
ctx,
|
||||
preparedMessages,
|
||||
)
|
||||
accumulatedMessage := anthropic.Message{}
|
||||
|
||||
currentToolCallID := ""
|
||||
for anthropicStream.Next() {
|
||||
event := anthropicStream.Current()
|
||||
err := accumulatedMessage.Accumulate(event)
|
||||
if err != nil {
|
||||
slog.Warn("Error accumulating message", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
switch event := event.AsAny().(type) {
|
||||
case anthropic.ContentBlockStartEvent:
|
||||
if event.ContentBlock.Type == "text" {
|
||||
eventChan <- ProviderEvent{Type: EventContentStart}
|
||||
} else if event.ContentBlock.Type == "tool_use" {
|
||||
currentToolCallID = event.ContentBlock.ID
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventToolUseStart,
|
||||
ToolCall: &message.ToolCall{
|
||||
ID: event.ContentBlock.ID,
|
||||
Name: event.ContentBlock.Name,
|
||||
Finished: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
case anthropic.ContentBlockDeltaEvent:
|
||||
if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventThinkingDelta,
|
||||
Thinking: event.Delta.Thinking,
|
||||
}
|
||||
} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventContentDelta,
|
||||
Content: event.Delta.Text,
|
||||
}
|
||||
} else if event.Delta.Type == "input_json_delta" {
|
||||
if currentToolCallID != "" {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventToolUseDelta,
|
||||
ToolCall: &message.ToolCall{
|
||||
ID: currentToolCallID,
|
||||
Finished: false,
|
||||
Input: event.Delta.JSON.PartialJSON.Raw(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
case anthropic.ContentBlockStopEvent:
|
||||
if currentToolCallID != "" {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventToolUseStop,
|
||||
ToolCall: &message.ToolCall{
|
||||
ID: currentToolCallID,
|
||||
},
|
||||
}
|
||||
currentToolCallID = ""
|
||||
} else {
|
||||
eventChan <- ProviderEvent{Type: EventContentStop}
|
||||
}
|
||||
|
||||
case anthropic.MessageStopEvent:
|
||||
content := ""
|
||||
for _, block := range accumulatedMessage.Content {
|
||||
if text, ok := block.AsAny().(anthropic.TextBlock); ok {
|
||||
content += text.Text
|
||||
}
|
||||
}
|
||||
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventComplete,
|
||||
Response: &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: a.toolCalls(accumulatedMessage),
|
||||
Usage: a.usage(accumulatedMessage),
|
||||
FinishReason: a.finishReason(string(accumulatedMessage.StopReason)),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err := anthropicStream.Err()
|
||||
if err == nil || errors.Is(err, io.EOF) {
|
||||
close(eventChan)
|
||||
return
|
||||
}
|
||||
// If there is an error we are going to see if we can retry the call
|
||||
retry, after, retryErr := a.shouldRetry(attempts, err)
|
||||
if retryErr != nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
|
||||
close(eventChan)
|
||||
return
|
||||
}
|
||||
if retry {
|
||||
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// context cancelled
|
||||
if ctx.Err() != nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
|
||||
}
|
||||
close(eventChan)
|
||||
return
|
||||
case <-time.After(time.Duration(after) * time.Millisecond):
|
||||
continue
|
||||
}
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
|
||||
}
|
||||
|
||||
close(eventChan)
|
||||
return
|
||||
}
|
||||
}()
|
||||
return eventChan
|
||||
}
|
||||
|
||||
func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) {
|
||||
var apierr *anthropic.Error
|
||||
if !errors.As(err, &apierr) {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
if attempts > maxRetries {
|
||||
return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
|
||||
}
|
||||
|
||||
retryMs := 0
|
||||
retryAfterValues := apierr.Response.Header.Values("Retry-After")
|
||||
|
||||
backoffMs := 2000 * (1 << (attempts - 1))
|
||||
jitterMs := int(float64(backoffMs) * 0.2)
|
||||
retryMs = backoffMs + jitterMs
|
||||
if len(retryAfterValues) > 0 {
|
||||
if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
|
||||
retryMs = retryMs * 1000
|
||||
}
|
||||
}
|
||||
return true, int64(retryMs), nil
|
||||
}
|
||||
|
||||
func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
|
||||
var toolCalls []message.ToolCall
|
||||
|
||||
for _, block := range msg.Content {
|
||||
switch variant := block.AsAny().(type) {
|
||||
case anthropic.ToolUseBlock:
|
||||
toolCall := message.ToolCall{
|
||||
ID: variant.ID,
|
||||
Name: variant.Name,
|
||||
Input: string(variant.Input),
|
||||
Type: string(variant.Type),
|
||||
Finished: true,
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
}
|
||||
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
|
||||
return TokenUsage{
|
||||
InputTokens: msg.Usage.InputTokens,
|
||||
OutputTokens: msg.Usage.OutputTokens,
|
||||
CacheCreationTokens: msg.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: msg.Usage.CacheReadInputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func WithAnthropicBedrock(useBedrock bool) AnthropicOption {
|
||||
return func(options *anthropicOptions) {
|
||||
options.useBedrock = useBedrock
|
||||
}
|
||||
}
|
||||
|
||||
func WithAnthropicDisableCache() AnthropicOption {
|
||||
return func(options *anthropicOptions) {
|
||||
options.disableCache = true
|
||||
}
|
||||
}
|
||||
|
||||
func DefaultShouldThinkFn(s string) bool {
|
||||
return strings.Contains(strings.ToLower(s), "think")
|
||||
}
|
||||
|
||||
func WithAnthropicShouldThinkFn(fn func(string) bool) AnthropicOption {
|
||||
return func(options *anthropicOptions) {
|
||||
options.shouldThink = fn
|
||||
}
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
||||
"github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/azure"
|
||||
"github.com/openai/openai-go/option"
|
||||
)
|
||||
|
||||
type azureClient struct {
|
||||
*openaiClient
|
||||
}
|
||||
|
||||
type AzureClient ProviderClient
|
||||
|
||||
func newAzureClient(opts providerClientOptions) AzureClient {
|
||||
|
||||
endpoint := os.Getenv("AZURE_OPENAI_ENDPOINT") // ex: https://foo.openai.azure.com
|
||||
apiVersion := os.Getenv("AZURE_OPENAI_API_VERSION") // ex: 2025-04-01-preview
|
||||
|
||||
if endpoint == "" || apiVersion == "" {
|
||||
return &azureClient{openaiClient: newOpenAIClient(opts).(*openaiClient)}
|
||||
}
|
||||
|
||||
reqOpts := []option.RequestOption{
|
||||
azure.WithEndpoint(endpoint, apiVersion),
|
||||
}
|
||||
|
||||
if opts.apiKey != "" || os.Getenv("AZURE_OPENAI_API_KEY") != "" {
|
||||
key := opts.apiKey
|
||||
if key == "" {
|
||||
key = os.Getenv("AZURE_OPENAI_API_KEY")
|
||||
}
|
||||
reqOpts = append(reqOpts, azure.WithAPIKey(key))
|
||||
} else if cred, err := azidentity.NewDefaultAzureCredential(nil); err == nil {
|
||||
reqOpts = append(reqOpts, azure.WithTokenCredential(cred))
|
||||
}
|
||||
|
||||
base := &openaiClient{
|
||||
providerOptions: opts,
|
||||
client: openai.NewClient(reqOpts...),
|
||||
}
|
||||
|
||||
return &azureClient{openaiClient: base}
|
||||
}
|
||||
@@ -1,100 +0,0 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
)
|
||||
|
||||
type bedrockOptions struct {
|
||||
// Bedrock specific options can be added here
|
||||
}
|
||||
|
||||
type BedrockOption func(*bedrockOptions)
|
||||
|
||||
type bedrockClient struct {
|
||||
providerOptions providerClientOptions
|
||||
options bedrockOptions
|
||||
childProvider ProviderClient
|
||||
}
|
||||
|
||||
type BedrockClient ProviderClient
|
||||
|
||||
func newBedrockClient(opts providerClientOptions) BedrockClient {
|
||||
bedrockOpts := bedrockOptions{}
|
||||
// Apply bedrock specific options if they are added in the future
|
||||
|
||||
// Get AWS region from environment
|
||||
region := os.Getenv("AWS_REGION")
|
||||
if region == "" {
|
||||
region = os.Getenv("AWS_DEFAULT_REGION")
|
||||
}
|
||||
|
||||
if region == "" {
|
||||
region = "us-east-1" // default region
|
||||
}
|
||||
if len(region) < 2 {
|
||||
return &bedrockClient{
|
||||
providerOptions: opts,
|
||||
options: bedrockOpts,
|
||||
childProvider: nil, // Will cause an error when used
|
||||
}
|
||||
}
|
||||
|
||||
// Prefix the model name with region
|
||||
regionPrefix := region[:2]
|
||||
modelName := opts.model.APIModel
|
||||
opts.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, modelName)
|
||||
|
||||
// Determine which provider to use based on the model
|
||||
if strings.Contains(string(opts.model.APIModel), "anthropic") {
|
||||
// Create Anthropic client with Bedrock configuration
|
||||
anthropicOpts := opts
|
||||
anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions,
|
||||
WithAnthropicBedrock(true),
|
||||
WithAnthropicDisableCache(),
|
||||
)
|
||||
return &bedrockClient{
|
||||
providerOptions: opts,
|
||||
options: bedrockOpts,
|
||||
childProvider: newAnthropicClient(anthropicOpts),
|
||||
}
|
||||
}
|
||||
|
||||
// Return client with nil childProvider if model is not supported
|
||||
// This will cause an error when used
|
||||
return &bedrockClient{
|
||||
providerOptions: opts,
|
||||
options: bedrockOpts,
|
||||
childProvider: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
if b.childProvider == nil {
|
||||
return nil, errors.New("unsupported model for bedrock provider")
|
||||
}
|
||||
return b.childProvider.send(ctx, messages, tools)
|
||||
}
|
||||
|
||||
func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
if b.childProvider == nil {
|
||||
go func() {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventError,
|
||||
Error: errors.New("unsupported model for bedrock provider"),
|
||||
}
|
||||
close(eventChan)
|
||||
}()
|
||||
return eventChan
|
||||
}
|
||||
|
||||
return b.childProvider.stream(ctx, messages, tools)
|
||||
}
|
||||
@@ -1,547 +0,0 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/status"
|
||||
"google.golang.org/genai"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type geminiOptions struct {
|
||||
disableCache bool
|
||||
}
|
||||
|
||||
type GeminiOption func(*geminiOptions)
|
||||
|
||||
type geminiClient struct {
|
||||
providerOptions providerClientOptions
|
||||
options geminiOptions
|
||||
client *genai.Client
|
||||
}
|
||||
|
||||
type GeminiClient ProviderClient
|
||||
|
||||
func newGeminiClient(opts providerClientOptions) GeminiClient {
|
||||
geminiOpts := geminiOptions{}
|
||||
for _, o := range opts.geminiOptions {
|
||||
o(&geminiOpts)
|
||||
}
|
||||
|
||||
client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
|
||||
if err != nil {
|
||||
slog.Error("Failed to create Gemini client", "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &geminiClient{
|
||||
providerOptions: opts,
|
||||
options: geminiOpts,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
|
||||
var history []*genai.Content
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
var parts []*genai.Part
|
||||
parts = append(parts, &genai.Part{Text: msg.Content().String()})
|
||||
for _, binaryContent := range msg.BinaryContent() {
|
||||
imageFormat := strings.Split(binaryContent.MIMEType, "/")
|
||||
parts = append(parts, &genai.Part{InlineData: &genai.Blob{
|
||||
MIMEType: imageFormat[1],
|
||||
Data: binaryContent.Data,
|
||||
}})
|
||||
}
|
||||
history = append(history, &genai.Content{
|
||||
Parts: parts,
|
||||
Role: "user",
|
||||
})
|
||||
case message.Assistant:
|
||||
content := &genai.Content{
|
||||
Role: "model",
|
||||
Parts: []*genai.Part{},
|
||||
}
|
||||
|
||||
if msg.Content().String() != "" {
|
||||
content.Parts = append(content.Parts, &genai.Part{Text: msg.Content().String()})
|
||||
}
|
||||
|
||||
if len(msg.ToolCalls()) > 0 {
|
||||
for _, call := range msg.ToolCalls() {
|
||||
args, _ := parseJsonToMap(call.Input)
|
||||
content.Parts = append(content.Parts, &genai.Part{
|
||||
FunctionCall: &genai.FunctionCall{
|
||||
Name: call.Name,
|
||||
Args: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
history = append(history, content)
|
||||
|
||||
case message.Tool:
|
||||
for _, result := range msg.ToolResults() {
|
||||
response := map[string]interface{}{"result": result.Content}
|
||||
parsed, err := parseJsonToMap(result.Content)
|
||||
if err == nil {
|
||||
response = parsed
|
||||
}
|
||||
|
||||
var toolCall message.ToolCall
|
||||
for _, m := range messages {
|
||||
if m.Role == message.Assistant {
|
||||
for _, call := range m.ToolCalls() {
|
||||
if call.ID == result.ToolCallID {
|
||||
toolCall = call
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
history = append(history, &genai.Content{
|
||||
Parts: []*genai.Part{
|
||||
{
|
||||
FunctionResponse: &genai.FunctionResponse{
|
||||
Name: toolCall.Name,
|
||||
Response: response,
|
||||
},
|
||||
},
|
||||
},
|
||||
Role: "function",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
|
||||
geminiTool := &genai.Tool{}
|
||||
geminiTool.FunctionDeclarations = make([]*genai.FunctionDeclaration, 0, len(tools))
|
||||
|
||||
for _, tool := range tools {
|
||||
info := tool.Info()
|
||||
declaration := &genai.FunctionDeclaration{
|
||||
Name: info.Name,
|
||||
Description: info.Description,
|
||||
Parameters: &genai.Schema{
|
||||
Type: genai.TypeObject,
|
||||
Properties: convertSchemaProperties(info.Parameters),
|
||||
Required: info.Required,
|
||||
},
|
||||
}
|
||||
|
||||
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, declaration)
|
||||
}
|
||||
|
||||
return []*genai.Tool{geminiTool}
|
||||
}
|
||||
|
||||
func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
|
||||
switch {
|
||||
case reason == genai.FinishReasonStop:
|
||||
return message.FinishReasonEndTurn
|
||||
case reason == genai.FinishReasonMaxTokens:
|
||||
return message.FinishReasonMaxTokens
|
||||
default:
|
||||
return message.FinishReasonUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
// Convert messages
|
||||
geminiMessages := g.convertMessages(messages)
|
||||
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
jsonData, _ := json.Marshal(geminiMessages)
|
||||
slog.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
|
||||
history := geminiMessages[:len(geminiMessages)-1] // All but last message
|
||||
lastMsg := geminiMessages[len(geminiMessages)-1]
|
||||
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{
|
||||
MaxOutputTokens: int32(g.providerOptions.maxTokens),
|
||||
SystemInstruction: &genai.Content{
|
||||
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
|
||||
},
|
||||
Tools: g.convertTools(tools),
|
||||
}, history)
|
||||
|
||||
attempts := 0
|
||||
for {
|
||||
attempts++
|
||||
var toolCalls []message.ToolCall
|
||||
|
||||
var lastMsgParts []genai.Part
|
||||
for _, part := range lastMsg.Parts {
|
||||
lastMsgParts = append(lastMsgParts, *part)
|
||||
}
|
||||
resp, err := chat.SendMessage(ctx, lastMsgParts...)
|
||||
// If there is an error we are going to see if we can retry the call
|
||||
if err != nil {
|
||||
retry, after, retryErr := g.shouldRetry(attempts, err)
|
||||
if retryErr != nil {
|
||||
return nil, retryErr
|
||||
}
|
||||
if retry {
|
||||
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(time.Duration(after) * time.Millisecond):
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, retryErr
|
||||
}
|
||||
|
||||
content := ""
|
||||
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
switch {
|
||||
case part.Text != "":
|
||||
content = string(part.Text)
|
||||
case part.FunctionCall != nil:
|
||||
id := "call_" + uuid.New().String()
|
||||
args, _ := json.Marshal(part.FunctionCall.Args)
|
||||
toolCalls = append(toolCalls, message.ToolCall{
|
||||
ID: id,
|
||||
Name: part.FunctionCall.Name,
|
||||
Input: string(args),
|
||||
Type: "function",
|
||||
Finished: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
finishReason := message.FinishReasonEndTurn
|
||||
if len(resp.Candidates) > 0 {
|
||||
finishReason = g.finishReason(resp.Candidates[0].FinishReason)
|
||||
}
|
||||
if len(toolCalls) > 0 {
|
||||
finishReason = message.FinishReasonToolUse
|
||||
}
|
||||
|
||||
return &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: g.usage(resp),
|
||||
FinishReason: finishReason,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
||||
// Convert messages
|
||||
geminiMessages := g.convertMessages(messages)
|
||||
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
jsonData, _ := json.Marshal(geminiMessages)
|
||||
slog.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
|
||||
history := geminiMessages[:len(geminiMessages)-1] // All but last message
|
||||
lastMsg := geminiMessages[len(geminiMessages)-1]
|
||||
chat, _ := g.client.Chats.Create(ctx, g.providerOptions.model.APIModel, &genai.GenerateContentConfig{
|
||||
MaxOutputTokens: int32(g.providerOptions.maxTokens),
|
||||
SystemInstruction: &genai.Content{
|
||||
Parts: []*genai.Part{{Text: g.providerOptions.systemMessage}},
|
||||
},
|
||||
Tools: g.convertTools(tools),
|
||||
}, history)
|
||||
|
||||
attempts := 0
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
go func() {
|
||||
defer close(eventChan)
|
||||
|
||||
for {
|
||||
attempts++
|
||||
|
||||
currentContent := ""
|
||||
toolCalls := []message.ToolCall{}
|
||||
var finalResp *genai.GenerateContentResponse
|
||||
|
||||
eventChan <- ProviderEvent{Type: EventContentStart}
|
||||
|
||||
var lastMsgParts []genai.Part
|
||||
|
||||
for _, part := range lastMsg.Parts {
|
||||
lastMsgParts = append(lastMsgParts, *part)
|
||||
}
|
||||
for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) {
|
||||
if err != nil {
|
||||
retry, after, retryErr := g.shouldRetry(attempts, err)
|
||||
if retryErr != nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
|
||||
return
|
||||
}
|
||||
if retry {
|
||||
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() != nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
|
||||
}
|
||||
|
||||
return
|
||||
case <-time.After(time.Duration(after) * time.Millisecond):
|
||||
break
|
||||
}
|
||||
} else {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: err}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
finalResp = resp
|
||||
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
switch {
|
||||
case part.Text != "":
|
||||
delta := string(part.Text)
|
||||
if delta != "" {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventContentDelta,
|
||||
Content: delta,
|
||||
}
|
||||
currentContent += delta
|
||||
}
|
||||
case part.FunctionCall != nil:
|
||||
id := "call_" + uuid.New().String()
|
||||
args, _ := json.Marshal(part.FunctionCall.Args)
|
||||
newCall := message.ToolCall{
|
||||
ID: id,
|
||||
Name: part.FunctionCall.Name,
|
||||
Input: string(args),
|
||||
Type: "function",
|
||||
Finished: true,
|
||||
}
|
||||
|
||||
isNew := true
|
||||
for _, existing := range toolCalls {
|
||||
if existing.Name == newCall.Name && existing.Input == newCall.Input {
|
||||
isNew = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isNew {
|
||||
toolCalls = append(toolCalls, newCall)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eventChan <- ProviderEvent{Type: EventContentStop}
|
||||
|
||||
if finalResp != nil {
|
||||
|
||||
finishReason := message.FinishReasonEndTurn
|
||||
if len(finalResp.Candidates) > 0 {
|
||||
finishReason = g.finishReason(finalResp.Candidates[0].FinishReason)
|
||||
}
|
||||
if len(toolCalls) > 0 {
|
||||
finishReason = message.FinishReasonToolUse
|
||||
}
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventComplete,
|
||||
Response: &ProviderResponse{
|
||||
Content: currentContent,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: g.usage(finalResp),
|
||||
FinishReason: finishReason,
|
||||
},
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
|
||||
return eventChan
|
||||
}
|
||||
|
||||
func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
|
||||
// Check if error is a rate limit error
|
||||
if attempts > maxRetries {
|
||||
return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
|
||||
}
|
||||
|
||||
// Gemini doesn't have a standard error type we can check against
|
||||
// So we'll check the error message for rate limit indicators
|
||||
if errors.Is(err, io.EOF) {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
errMsg := err.Error()
|
||||
isRateLimit := false
|
||||
|
||||
// Check for common rate limit error messages
|
||||
if contains(errMsg, "rate limit", "quota exceeded", "too many requests") {
|
||||
isRateLimit = true
|
||||
}
|
||||
|
||||
if !isRateLimit {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
// Calculate backoff with jitter
|
||||
backoffMs := 2000 * (1 << (attempts - 1))
|
||||
jitterMs := int(float64(backoffMs) * 0.2)
|
||||
retryMs := backoffMs + jitterMs
|
||||
|
||||
return true, int64(retryMs), nil
|
||||
}
|
||||
|
||||
func (g *geminiClient) toolCalls(resp *genai.GenerateContentResponse) []message.ToolCall {
|
||||
var toolCalls []message.ToolCall
|
||||
|
||||
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
||||
for _, part := range resp.Candidates[0].Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
id := "call_" + uuid.New().String()
|
||||
args, _ := json.Marshal(part.FunctionCall.Args)
|
||||
toolCalls = append(toolCalls, message.ToolCall{
|
||||
ID: id,
|
||||
Name: part.FunctionCall.Name,
|
||||
Input: string(args),
|
||||
Type: "function",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
|
||||
if resp == nil || resp.UsageMetadata == nil {
|
||||
return TokenUsage{}
|
||||
}
|
||||
|
||||
return TokenUsage{
|
||||
InputTokens: int64(resp.UsageMetadata.PromptTokenCount),
|
||||
OutputTokens: int64(resp.UsageMetadata.CandidatesTokenCount),
|
||||
CacheCreationTokens: 0, // Not directly provided by Gemini
|
||||
CacheReadTokens: int64(resp.UsageMetadata.CachedContentTokenCount),
|
||||
}
|
||||
}
|
||||
|
||||
func WithGeminiDisableCache() GeminiOption {
|
||||
return func(options *geminiOptions) {
|
||||
options.disableCache = true
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func parseJsonToMap(jsonStr string) (map[string]interface{}, error) {
|
||||
var result map[string]interface{}
|
||||
err := json.Unmarshal([]byte(jsonStr), &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func convertSchemaProperties(parameters map[string]interface{}) map[string]*genai.Schema {
|
||||
properties := make(map[string]*genai.Schema)
|
||||
|
||||
for name, param := range parameters {
|
||||
properties[name] = convertToSchema(param)
|
||||
}
|
||||
|
||||
return properties
|
||||
}
|
||||
|
||||
func convertToSchema(param interface{}) *genai.Schema {
|
||||
schema := &genai.Schema{Type: genai.TypeString}
|
||||
|
||||
paramMap, ok := param.(map[string]interface{})
|
||||
if !ok {
|
||||
return schema
|
||||
}
|
||||
|
||||
if desc, ok := paramMap["description"].(string); ok {
|
||||
schema.Description = desc
|
||||
}
|
||||
|
||||
typeVal, hasType := paramMap["type"]
|
||||
if !hasType {
|
||||
return schema
|
||||
}
|
||||
|
||||
typeStr, ok := typeVal.(string)
|
||||
if !ok {
|
||||
return schema
|
||||
}
|
||||
|
||||
schema.Type = mapJSONTypeToGenAI(typeStr)
|
||||
|
||||
switch typeStr {
|
||||
case "array":
|
||||
schema.Items = processArrayItems(paramMap)
|
||||
case "object":
|
||||
if props, ok := paramMap["properties"].(map[string]interface{}); ok {
|
||||
schema.Properties = convertSchemaProperties(props)
|
||||
}
|
||||
}
|
||||
|
||||
return schema
|
||||
}
|
||||
|
||||
func processArrayItems(paramMap map[string]interface{}) *genai.Schema {
|
||||
items, ok := paramMap["items"].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return convertToSchema(items)
|
||||
}
|
||||
|
||||
func mapJSONTypeToGenAI(jsonType string) genai.Type {
|
||||
switch jsonType {
|
||||
case "string":
|
||||
return genai.TypeString
|
||||
case "number":
|
||||
return genai.TypeNumber
|
||||
case "integer":
|
||||
return genai.TypeInteger
|
||||
case "boolean":
|
||||
return genai.TypeBoolean
|
||||
case "array":
|
||||
return genai.TypeArray
|
||||
case "object":
|
||||
return genai.TypeObject
|
||||
default:
|
||||
return genai.TypeString // Default to string for unknown types
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s string, substrs ...string) bool {
|
||||
for _, substr := range substrs {
|
||||
if strings.Contains(strings.ToLower(s), strings.ToLower(substr)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,434 +0,0 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/option"
|
||||
"github.com/openai/openai-go/shared"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/status"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type openaiOptions struct {
|
||||
baseURL string
|
||||
disableCache bool
|
||||
reasoningEffort string
|
||||
extraHeaders map[string]string
|
||||
}
|
||||
|
||||
type OpenAIOption func(*openaiOptions)
|
||||
|
||||
type openaiClient struct {
|
||||
providerOptions providerClientOptions
|
||||
options openaiOptions
|
||||
client openai.Client
|
||||
}
|
||||
|
||||
type OpenAIClient ProviderClient
|
||||
|
||||
func newOpenAIClient(opts providerClientOptions) OpenAIClient {
|
||||
openaiOpts := openaiOptions{
|
||||
reasoningEffort: "medium",
|
||||
}
|
||||
for _, o := range opts.openaiOptions {
|
||||
o(&openaiOpts)
|
||||
}
|
||||
|
||||
openaiClientOptions := []option.RequestOption{}
|
||||
if opts.apiKey != "" {
|
||||
openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
|
||||
}
|
||||
if openaiOpts.baseURL != "" {
|
||||
openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL))
|
||||
}
|
||||
|
||||
if openaiOpts.extraHeaders != nil {
|
||||
for key, value := range openaiOpts.extraHeaders {
|
||||
openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
|
||||
}
|
||||
}
|
||||
|
||||
client := openai.NewClient(openaiClientOptions...)
|
||||
return &openaiClient{
|
||||
providerOptions: opts,
|
||||
options: openaiOpts,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
|
||||
// Add system message first
|
||||
openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case message.User:
|
||||
var content []openai.ChatCompletionContentPartUnionParam
|
||||
textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
|
||||
content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
|
||||
for _, binaryContent := range msg.BinaryContent() {
|
||||
imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(models.ProviderOpenAI)}
|
||||
imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
|
||||
|
||||
content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
|
||||
}
|
||||
|
||||
openaiMessages = append(openaiMessages, openai.UserMessage(content))
|
||||
|
||||
case message.Assistant:
|
||||
assistantMsg := openai.ChatCompletionAssistantMessageParam{
|
||||
Role: "assistant",
|
||||
}
|
||||
|
||||
if msg.Content().String() != "" {
|
||||
assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
|
||||
OfString: openai.String(msg.Content().String()),
|
||||
}
|
||||
}
|
||||
|
||||
if len(msg.ToolCalls()) > 0 {
|
||||
assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
|
||||
for i, call := range msg.ToolCalls() {
|
||||
assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
|
||||
ID: call.ID,
|
||||
Type: "function",
|
||||
Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
||||
Name: call.Name,
|
||||
Arguments: call.Input,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
|
||||
OfAssistant: &assistantMsg,
|
||||
})
|
||||
|
||||
case message.Tool:
|
||||
for _, result := range msg.ToolResults() {
|
||||
openaiMessages = append(openaiMessages,
|
||||
openai.ToolMessage(result.Content, result.ToolCallID),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
|
||||
openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
|
||||
|
||||
for i, tool := range tools {
|
||||
info := tool.Info()
|
||||
openaiTools[i] = openai.ChatCompletionToolParam{
|
||||
Function: openai.FunctionDefinitionParam{
|
||||
Name: info.Name,
|
||||
Description: openai.String(info.Description),
|
||||
Parameters: openai.FunctionParameters{
|
||||
"type": "object",
|
||||
"properties": info.Parameters,
|
||||
"required": info.Required,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return openaiTools
|
||||
}
|
||||
|
||||
func (o *openaiClient) finishReason(reason string) message.FinishReason {
|
||||
switch reason {
|
||||
case "stop":
|
||||
return message.FinishReasonEndTurn
|
||||
case "length":
|
||||
return message.FinishReasonMaxTokens
|
||||
case "tool_calls":
|
||||
return message.FinishReasonToolUse
|
||||
default:
|
||||
return message.FinishReasonUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
|
||||
params := openai.ChatCompletionNewParams{
|
||||
Model: openai.ChatModel(o.providerOptions.model.APIModel),
|
||||
Messages: messages,
|
||||
Tools: tools,
|
||||
}
|
||||
|
||||
if o.providerOptions.model.CanReason == true {
|
||||
params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens)
|
||||
switch o.options.reasoningEffort {
|
||||
case "low":
|
||||
params.ReasoningEffort = shared.ReasoningEffortLow
|
||||
case "medium":
|
||||
params.ReasoningEffort = shared.ReasoningEffortMedium
|
||||
case "high":
|
||||
params.ReasoningEffort = shared.ReasoningEffortHigh
|
||||
default:
|
||||
params.ReasoningEffort = shared.ReasoningEffortMedium
|
||||
}
|
||||
} else {
|
||||
params.MaxTokens = openai.Int(o.providerOptions.maxTokens)
|
||||
}
|
||||
|
||||
if o.providerOptions.model.Provider == models.ProviderOpenRouter {
|
||||
params.WithExtraFields(map[string]any{
|
||||
"provider": map[string]any{
|
||||
"require_parameters": true,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
|
||||
params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
jsonData, _ := json.Marshal(params)
|
||||
slog.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
attempts := 0
|
||||
for {
|
||||
attempts++
|
||||
openaiResponse, err := o.client.Chat.Completions.New(
|
||||
ctx,
|
||||
params,
|
||||
)
|
||||
// If there is an error we are going to see if we can retry the call
|
||||
if err != nil {
|
||||
retry, after, retryErr := o.shouldRetry(attempts, err)
|
||||
if retryErr != nil {
|
||||
return nil, retryErr
|
||||
}
|
||||
if retry {
|
||||
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(time.Duration(after) * time.Millisecond):
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, retryErr
|
||||
}
|
||||
|
||||
content := ""
|
||||
if openaiResponse.Choices[0].Message.Content != "" {
|
||||
content = openaiResponse.Choices[0].Message.Content
|
||||
}
|
||||
|
||||
toolCalls := o.toolCalls(*openaiResponse)
|
||||
finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
finishReason = message.FinishReasonToolUse
|
||||
}
|
||||
|
||||
return &ProviderResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: o.usage(*openaiResponse),
|
||||
FinishReason: finishReason,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
||||
params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
|
||||
params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
|
||||
IncludeUsage: openai.Bool(true),
|
||||
}
|
||||
|
||||
cfg := config.Get()
|
||||
if cfg.Debug {
|
||||
jsonData, _ := json.Marshal(params)
|
||||
slog.Debug("Prepared messages", "messages", string(jsonData))
|
||||
}
|
||||
|
||||
attempts := 0
|
||||
eventChan := make(chan ProviderEvent)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
attempts++
|
||||
openaiStream := o.client.Chat.Completions.NewStreaming(
|
||||
ctx,
|
||||
params,
|
||||
)
|
||||
|
||||
acc := openai.ChatCompletionAccumulator{}
|
||||
currentContent := ""
|
||||
toolCalls := make([]message.ToolCall, 0)
|
||||
|
||||
for openaiStream.Next() {
|
||||
chunk := openaiStream.Current()
|
||||
acc.AddChunk(chunk)
|
||||
|
||||
for _, choice := range chunk.Choices {
|
||||
if choice.Delta.Content != "" {
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventContentDelta,
|
||||
Content: choice.Delta.Content,
|
||||
}
|
||||
currentContent += choice.Delta.Content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err := openaiStream.Err()
|
||||
if err == nil || errors.Is(err, io.EOF) {
|
||||
// Stream completed successfully
|
||||
finishReason := o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason))
|
||||
if len(acc.ChatCompletion.Choices[0].Message.ToolCalls) > 0 {
|
||||
toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
|
||||
}
|
||||
if len(toolCalls) > 0 {
|
||||
finishReason = message.FinishReasonToolUse
|
||||
}
|
||||
|
||||
eventChan <- ProviderEvent{
|
||||
Type: EventComplete,
|
||||
Response: &ProviderResponse{
|
||||
Content: currentContent,
|
||||
ToolCalls: toolCalls,
|
||||
Usage: o.usage(acc.ChatCompletion),
|
||||
FinishReason: finishReason,
|
||||
},
|
||||
}
|
||||
close(eventChan)
|
||||
return
|
||||
}
|
||||
|
||||
// If there is an error we are going to see if we can retry the call
|
||||
retry, after, retryErr := o.shouldRetry(attempts, err)
|
||||
if retryErr != nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
|
||||
close(eventChan)
|
||||
return
|
||||
}
|
||||
if retry {
|
||||
status.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// context cancelled
|
||||
if ctx.Err() == nil {
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
|
||||
}
|
||||
close(eventChan)
|
||||
return
|
||||
case <-time.After(time.Duration(after) * time.Millisecond):
|
||||
continue
|
||||
}
|
||||
}
|
||||
eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
|
||||
close(eventChan)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
return eventChan
|
||||
}
|
||||
|
||||
func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
|
||||
var apierr *openai.Error
|
||||
if !errors.As(err, &apierr) {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
if apierr.StatusCode != 429 && apierr.StatusCode != 500 {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
if attempts > maxRetries {
|
||||
return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
|
||||
}
|
||||
|
||||
retryMs := 0
|
||||
retryAfterValues := apierr.Response.Header.Values("Retry-After")
|
||||
|
||||
backoffMs := 2000 * (1 << (attempts - 1))
|
||||
jitterMs := int(float64(backoffMs) * 0.2)
|
||||
retryMs = backoffMs + jitterMs
|
||||
if len(retryAfterValues) > 0 {
|
||||
if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
|
||||
retryMs = retryMs * 1000
|
||||
}
|
||||
}
|
||||
return true, int64(retryMs), nil
|
||||
}
|
||||
|
||||
func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
|
||||
var toolCalls []message.ToolCall
|
||||
|
||||
if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
|
||||
for _, call := range completion.Choices[0].Message.ToolCalls {
|
||||
toolCall := message.ToolCall{
|
||||
ID: call.ID,
|
||||
Name: call.Function.Name,
|
||||
Input: call.Function.Arguments,
|
||||
Type: "function",
|
||||
Finished: true,
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
}
|
||||
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
|
||||
cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
|
||||
inputTokens := completion.Usage.PromptTokens - cachedTokens
|
||||
|
||||
return TokenUsage{
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: completion.Usage.CompletionTokens,
|
||||
CacheCreationTokens: 0, // OpenAI doesn't provide this directly
|
||||
CacheReadTokens: cachedTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenAIBaseURL(baseURL string) OpenAIOption {
|
||||
return func(options *openaiOptions) {
|
||||
options.baseURL = baseURL
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenAIExtraHeaders(headers map[string]string) OpenAIOption {
|
||||
return func(options *openaiOptions) {
|
||||
options.extraHeaders = headers
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenAIDisableCache() OpenAIOption {
|
||||
return func(options *openaiOptions) {
|
||||
options.disableCache = true
|
||||
}
|
||||
}
|
||||
|
||||
func WithReasoningEffort(effort string) OpenAIOption {
|
||||
return func(options *openaiOptions) {
|
||||
defaultReasoningEffort := "medium"
|
||||
switch effort {
|
||||
case "low", "medium", "high":
|
||||
defaultReasoningEffort = effort
|
||||
default:
|
||||
slog.Warn("Invalid reasoning effort, using default: medium")
|
||||
}
|
||||
options.reasoningEffort = defaultReasoningEffort
|
||||
}
|
||||
}
|
||||
@@ -1,264 +0,0 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type EventType string
|
||||
|
||||
const maxRetries = 8
|
||||
|
||||
const (
|
||||
EventContentStart EventType = "content_start"
|
||||
EventToolUseStart EventType = "tool_use_start"
|
||||
EventToolUseDelta EventType = "tool_use_delta"
|
||||
EventToolUseStop EventType = "tool_use_stop"
|
||||
EventContentDelta EventType = "content_delta"
|
||||
EventThinkingDelta EventType = "thinking_delta"
|
||||
EventContentStop EventType = "content_stop"
|
||||
EventComplete EventType = "complete"
|
||||
EventError EventType = "error"
|
||||
EventWarning EventType = "warning"
|
||||
)
|
||||
|
||||
type TokenUsage struct {
|
||||
InputTokens int64
|
||||
OutputTokens int64
|
||||
CacheCreationTokens int64
|
||||
CacheReadTokens int64
|
||||
}
|
||||
|
||||
type ProviderResponse struct {
|
||||
Content string
|
||||
ToolCalls []message.ToolCall
|
||||
Usage TokenUsage
|
||||
FinishReason message.FinishReason
|
||||
}
|
||||
|
||||
type ProviderEvent struct {
|
||||
Type EventType
|
||||
|
||||
Content string
|
||||
Thinking string
|
||||
Response *ProviderResponse
|
||||
ToolCall *message.ToolCall
|
||||
Error error
|
||||
}
|
||||
type Provider interface {
|
||||
SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
|
||||
|
||||
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
|
||||
|
||||
Model() models.Model
|
||||
|
||||
MaxTokens() int64
|
||||
}
|
||||
|
||||
type providerClientOptions struct {
|
||||
apiKey string
|
||||
model models.Model
|
||||
maxTokens int64
|
||||
systemMessage string
|
||||
|
||||
anthropicOptions []AnthropicOption
|
||||
openaiOptions []OpenAIOption
|
||||
geminiOptions []GeminiOption
|
||||
bedrockOptions []BedrockOption
|
||||
}
|
||||
|
||||
type ProviderClientOption func(*providerClientOptions)
|
||||
|
||||
type ProviderClient interface {
|
||||
send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
|
||||
stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
|
||||
}
|
||||
|
||||
type baseProvider[C ProviderClient] struct {
|
||||
options providerClientOptions
|
||||
client C
|
||||
}
|
||||
|
||||
func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) {
|
||||
clientOptions := providerClientOptions{}
|
||||
for _, o := range opts {
|
||||
o(&clientOptions)
|
||||
}
|
||||
switch providerName {
|
||||
case models.ProviderAnthropic:
|
||||
return &baseProvider[AnthropicClient]{
|
||||
options: clientOptions,
|
||||
client: newAnthropicClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderOpenAI:
|
||||
return &baseProvider[OpenAIClient]{
|
||||
options: clientOptions,
|
||||
client: newOpenAIClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderGemini:
|
||||
return &baseProvider[GeminiClient]{
|
||||
options: clientOptions,
|
||||
client: newGeminiClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderBedrock:
|
||||
return &baseProvider[BedrockClient]{
|
||||
options: clientOptions,
|
||||
client: newBedrockClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderGROQ:
|
||||
clientOptions.openaiOptions = append(clientOptions.openaiOptions,
|
||||
WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
|
||||
)
|
||||
return &baseProvider[OpenAIClient]{
|
||||
options: clientOptions,
|
||||
client: newOpenAIClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderAzure:
|
||||
return &baseProvider[AzureClient]{
|
||||
options: clientOptions,
|
||||
client: newAzureClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderOpenRouter:
|
||||
clientOptions.openaiOptions = append(clientOptions.openaiOptions,
|
||||
WithOpenAIBaseURL("https://openrouter.ai/api/v1"),
|
||||
WithOpenAIExtraHeaders(map[string]string{
|
||||
"HTTP-Referer": "opencode.ai",
|
||||
"X-Title": "OpenCode",
|
||||
}),
|
||||
)
|
||||
return &baseProvider[OpenAIClient]{
|
||||
options: clientOptions,
|
||||
client: newOpenAIClient(clientOptions),
|
||||
}, nil
|
||||
case models.ProviderXAI:
|
||||
clientOptions.openaiOptions = append(clientOptions.openaiOptions,
|
||||
WithOpenAIBaseURL("https://api.x.ai/v1"),
|
||||
)
|
||||
return &baseProvider[OpenAIClient]{
|
||||
options: clientOptions,
|
||||
client: newOpenAIClient(clientOptions),
|
||||
}, nil
|
||||
|
||||
case models.ProviderMock:
|
||||
// TODO: implement mock client for test
|
||||
panic("not implemented")
|
||||
}
|
||||
return nil, fmt.Errorf("provider not supported: %s", providerName)
|
||||
}
|
||||
|
||||
func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
|
||||
for _, msg := range messages {
|
||||
// The message has no content
|
||||
if len(msg.Parts) == 0 {
|
||||
continue
|
||||
}
|
||||
cleaned = append(cleaned, msg)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
|
||||
messages = p.cleanMessages(messages)
|
||||
response, err := p.client.send(ctx, messages, tools)
|
||||
if err == nil && response != nil {
|
||||
slog.Debug("API request token usage",
|
||||
"model", p.options.model.Name,
|
||||
"input_tokens", response.Usage.InputTokens,
|
||||
"output_tokens", response.Usage.OutputTokens,
|
||||
"cache_creation_tokens", response.Usage.CacheCreationTokens,
|
||||
"cache_read_tokens", response.Usage.CacheReadTokens,
|
||||
"total_tokens", response.Usage.InputTokens+response.Usage.OutputTokens)
|
||||
}
|
||||
return response, err
|
||||
}
|
||||
|
||||
func (p *baseProvider[C]) Model() models.Model {
|
||||
return p.options.model
|
||||
}
|
||||
|
||||
func (p *baseProvider[C]) MaxTokens() int64 {
|
||||
return p.options.maxTokens
|
||||
}
|
||||
|
||||
func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
|
||||
messages = p.cleanMessages(messages)
|
||||
eventChan := p.client.stream(ctx, messages, tools)
|
||||
|
||||
// Create a new channel to intercept events
|
||||
wrappedChan := make(chan ProviderEvent)
|
||||
|
||||
go func() {
|
||||
defer close(wrappedChan)
|
||||
|
||||
for event := range eventChan {
|
||||
// Pass the event through
|
||||
wrappedChan <- event
|
||||
|
||||
// Log token usage when we get the complete event
|
||||
if event.Type == EventComplete && event.Response != nil {
|
||||
slog.Debug("API streaming request token usage",
|
||||
"model", p.options.model.Name,
|
||||
"input_tokens", event.Response.Usage.InputTokens,
|
||||
"output_tokens", event.Response.Usage.OutputTokens,
|
||||
"cache_creation_tokens", event.Response.Usage.CacheCreationTokens,
|
||||
"cache_read_tokens", event.Response.Usage.CacheReadTokens,
|
||||
"total_tokens", event.Response.Usage.InputTokens+event.Response.Usage.OutputTokens)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return wrappedChan
|
||||
}
|
||||
|
||||
func WithAPIKey(apiKey string) ProviderClientOption {
|
||||
return func(options *providerClientOptions) {
|
||||
options.apiKey = apiKey
|
||||
}
|
||||
}
|
||||
|
||||
func WithModel(model models.Model) ProviderClientOption {
|
||||
return func(options *providerClientOptions) {
|
||||
options.model = model
|
||||
}
|
||||
}
|
||||
|
||||
func WithMaxTokens(maxTokens int64) ProviderClientOption {
|
||||
return func(options *providerClientOptions) {
|
||||
options.maxTokens = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
func WithSystemMessage(systemMessage string) ProviderClientOption {
|
||||
return func(options *providerClientOptions) {
|
||||
options.systemMessage = systemMessage
|
||||
}
|
||||
}
|
||||
|
||||
func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption {
|
||||
return func(options *providerClientOptions) {
|
||||
options.anthropicOptions = anthropicOptions
|
||||
}
|
||||
}
|
||||
|
||||
func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption {
|
||||
return func(options *providerClientOptions) {
|
||||
options.openaiOptions = openaiOptions
|
||||
}
|
||||
}
|
||||
|
||||
func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption {
|
||||
return func(options *providerClientOptions) {
|
||||
options.geminiOptions = geminiOptions
|
||||
}
|
||||
}
|
||||
|
||||
func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption {
|
||||
return func(options *providerClientOptions) {
|
||||
options.bedrockOptions = bedrockOptions
|
||||
}
|
||||
}
|
||||
@@ -1,492 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/diff"
|
||||
"github.com/sst/opencode/internal/history"
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/permission"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type EditParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
OldString string `json:"old_string"`
|
||||
NewString string `json:"new_string"`
|
||||
}
|
||||
|
||||
type EditPermissionsParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
Diff string `json:"diff"`
|
||||
}
|
||||
|
||||
type EditResponseMetadata struct {
|
||||
Diff string `json:"diff"`
|
||||
Additions int `json:"additions"`
|
||||
Removals int `json:"removals"`
|
||||
}
|
||||
|
||||
type editTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
permissions permission.Service
|
||||
history history.Service
|
||||
}
|
||||
|
||||
const (
|
||||
EditToolName = "edit"
|
||||
editDescription = `Edits files by replacing text, creating new files, or deleting content. For moving or renaming files, use the Bash tool with the 'mv' command instead. For larger file edits, use the FileWrite tool to overwrite files.
|
||||
|
||||
Before using this tool:
|
||||
|
||||
1. Use the FileRead tool to understand the file's contents and context
|
||||
|
||||
2. Verify the directory path is correct (only applicable when creating new files):
|
||||
- Use the LS tool to verify the parent directory exists and is the correct location
|
||||
|
||||
To make a file edit, provide the following:
|
||||
1. file_path: The absolute path to the file to modify (must be absolute, not relative)
|
||||
2. old_string: The text to replace (must be unique within the file, and must match the file contents exactly, including all whitespace and indentation)
|
||||
3. new_string: The edited text to replace the old_string
|
||||
|
||||
Special cases:
|
||||
- To create a new file: provide file_path and new_string, leave old_string empty
|
||||
- To delete content: provide file_path and old_string, leave new_string empty
|
||||
|
||||
The tool will replace ONE occurrence of old_string with new_string in the specified file.
|
||||
|
||||
CRITICAL REQUIREMENTS FOR USING THIS TOOL:
|
||||
|
||||
1. UNIQUENESS: The old_string MUST uniquely identify the specific instance you want to change. This means:
|
||||
- Include AT LEAST 3-5 lines of context BEFORE the change point
|
||||
- Include AT LEAST 3-5 lines of context AFTER the change point
|
||||
- Include all whitespace, indentation, and surrounding code exactly as it appears in the file
|
||||
|
||||
2. SINGLE INSTANCE: This tool can only change ONE instance at a time. If you need to change multiple instances:
|
||||
- Make separate calls to this tool for each instance
|
||||
- Each call must uniquely identify its specific instance using extensive context
|
||||
|
||||
3. VERIFICATION: Before using this tool:
|
||||
- Check how many instances of the target text exist in the file
|
||||
- If multiple instances exist, gather enough context to uniquely identify each one
|
||||
- Plan separate tool calls for each instance
|
||||
|
||||
WARNING: If you do not follow these requirements:
|
||||
- The tool will fail if old_string matches multiple locations
|
||||
- The tool will fail if old_string doesn't match exactly (including whitespace)
|
||||
- You may change the wrong instance if you don't include enough context
|
||||
|
||||
When making edits:
|
||||
- Ensure the edit results in idiomatic, correct code
|
||||
- Do not leave the code in a broken state
|
||||
- Always use absolute file paths (starting with /)
|
||||
|
||||
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.`
|
||||
)
|
||||
|
||||
func NewEditTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool {
|
||||
return &editTool{
|
||||
lspClients: lspClients,
|
||||
permissions: permissions,
|
||||
history: files,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *editTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: EditToolName,
|
||||
Description: editDescription,
|
||||
Parameters: map[string]any{
|
||||
"file_path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The absolute path to the file to modify",
|
||||
},
|
||||
"old_string": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The text to replace",
|
||||
},
|
||||
"new_string": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The text to replace it with",
|
||||
},
|
||||
},
|
||||
Required: []string{"file_path", "old_string", "new_string"},
|
||||
}
|
||||
}
|
||||
|
||||
func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params EditParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse("invalid parameters"), nil
|
||||
}
|
||||
|
||||
if params.FilePath == "" {
|
||||
return NewTextErrorResponse("file_path is required"), nil
|
||||
}
|
||||
|
||||
if !filepath.IsAbs(params.FilePath) {
|
||||
wd := config.WorkingDirectory()
|
||||
params.FilePath = filepath.Join(wd, params.FilePath)
|
||||
}
|
||||
|
||||
var response ToolResponse
|
||||
var err error
|
||||
|
||||
if params.OldString == "" {
|
||||
response, err = e.createNewFile(ctx, params.FilePath, params.NewString)
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
}
|
||||
|
||||
if params.NewString == "" {
|
||||
response, err = e.deleteContent(ctx, params.FilePath, params.OldString)
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
}
|
||||
|
||||
response, err = e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString)
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
if response.IsError {
|
||||
// Return early if there was an error during content replacement
|
||||
// This prevents unnecessary LSP diagnostics processing
|
||||
return response, nil
|
||||
}
|
||||
|
||||
waitForLspDiagnostics(ctx, params.FilePath, e.lspClients)
|
||||
text := fmt.Sprintf("<result>\n%s\n</result>\n", response.Content)
|
||||
text += getDiagnostics(params.FilePath, e.lspClients)
|
||||
response.Content = text
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (e *editTool) createNewFile(ctx context.Context, filePath, content string) (ToolResponse, error) {
|
||||
fileInfo, err := os.Stat(filePath)
|
||||
if err == nil {
|
||||
if fileInfo.IsDir() {
|
||||
return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
|
||||
}
|
||||
return NewTextErrorResponse(fmt.Sprintf("file already exists: %s", filePath)), nil
|
||||
} else if !os.IsNotExist(err) {
|
||||
return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
|
||||
}
|
||||
|
||||
dir := filepath.Dir(filePath)
|
||||
if err = os.MkdirAll(dir, 0o755); err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
|
||||
}
|
||||
|
||||
sessionID, messageID := GetContextValues(ctx)
|
||||
if sessionID == "" || messageID == "" {
|
||||
return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
|
||||
}
|
||||
|
||||
diff, additions, removals := diff.GenerateDiff(
|
||||
"",
|
||||
content,
|
||||
filePath,
|
||||
)
|
||||
rootDir := config.WorkingDirectory()
|
||||
permissionPath := filepath.Dir(filePath)
|
||||
if strings.HasPrefix(filePath, rootDir) {
|
||||
permissionPath = rootDir
|
||||
}
|
||||
p := e.permissions.Request(
|
||||
ctx,
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: permissionPath,
|
||||
ToolName: EditToolName,
|
||||
Action: "write",
|
||||
Description: fmt.Sprintf("Create file %s", filePath),
|
||||
Params: EditPermissionsParams{
|
||||
FilePath: filePath,
|
||||
Diff: diff,
|
||||
},
|
||||
},
|
||||
)
|
||||
if !p {
|
||||
return ToolResponse{}, permission.ErrorPermissionDenied
|
||||
}
|
||||
|
||||
err = os.WriteFile(filePath, []byte(content), 0o644)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
// File can't be in the history so we create a new file history
|
||||
_, err = e.history.Create(ctx, sessionID, filePath, "")
|
||||
if err != nil {
|
||||
// Log error but don't fail the operation
|
||||
return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
|
||||
}
|
||||
|
||||
// Add the new content to the file history
|
||||
_, err = e.history.CreateVersion(ctx, sessionID, filePath, content)
|
||||
if err != nil {
|
||||
// Log error but don't fail the operation
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
|
||||
recordFileWrite(filePath)
|
||||
recordFileRead(filePath)
|
||||
|
||||
return WithResponseMetadata(
|
||||
NewTextResponse("File created: "+filePath),
|
||||
EditResponseMetadata{
|
||||
Diff: diff,
|
||||
Additions: additions,
|
||||
Removals: removals,
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string) (ToolResponse, error) {
|
||||
fileInfo, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return NewTextErrorResponse(fmt.Sprintf("file not found: %s", filePath)), nil
|
||||
}
|
||||
return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
|
||||
}
|
||||
|
||||
if fileInfo.IsDir() {
|
||||
return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
|
||||
}
|
||||
|
||||
if getLastReadTime(filePath).IsZero() {
|
||||
return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
|
||||
}
|
||||
|
||||
modTime := fileInfo.ModTime()
|
||||
lastRead := getLastReadTime(filePath)
|
||||
if modTime.After(lastRead) {
|
||||
return NewTextErrorResponse(
|
||||
fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
|
||||
filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339),
|
||||
)), nil
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
oldContent := string(content)
|
||||
|
||||
index := strings.Index(oldContent, oldString)
|
||||
if index == -1 {
|
||||
return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
|
||||
}
|
||||
|
||||
lastIndex := strings.LastIndex(oldContent, oldString)
|
||||
if index != lastIndex {
|
||||
return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match"), nil
|
||||
}
|
||||
|
||||
newContent := oldContent[:index] + oldContent[index+len(oldString):]
|
||||
|
||||
sessionID, messageID := GetContextValues(ctx)
|
||||
|
||||
if sessionID == "" || messageID == "" {
|
||||
return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
|
||||
}
|
||||
|
||||
diff, additions, removals := diff.GenerateDiff(
|
||||
oldContent,
|
||||
newContent,
|
||||
filePath,
|
||||
)
|
||||
|
||||
rootDir := config.WorkingDirectory()
|
||||
permissionPath := filepath.Dir(filePath)
|
||||
if strings.HasPrefix(filePath, rootDir) {
|
||||
permissionPath = rootDir
|
||||
}
|
||||
p := e.permissions.Request(
|
||||
ctx,
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: permissionPath,
|
||||
ToolName: EditToolName,
|
||||
Action: "write",
|
||||
Description: fmt.Sprintf("Delete content from file %s", filePath),
|
||||
Params: EditPermissionsParams{
|
||||
FilePath: filePath,
|
||||
Diff: diff,
|
||||
},
|
||||
},
|
||||
)
|
||||
if !p {
|
||||
return ToolResponse{}, permission.ErrorPermissionDenied
|
||||
}
|
||||
|
||||
err = os.WriteFile(filePath, []byte(newContent), 0o644)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
// Check if file exists in history
|
||||
file, err := e.history.GetLatestByPathAndSession(ctx, filePath, sessionID)
|
||||
if err != nil {
|
||||
_, err = e.history.Create(ctx, sessionID, filePath, oldContent)
|
||||
if err != nil {
|
||||
// Log error but don't fail the operation
|
||||
return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
|
||||
}
|
||||
}
|
||||
if file.Content != oldContent {
|
||||
// User Manually changed the content store an intermediate version
|
||||
_, err = e.history.CreateVersion(ctx, sessionID, filePath, oldContent)
|
||||
if err != nil {
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
}
|
||||
// Store the new version
|
||||
_, err = e.history.CreateVersion(ctx, sessionID, filePath, "")
|
||||
if err != nil {
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
|
||||
recordFileWrite(filePath)
|
||||
recordFileRead(filePath)
|
||||
|
||||
return WithResponseMetadata(
|
||||
NewTextResponse("Content deleted from file: "+filePath),
|
||||
EditResponseMetadata{
|
||||
Diff: diff,
|
||||
Additions: additions,
|
||||
Removals: removals,
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newString string) (ToolResponse, error) {
|
||||
fileInfo, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return NewTextErrorResponse(fmt.Sprintf("file not found: %s", filePath)), nil
|
||||
}
|
||||
return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
|
||||
}
|
||||
|
||||
if fileInfo.IsDir() {
|
||||
return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
|
||||
}
|
||||
|
||||
if getLastReadTime(filePath).IsZero() {
|
||||
return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
|
||||
}
|
||||
|
||||
modTime := fileInfo.ModTime()
|
||||
lastRead := getLastReadTime(filePath)
|
||||
if modTime.After(lastRead) {
|
||||
return NewTextErrorResponse(
|
||||
fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
|
||||
filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339),
|
||||
)), nil
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
oldContent := string(content)
|
||||
|
||||
index := strings.Index(oldContent, oldString)
|
||||
if index == -1 {
|
||||
return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
|
||||
}
|
||||
|
||||
lastIndex := strings.LastIndex(oldContent, oldString)
|
||||
if index != lastIndex {
|
||||
return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match"), nil
|
||||
}
|
||||
|
||||
newContent := oldContent[:index] + newString + oldContent[index+len(oldString):]
|
||||
|
||||
if oldContent == newContent {
|
||||
return NewTextErrorResponse("new content is the same as old content. No changes made."), nil
|
||||
}
|
||||
sessionID, messageID := GetContextValues(ctx)
|
||||
|
||||
if sessionID == "" || messageID == "" {
|
||||
return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
|
||||
}
|
||||
diff, additions, removals := diff.GenerateDiff(
|
||||
oldContent,
|
||||
newContent,
|
||||
filePath,
|
||||
)
|
||||
rootDir := config.WorkingDirectory()
|
||||
permissionPath := filepath.Dir(filePath)
|
||||
if strings.HasPrefix(filePath, rootDir) {
|
||||
permissionPath = rootDir
|
||||
}
|
||||
p := e.permissions.Request(
|
||||
ctx,
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: permissionPath,
|
||||
ToolName: EditToolName,
|
||||
Action: "write",
|
||||
Description: fmt.Sprintf("Replace content in file %s", filePath),
|
||||
Params: EditPermissionsParams{
|
||||
FilePath: filePath,
|
||||
Diff: diff,
|
||||
},
|
||||
},
|
||||
)
|
||||
if !p {
|
||||
return ToolResponse{}, permission.ErrorPermissionDenied
|
||||
}
|
||||
|
||||
err = os.WriteFile(filePath, []byte(newContent), 0o644)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
// Check if file exists in history
|
||||
file, err := e.history.GetLatestByPathAndSession(ctx, filePath, sessionID)
|
||||
if err != nil {
|
||||
_, err = e.history.Create(ctx, sessionID, filePath, oldContent)
|
||||
if err != nil {
|
||||
// Log error but don't fail the operation
|
||||
return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
|
||||
}
|
||||
}
|
||||
if file.Content != oldContent {
|
||||
// User Manually changed the content store an intermediate version
|
||||
_, err = e.history.CreateVersion(ctx, sessionID, filePath, oldContent)
|
||||
if err != nil {
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
}
|
||||
// Store the new version
|
||||
_, err = e.history.CreateVersion(ctx, sessionID, filePath, newContent)
|
||||
if err != nil {
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
|
||||
recordFileWrite(filePath)
|
||||
recordFileRead(filePath)
|
||||
|
||||
return WithResponseMetadata(
|
||||
NewTextResponse("Content replaced in file: "+filePath),
|
||||
EditResponseMetadata{
|
||||
Diff: diff,
|
||||
Additions: additions,
|
||||
Removals: removals,
|
||||
}), nil
|
||||
}
|
||||
@@ -1,228 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
md "github.com/JohannesKaufmann/html-to-markdown"
|
||||
"github.com/PuerkitoBio/goquery"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/permission"
|
||||
)
|
||||
|
||||
type FetchParams struct {
|
||||
URL string `json:"url"`
|
||||
Format string `json:"format"`
|
||||
Timeout int `json:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
type FetchPermissionsParams struct {
|
||||
URL string `json:"url"`
|
||||
Format string `json:"format"`
|
||||
Timeout int `json:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
type fetchTool struct {
|
||||
client *http.Client
|
||||
permissions permission.Service
|
||||
}
|
||||
|
||||
const (
|
||||
FetchToolName = "fetch"
|
||||
fetchToolDescription = `Fetches content from a URL and returns it in the specified format.
|
||||
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to download content from a URL
|
||||
- Helpful for retrieving documentation, API responses, or web content
|
||||
- Useful for getting external information to assist with tasks
|
||||
|
||||
HOW TO USE:
|
||||
- Provide the URL to fetch content from
|
||||
- Specify the desired output format (text, markdown, or html)
|
||||
- Optionally set a timeout for the request
|
||||
|
||||
FEATURES:
|
||||
- Supports three output formats: text, markdown, and html
|
||||
- Automatically handles HTTP redirects
|
||||
- Sets reasonable timeouts to prevent hanging
|
||||
- Validates input parameters before making requests
|
||||
|
||||
LIMITATIONS:
|
||||
- Maximum response size is 5MB
|
||||
- Only supports HTTP and HTTPS protocols
|
||||
- Cannot handle authentication or cookies
|
||||
- Some websites may block automated requests
|
||||
|
||||
TIPS:
|
||||
- Use text format for plain text content or simple API responses
|
||||
- Use markdown format for content that should be rendered with formatting
|
||||
- Use html format when you need the raw HTML structure
|
||||
- Set appropriate timeouts for potentially slow websites`
|
||||
)
|
||||
|
||||
func NewFetchTool(permissions permission.Service) BaseTool {
|
||||
return &fetchTool{
|
||||
client: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
permissions: permissions,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *fetchTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: FetchToolName,
|
||||
Description: fetchToolDescription,
|
||||
Parameters: map[string]any{
|
||||
"url": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The URL to fetch content from",
|
||||
},
|
||||
"format": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The format to return the content in (text, markdown, or html)",
|
||||
"enum": []string{"text", "markdown", "html"},
|
||||
},
|
||||
"timeout": map[string]any{
|
||||
"type": "number",
|
||||
"description": "Optional timeout in seconds (max 120)",
|
||||
},
|
||||
},
|
||||
Required: []string{"url", "format"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params FetchParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse("Failed to parse fetch parameters: " + err.Error()), nil
|
||||
}
|
||||
|
||||
if params.URL == "" {
|
||||
return NewTextErrorResponse("URL parameter is required"), nil
|
||||
}
|
||||
|
||||
format := strings.ToLower(params.Format)
|
||||
if format != "text" && format != "markdown" && format != "html" {
|
||||
return NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
|
||||
return NewTextErrorResponse("URL must start with http:// or https://"), nil
|
||||
}
|
||||
|
||||
sessionID, messageID := GetContextValues(ctx)
|
||||
if sessionID == "" || messageID == "" {
|
||||
return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
|
||||
}
|
||||
|
||||
p := t.permissions.Request(
|
||||
ctx,
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: config.WorkingDirectory(),
|
||||
ToolName: FetchToolName,
|
||||
Action: "fetch",
|
||||
Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
|
||||
Params: FetchPermissionsParams(params),
|
||||
},
|
||||
)
|
||||
|
||||
if !p {
|
||||
return ToolResponse{}, permission.ErrorPermissionDenied
|
||||
}
|
||||
|
||||
client := t.client
|
||||
if params.Timeout > 0 {
|
||||
maxTimeout := 120 // 2 minutes
|
||||
if params.Timeout > maxTimeout {
|
||||
params.Timeout = maxTimeout
|
||||
}
|
||||
client = &http.Client{
|
||||
Timeout: time.Duration(params.Timeout) * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", params.URL, nil)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", "opencode/1.0")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
|
||||
}
|
||||
|
||||
maxSize := int64(5 * 1024 * 1024) // 5MB
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
|
||||
if err != nil {
|
||||
return NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
|
||||
}
|
||||
|
||||
content := string(body)
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
|
||||
switch format {
|
||||
case "text":
|
||||
if strings.Contains(contentType, "text/html") {
|
||||
text, err := extractTextFromHTML(content)
|
||||
if err != nil {
|
||||
return NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
|
||||
}
|
||||
return NewTextResponse(text), nil
|
||||
}
|
||||
return NewTextResponse(content), nil
|
||||
|
||||
case "markdown":
|
||||
if strings.Contains(contentType, "text/html") {
|
||||
markdown, err := convertHTMLToMarkdown(content)
|
||||
if err != nil {
|
||||
return NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
|
||||
}
|
||||
return NewTextResponse(markdown), nil
|
||||
}
|
||||
|
||||
return NewTextResponse("```\n" + content + "\n```"), nil
|
||||
|
||||
case "html":
|
||||
return NewTextResponse(content), nil
|
||||
|
||||
default:
|
||||
return NewTextResponse(content), nil
|
||||
}
|
||||
}
|
||||
|
||||
func extractTextFromHTML(html string) (string, error) {
|
||||
doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
text := doc.Text()
|
||||
text = strings.Join(strings.Fields(text), " ")
|
||||
|
||||
return text, nil
|
||||
}
|
||||
|
||||
func convertHTMLToMarkdown(html string) (string, error) {
|
||||
converter := md.NewConverter("", true, nil)
|
||||
|
||||
markdown, err := converter.ConvertString(html)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return markdown, nil
|
||||
}
|
||||
@@ -1,53 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// File record to track when files were read/written
|
||||
type fileRecord struct {
|
||||
path string
|
||||
readTime time.Time
|
||||
writeTime time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
fileRecords = make(map[string]fileRecord)
|
||||
fileRecordMutex sync.RWMutex
|
||||
)
|
||||
|
||||
func recordFileRead(path string) {
|
||||
fileRecordMutex.Lock()
|
||||
defer fileRecordMutex.Unlock()
|
||||
|
||||
record, exists := fileRecords[path]
|
||||
if !exists {
|
||||
record = fileRecord{path: path}
|
||||
}
|
||||
record.readTime = time.Now()
|
||||
fileRecords[path] = record
|
||||
}
|
||||
|
||||
func getLastReadTime(path string) time.Time {
|
||||
fileRecordMutex.RLock()
|
||||
defer fileRecordMutex.RUnlock()
|
||||
|
||||
record, exists := fileRecords[path]
|
||||
if !exists {
|
||||
return time.Time{}
|
||||
}
|
||||
return record.readTime
|
||||
}
|
||||
|
||||
func recordFileWrite(path string) {
|
||||
fileRecordMutex.Lock()
|
||||
defer fileRecordMutex.Unlock()
|
||||
|
||||
record, exists := fileRecords[path]
|
||||
if !exists {
|
||||
record = fileRecord{path: path}
|
||||
}
|
||||
record.writeTime = time.Now()
|
||||
fileRecords[path] = record
|
||||
}
|
||||
@@ -1,298 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bmatcuk/doublestar/v4"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
GlobToolName = "glob"
|
||||
globDescription = `Fast file pattern matching tool that finds files by name and pattern, returning matching paths sorted by modification time (newest first).
|
||||
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to find files by name patterns or extensions
|
||||
- Great for finding specific file types across a directory structure
|
||||
- Useful for discovering files that match certain naming conventions
|
||||
|
||||
HOW TO USE:
|
||||
- Provide a glob pattern to match against file paths
|
||||
- Optionally specify a starting directory (defaults to current working directory)
|
||||
- Results are sorted with most recently modified files first
|
||||
|
||||
GLOB PATTERN SYNTAX:
|
||||
- '*' matches any sequence of non-separator characters
|
||||
- '**' matches any sequence of characters, including separators
|
||||
- '?' matches any single non-separator character
|
||||
- '[...]' matches any character in the brackets
|
||||
- '[!...]' matches any character not in the brackets
|
||||
|
||||
COMMON PATTERN EXAMPLES:
|
||||
- '*.js' - Find all JavaScript files in the current directory
|
||||
- '**/*.js' - Find all JavaScript files in any subdirectory
|
||||
- 'src/**/*.{ts,tsx}' - Find all TypeScript files in the src directory
|
||||
- '*.{html,css,js}' - Find all HTML, CSS, and JS files
|
||||
|
||||
LIMITATIONS:
|
||||
- Results are limited to 100 files (newest first)
|
||||
- Does not search file contents (use Grep tool for that)
|
||||
- Hidden files (starting with '.') are skipped
|
||||
|
||||
TIPS:
|
||||
- For the most useful results, combine with the Grep tool: first find files with Glob, then search their contents with Grep
|
||||
- When doing iterative exploration that may require multiple rounds of searching, consider using the Agent tool instead
|
||||
- Always check if results are truncated and refine your search pattern if needed`
|
||||
)
|
||||
|
||||
type fileInfo struct {
|
||||
path string
|
||||
modTime time.Time
|
||||
}
|
||||
|
||||
type GlobParams struct {
|
||||
Pattern string `json:"pattern"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
type GlobResponseMetadata struct {
|
||||
NumberOfFiles int `json:"number_of_files"`
|
||||
Truncated bool `json:"truncated"`
|
||||
}
|
||||
|
||||
type globTool struct{}
|
||||
|
||||
func NewGlobTool() BaseTool {
|
||||
return &globTool{}
|
||||
}
|
||||
|
||||
func (g *globTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: GlobToolName,
|
||||
Description: globDescription,
|
||||
Parameters: map[string]any{
|
||||
"pattern": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The glob pattern to match files against",
|
||||
},
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The directory to search in. Defaults to the current working directory.",
|
||||
},
|
||||
},
|
||||
Required: []string{"pattern"},
|
||||
}
|
||||
}
|
||||
|
||||
func (g *globTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params GlobParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
if params.Pattern == "" {
|
||||
return NewTextErrorResponse("pattern is required"), nil
|
||||
}
|
||||
|
||||
searchPath := params.Path
|
||||
if searchPath == "" {
|
||||
searchPath = config.WorkingDirectory()
|
||||
}
|
||||
|
||||
files, truncated, err := globFiles(params.Pattern, searchPath, 100)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("error finding files: %w", err)
|
||||
}
|
||||
|
||||
var output string
|
||||
if len(files) == 0 {
|
||||
output = "No files found"
|
||||
} else {
|
||||
output = strings.Join(files, "\n")
|
||||
if truncated {
|
||||
output += "\n\n(Results are truncated. Consider using a more specific path or pattern.)"
|
||||
}
|
||||
}
|
||||
|
||||
return WithResponseMetadata(
|
||||
NewTextResponse(output),
|
||||
GlobResponseMetadata{
|
||||
NumberOfFiles: len(files),
|
||||
Truncated: truncated,
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
func globFiles(pattern, searchPath string, limit int) ([]string, bool, error) {
|
||||
matches, err := globWithRipgrep(pattern, searchPath, limit)
|
||||
if err == nil {
|
||||
return matches, len(matches) >= limit, nil
|
||||
}
|
||||
|
||||
return globWithDoublestar(pattern, searchPath, limit)
|
||||
}
|
||||
|
||||
func globWithRipgrep(
|
||||
pattern, searchRoot string,
|
||||
limit int,
|
||||
) ([]string, error) {
|
||||
if searchRoot == "" {
|
||||
searchRoot = "."
|
||||
}
|
||||
|
||||
rgBin, err := exec.LookPath("rg")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ripgrep not found in $PATH: %w", err)
|
||||
}
|
||||
|
||||
if !filepath.IsAbs(pattern) && !strings.HasPrefix(pattern, "/") {
|
||||
pattern = "/" + pattern
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"--files",
|
||||
"--null",
|
||||
"--glob", pattern,
|
||||
"-L",
|
||||
}
|
||||
|
||||
cmd := exec.Command(rgBin, args...)
|
||||
cmd.Dir = searchRoot
|
||||
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
if ee, ok := err.(*exec.ExitError); ok && ee.ExitCode() == 1 {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("ripgrep: %w\n%s", err, out)
|
||||
}
|
||||
|
||||
var matches []string
|
||||
for _, p := range bytes.Split(out, []byte{0}) {
|
||||
if len(p) == 0 {
|
||||
continue
|
||||
}
|
||||
abs := filepath.Join(searchRoot, string(p))
|
||||
if skipHidden(abs) {
|
||||
continue
|
||||
}
|
||||
matches = append(matches, abs)
|
||||
}
|
||||
|
||||
sort.SliceStable(matches, func(i, j int) bool {
|
||||
return len(matches[i]) < len(matches[j])
|
||||
})
|
||||
|
||||
if len(matches) > limit {
|
||||
matches = matches[:limit]
|
||||
}
|
||||
return matches, nil
|
||||
}
|
||||
|
||||
func globWithDoublestar(pattern, searchPath string, limit int) ([]string, bool, error) {
|
||||
fsys := os.DirFS(searchPath)
|
||||
|
||||
relPattern := strings.TrimPrefix(pattern, "/")
|
||||
|
||||
var matches []fileInfo
|
||||
|
||||
err := doublestar.GlobWalk(fsys, relPattern, func(path string, d fs.DirEntry) error {
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if skipHidden(path) {
|
||||
return nil
|
||||
}
|
||||
|
||||
info, err := d.Info()
|
||||
if err != nil {
|
||||
return nil // Skip files we can't access
|
||||
}
|
||||
|
||||
absPath := path // Restore absolute path
|
||||
if !strings.HasPrefix(absPath, searchPath) {
|
||||
absPath = filepath.Join(searchPath, absPath)
|
||||
}
|
||||
|
||||
matches = append(matches, fileInfo{
|
||||
path: absPath,
|
||||
modTime: info.ModTime(),
|
||||
})
|
||||
|
||||
if len(matches) >= limit*2 { // Collect more than needed for sorting
|
||||
return fs.SkipAll
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("glob walk error: %w", err)
|
||||
}
|
||||
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
return matches[i].modTime.After(matches[j].modTime)
|
||||
})
|
||||
|
||||
truncated := len(matches) > limit
|
||||
if truncated {
|
||||
matches = matches[:limit]
|
||||
}
|
||||
|
||||
results := make([]string, len(matches))
|
||||
for i, m := range matches {
|
||||
results[i] = m.path
|
||||
}
|
||||
|
||||
return results, truncated, nil
|
||||
}
|
||||
|
||||
func skipHidden(path string) bool {
|
||||
// Check for hidden files (starting with a dot)
|
||||
base := filepath.Base(path)
|
||||
if base != "." && strings.HasPrefix(base, ".") {
|
||||
return true
|
||||
}
|
||||
|
||||
// List of commonly ignored directories in development projects
|
||||
commonIgnoredDirs := map[string]bool{
|
||||
"node_modules": true,
|
||||
"vendor": true,
|
||||
"dist": true,
|
||||
"build": true,
|
||||
"target": true,
|
||||
".git": true,
|
||||
".idea": true,
|
||||
".vscode": true,
|
||||
"__pycache__": true,
|
||||
"bin": true,
|
||||
"obj": true,
|
||||
"out": true,
|
||||
"coverage": true,
|
||||
"tmp": true,
|
||||
"temp": true,
|
||||
"logs": true,
|
||||
"generated": true,
|
||||
"bower_components": true,
|
||||
"jspm_packages": true,
|
||||
}
|
||||
|
||||
// Check if any path component is in our ignore list
|
||||
parts := strings.SplitSeq(path, string(os.PathSeparator))
|
||||
for part := range parts {
|
||||
if commonIgnoredDirs[part] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -1,358 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
)
|
||||
|
||||
type GrepParams struct {
|
||||
Pattern string `json:"pattern"`
|
||||
Path string `json:"path"`
|
||||
Include string `json:"include"`
|
||||
LiteralText bool `json:"literal_text"`
|
||||
}
|
||||
|
||||
type grepMatch struct {
|
||||
path string
|
||||
modTime time.Time
|
||||
lineNum int
|
||||
lineText string
|
||||
}
|
||||
|
||||
type GrepResponseMetadata struct {
|
||||
NumberOfMatches int `json:"number_of_matches"`
|
||||
Truncated bool `json:"truncated"`
|
||||
}
|
||||
|
||||
type grepTool struct{}
|
||||
|
||||
const (
|
||||
GrepToolName = "grep"
|
||||
grepDescription = `Fast content search tool that finds files containing specific text or patterns, returning matching file paths sorted by modification time (newest first).
|
||||
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to find files containing specific text or patterns
|
||||
- Great for searching code bases for function names, variable declarations, or error messages
|
||||
- Useful for finding all files that use a particular API or pattern
|
||||
|
||||
HOW TO USE:
|
||||
- Provide a regex pattern to search for within file contents
|
||||
- Set literal_text=true if you want to search for the exact text with special characters (recommended for non-regex users)
|
||||
- Optionally specify a starting directory (defaults to current working directory)
|
||||
- Optionally provide an include pattern to filter which files to search
|
||||
- Results are sorted with most recently modified files first
|
||||
|
||||
REGEX PATTERN SYNTAX (when literal_text=false):
|
||||
- Supports standard regular expression syntax
|
||||
- 'function' searches for the literal text "function"
|
||||
- 'log\..*Error' finds text starting with "log." and ending with "Error"
|
||||
- 'import\s+.*\s+from' finds import statements in JavaScript/TypeScript
|
||||
|
||||
COMMON INCLUDE PATTERN EXAMPLES:
|
||||
- '*.js' - Only search JavaScript files
|
||||
- '*.{ts,tsx}' - Only search TypeScript files
|
||||
- '*.go' - Only search Go files
|
||||
|
||||
LIMITATIONS:
|
||||
- Results are limited to 100 files (newest first)
|
||||
- Performance depends on the number of files being searched
|
||||
- Very large binary files may be skipped
|
||||
- Hidden files (starting with '.') are skipped
|
||||
|
||||
TIPS:
|
||||
- For faster, more targeted searches, first use Glob to find relevant files, then use Grep
|
||||
- When doing iterative exploration that may require multiple rounds of searching, consider using the Agent tool instead
|
||||
- Always check if results are truncated and refine your search pattern if needed
|
||||
- Use literal_text=true when searching for exact text containing special characters like dots, parentheses, etc.`
|
||||
)
|
||||
|
||||
func NewGrepTool() BaseTool {
|
||||
return &grepTool{}
|
||||
}
|
||||
|
||||
func (g *grepTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: GrepToolName,
|
||||
Description: grepDescription,
|
||||
Parameters: map[string]any{
|
||||
"pattern": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The regex pattern to search for in file contents",
|
||||
},
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The directory to search in. Defaults to the current working directory.",
|
||||
},
|
||||
"include": map[string]any{
|
||||
"type": "string",
|
||||
"description": "File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")",
|
||||
},
|
||||
"literal_text": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "If true, the pattern will be treated as literal text with special regex characters escaped. Default is false.",
|
||||
},
|
||||
},
|
||||
Required: []string{"pattern"},
|
||||
}
|
||||
}
|
||||
|
||||
// escapeRegexPattern escapes special regex characters so they're treated as literal characters
|
||||
func escapeRegexPattern(pattern string) string {
|
||||
specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
|
||||
escaped := pattern
|
||||
|
||||
for _, char := range specialChars {
|
||||
escaped = strings.ReplaceAll(escaped, char, "\\"+char)
|
||||
}
|
||||
|
||||
return escaped
|
||||
}
|
||||
|
||||
func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params GrepParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
if params.Pattern == "" {
|
||||
return NewTextErrorResponse("pattern is required"), nil
|
||||
}
|
||||
|
||||
// If literal_text is true, escape the pattern
|
||||
searchPattern := params.Pattern
|
||||
if params.LiteralText {
|
||||
searchPattern = escapeRegexPattern(params.Pattern)
|
||||
}
|
||||
|
||||
searchPath := params.Path
|
||||
if searchPath == "" {
|
||||
searchPath = config.WorkingDirectory()
|
||||
}
|
||||
|
||||
matches, truncated, err := searchFiles(searchPattern, searchPath, params.Include, 100)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("error searching files: %w", err)
|
||||
}
|
||||
|
||||
var output string
|
||||
if len(matches) == 0 {
|
||||
output = "No files found"
|
||||
} else {
|
||||
output = fmt.Sprintf("Found %d matches\n", len(matches))
|
||||
|
||||
currentFile := ""
|
||||
for _, match := range matches {
|
||||
if currentFile != match.path {
|
||||
if currentFile != "" {
|
||||
output += "\n"
|
||||
}
|
||||
currentFile = match.path
|
||||
output += fmt.Sprintf("%s:\n", match.path)
|
||||
}
|
||||
if match.lineNum > 0 {
|
||||
output += fmt.Sprintf(" Line %d: %s\n", match.lineNum, match.lineText)
|
||||
} else {
|
||||
output += fmt.Sprintf(" %s\n", match.path)
|
||||
}
|
||||
}
|
||||
|
||||
if truncated {
|
||||
output += "\n(Results are truncated. Consider using a more specific path or pattern.)"
|
||||
}
|
||||
}
|
||||
|
||||
return WithResponseMetadata(
|
||||
NewTextResponse(output),
|
||||
GrepResponseMetadata{
|
||||
NumberOfMatches: len(matches),
|
||||
Truncated: truncated,
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
func searchFiles(pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
|
||||
matches, err := searchWithRipgrep(pattern, rootPath, include)
|
||||
if err != nil {
|
||||
matches, err = searchFilesWithRegex(pattern, rootPath, include)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
return matches[i].modTime.After(matches[j].modTime)
|
||||
})
|
||||
|
||||
truncated := len(matches) > limit
|
||||
if truncated {
|
||||
matches = matches[:limit]
|
||||
}
|
||||
|
||||
return matches, truncated, nil
|
||||
}
|
||||
|
||||
func searchWithRipgrep(pattern, path, include string) ([]grepMatch, error) {
|
||||
_, err := exec.LookPath("rg")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ripgrep not found: %w", err)
|
||||
}
|
||||
|
||||
// Use -n to show line numbers and include the matched line
|
||||
args := []string{"-n", pattern}
|
||||
if include != "" {
|
||||
args = append(args, "--glob", include)
|
||||
}
|
||||
args = append(args, path)
|
||||
|
||||
cmd := exec.Command("rg", args...)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
|
||||
return []grepMatch{}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lines := strings.Split(strings.TrimSpace(string(output)), "\n")
|
||||
matches := make([]grepMatch, 0, len(lines))
|
||||
|
||||
for _, line := range lines {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse ripgrep output format: file:line:content
|
||||
parts := strings.SplitN(line, ":", 3)
|
||||
if len(parts) < 3 {
|
||||
continue
|
||||
}
|
||||
|
||||
filePath := parts[0]
|
||||
lineNum, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
lineText := parts[2]
|
||||
|
||||
fileInfo, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
continue // Skip files we can't access
|
||||
}
|
||||
|
||||
matches = append(matches, grepMatch{
|
||||
path: filePath,
|
||||
modTime: fileInfo.ModTime(),
|
||||
lineNum: lineNum,
|
||||
lineText: lineText,
|
||||
})
|
||||
}
|
||||
|
||||
return matches, nil
|
||||
}
|
||||
|
||||
func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
|
||||
matches := []grepMatch{}
|
||||
|
||||
regex, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid regex pattern: %w", err)
|
||||
}
|
||||
|
||||
var includePattern *regexp.Regexp
|
||||
if include != "" {
|
||||
regexPattern := globToRegex(include)
|
||||
includePattern, err = regexp.Compile(regexPattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid include pattern: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil // Skip errors
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return nil // Skip directories
|
||||
}
|
||||
|
||||
if skipHidden(path) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if includePattern != nil && !includePattern.MatchString(path) {
|
||||
return nil
|
||||
}
|
||||
|
||||
match, lineNum, lineText, err := fileContainsPattern(path, regex)
|
||||
if err != nil {
|
||||
return nil // Skip files we can't read
|
||||
}
|
||||
|
||||
if match {
|
||||
matches = append(matches, grepMatch{
|
||||
path: path,
|
||||
modTime: info.ModTime(),
|
||||
lineNum: lineNum,
|
||||
lineText: lineText,
|
||||
})
|
||||
|
||||
if len(matches) >= 200 {
|
||||
return filepath.SkipAll
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return matches, nil
|
||||
}
|
||||
|
||||
func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, string, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return false, 0, "", err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
lineNum := 0
|
||||
for scanner.Scan() {
|
||||
lineNum++
|
||||
line := scanner.Text()
|
||||
if pattern.MatchString(line) {
|
||||
return true, lineNum, line, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, 0, "", scanner.Err()
|
||||
}
|
||||
|
||||
func globToRegex(glob string) string {
|
||||
regexPattern := strings.ReplaceAll(glob, ".", "\\.")
|
||||
regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
|
||||
regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
|
||||
|
||||
re := regexp.MustCompile(`\{([^}]+)\}`)
|
||||
regexPattern = re.ReplaceAllStringFunc(regexPattern, func(match string) string {
|
||||
inner := match[1 : len(match)-1]
|
||||
return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
|
||||
})
|
||||
|
||||
return regexPattern
|
||||
}
|
||||
@@ -1,316 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
)
|
||||
|
||||
type LSParams struct {
|
||||
Path string `json:"path"`
|
||||
Ignore []string `json:"ignore"`
|
||||
}
|
||||
|
||||
type TreeNode struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
Type string `json:"type"` // "file" or "directory"
|
||||
Children []*TreeNode `json:"children,omitempty"`
|
||||
}
|
||||
|
||||
type LSResponseMetadata struct {
|
||||
NumberOfFiles int `json:"number_of_files"`
|
||||
Truncated bool `json:"truncated"`
|
||||
}
|
||||
|
||||
type lsTool struct{}
|
||||
|
||||
const (
|
||||
LSToolName = "ls"
|
||||
MaxLSFiles = 1000
|
||||
lsDescription = `Directory listing tool that shows files and subdirectories in a tree structure, helping you explore and understand the project organization.
|
||||
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to explore the structure of a directory
|
||||
- Helpful for understanding the organization of a project
|
||||
- Good first step when getting familiar with a new codebase
|
||||
|
||||
HOW TO USE:
|
||||
- Provide a path to list (defaults to current working directory)
|
||||
- Optionally specify glob patterns to ignore
|
||||
- Results are displayed in a tree structure
|
||||
|
||||
FEATURES:
|
||||
- Displays a hierarchical view of files and directories
|
||||
- Automatically skips hidden files/directories (starting with '.')
|
||||
- Skips common system directories like __pycache__
|
||||
- Can filter out files matching specific patterns
|
||||
|
||||
LIMITATIONS:
|
||||
- Results are limited to 1000 files
|
||||
- Very large directories will be truncated
|
||||
- Does not show file sizes or permissions
|
||||
- Cannot recursively list all directories in a large project
|
||||
|
||||
TIPS:
|
||||
- Use Glob tool for finding files by name patterns instead of browsing
|
||||
- Use Grep tool for searching file contents
|
||||
- Combine with other tools for more effective exploration`
|
||||
)
|
||||
|
||||
func NewLsTool() BaseTool {
|
||||
return &lsTool{}
|
||||
}
|
||||
|
||||
func (l *lsTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: LSToolName,
|
||||
Description: lsDescription,
|
||||
Parameters: map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The path to the directory to list (defaults to current working directory)",
|
||||
},
|
||||
"ignore": map[string]any{
|
||||
"type": "array",
|
||||
"description": "List of glob patterns to ignore",
|
||||
"items": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
Required: []string{"path"},
|
||||
}
|
||||
}
|
||||
|
||||
func (l *lsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params LSParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
searchPath := params.Path
|
||||
if searchPath == "" {
|
||||
searchPath = config.WorkingDirectory()
|
||||
}
|
||||
|
||||
if !filepath.IsAbs(searchPath) {
|
||||
searchPath = filepath.Join(config.WorkingDirectory(), searchPath)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(searchPath); os.IsNotExist(err) {
|
||||
return NewTextErrorResponse(fmt.Sprintf("path does not exist: %s", searchPath)), nil
|
||||
}
|
||||
|
||||
files, truncated, err := listDirectory(searchPath, params.Ignore, MaxLSFiles)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("error listing directory: %w", err)
|
||||
}
|
||||
|
||||
tree := createFileTree(files)
|
||||
output := printTree(tree, searchPath)
|
||||
|
||||
if truncated {
|
||||
output = fmt.Sprintf("There are more than %d files in the directory. Use a more specific path or use the Glob tool to find specific files. The first %d files and directories are included below:\n\n%s", MaxLSFiles, MaxLSFiles, output)
|
||||
}
|
||||
|
||||
return WithResponseMetadata(
|
||||
NewTextResponse(output),
|
||||
LSResponseMetadata{
|
||||
NumberOfFiles: len(files),
|
||||
Truncated: truncated,
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
func listDirectory(initialPath string, ignorePatterns []string, limit int) ([]string, bool, error) {
|
||||
var results []string
|
||||
truncated := false
|
||||
|
||||
err := filepath.Walk(initialPath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil // Skip files we don't have permission to access
|
||||
}
|
||||
|
||||
if shouldSkip(path, ignorePatterns) {
|
||||
if info.IsDir() {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if path != initialPath {
|
||||
if info.IsDir() {
|
||||
path = path + string(filepath.Separator)
|
||||
}
|
||||
results = append(results, path)
|
||||
}
|
||||
|
||||
if len(results) >= limit {
|
||||
truncated = true
|
||||
return filepath.SkipAll
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, truncated, err
|
||||
}
|
||||
|
||||
return results, truncated, nil
|
||||
}
|
||||
|
||||
func shouldSkip(path string, ignorePatterns []string) bool {
|
||||
base := filepath.Base(path)
|
||||
|
||||
if base != "." && strings.HasPrefix(base, ".") {
|
||||
return true
|
||||
}
|
||||
|
||||
commonIgnored := []string{
|
||||
"__pycache__",
|
||||
"node_modules",
|
||||
"dist",
|
||||
"build",
|
||||
"target",
|
||||
"vendor",
|
||||
"bin",
|
||||
"obj",
|
||||
".git",
|
||||
".idea",
|
||||
".vscode",
|
||||
".DS_Store",
|
||||
"*.pyc",
|
||||
"*.pyo",
|
||||
"*.pyd",
|
||||
"*.so",
|
||||
"*.dll",
|
||||
"*.exe",
|
||||
}
|
||||
|
||||
if strings.Contains(path, filepath.Join("__pycache__", "")) {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, ignored := range commonIgnored {
|
||||
if strings.HasSuffix(ignored, "/") {
|
||||
if strings.Contains(path, filepath.Join(ignored[:len(ignored)-1], "")) {
|
||||
return true
|
||||
}
|
||||
} else if strings.HasPrefix(ignored, "*.") {
|
||||
if strings.HasSuffix(base, ignored[1:]) {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
if base == ignored {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, pattern := range ignorePatterns {
|
||||
matched, err := filepath.Match(pattern, base)
|
||||
if err == nil && matched {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func createFileTree(sortedPaths []string) []*TreeNode {
|
||||
root := []*TreeNode{}
|
||||
pathMap := make(map[string]*TreeNode)
|
||||
|
||||
for _, path := range sortedPaths {
|
||||
parts := strings.Split(path, string(filepath.Separator))
|
||||
currentPath := ""
|
||||
var parentPath string
|
||||
|
||||
var cleanParts []string
|
||||
for _, part := range parts {
|
||||
if part != "" {
|
||||
cleanParts = append(cleanParts, part)
|
||||
}
|
||||
}
|
||||
parts = cleanParts
|
||||
|
||||
if len(parts) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for i, part := range parts {
|
||||
if currentPath == "" {
|
||||
currentPath = part
|
||||
} else {
|
||||
currentPath = filepath.Join(currentPath, part)
|
||||
}
|
||||
|
||||
if _, exists := pathMap[currentPath]; exists {
|
||||
parentPath = currentPath
|
||||
continue
|
||||
}
|
||||
|
||||
isLastPart := i == len(parts)-1
|
||||
isDir := !isLastPart || strings.HasSuffix(path, string(filepath.Separator))
|
||||
nodeType := "file"
|
||||
if isDir {
|
||||
nodeType = "directory"
|
||||
}
|
||||
newNode := &TreeNode{
|
||||
Name: part,
|
||||
Path: currentPath,
|
||||
Type: nodeType,
|
||||
Children: []*TreeNode{},
|
||||
}
|
||||
|
||||
pathMap[currentPath] = newNode
|
||||
|
||||
if i > 0 && parentPath != "" {
|
||||
if parent, ok := pathMap[parentPath]; ok {
|
||||
parent.Children = append(parent.Children, newNode)
|
||||
}
|
||||
} else {
|
||||
root = append(root, newNode)
|
||||
}
|
||||
|
||||
parentPath = currentPath
|
||||
}
|
||||
}
|
||||
|
||||
return root
|
||||
}
|
||||
|
||||
func printTree(tree []*TreeNode, rootPath string) string {
|
||||
var result strings.Builder
|
||||
|
||||
result.WriteString(fmt.Sprintf("- %s%s\n", rootPath, string(filepath.Separator)))
|
||||
|
||||
for _, node := range tree {
|
||||
printNode(&result, node, 1)
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func printNode(builder *strings.Builder, node *TreeNode, level int) {
|
||||
indent := strings.Repeat(" ", level)
|
||||
|
||||
nodeName := node.Name
|
||||
if node.Type == "directory" {
|
||||
nodeName += string(filepath.Separator)
|
||||
}
|
||||
|
||||
fmt.Fprintf(builder, "%s- %s\n", indent, nodeName)
|
||||
|
||||
if node.Type == "directory" && len(node.Children) > 0 {
|
||||
for _, child := range node.Children {
|
||||
printNode(builder, child, level+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,457 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLsTool_Info(t *testing.T) {
|
||||
tool := NewLsTool()
|
||||
info := tool.Info()
|
||||
|
||||
assert.Equal(t, LSToolName, info.Name)
|
||||
assert.NotEmpty(t, info.Description)
|
||||
assert.Contains(t, info.Parameters, "path")
|
||||
assert.Contains(t, info.Parameters, "ignore")
|
||||
assert.Contains(t, info.Required, "path")
|
||||
}
|
||||
|
||||
func TestLsTool_Run(t *testing.T) {
|
||||
// Create a temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "ls_tool_test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create a test directory structure
|
||||
testDirs := []string{
|
||||
"dir1",
|
||||
"dir2",
|
||||
"dir2/subdir1",
|
||||
"dir2/subdir2",
|
||||
"dir3",
|
||||
"dir3/.hidden_dir",
|
||||
"__pycache__",
|
||||
}
|
||||
|
||||
testFiles := []string{
|
||||
"file1.txt",
|
||||
"file2.txt",
|
||||
"dir1/file3.txt",
|
||||
"dir2/file4.txt",
|
||||
"dir2/subdir1/file5.txt",
|
||||
"dir2/subdir2/file6.txt",
|
||||
"dir3/file7.txt",
|
||||
"dir3/.hidden_file.txt",
|
||||
"__pycache__/cache.pyc",
|
||||
".hidden_root_file.txt",
|
||||
}
|
||||
|
||||
// Create directories
|
||||
for _, dir := range testDirs {
|
||||
dirPath := filepath.Join(tempDir, dir)
|
||||
err := os.MkdirAll(dirPath, 0755)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create files
|
||||
for _, file := range testFiles {
|
||||
filePath := filepath.Join(tempDir, file)
|
||||
err := os.WriteFile(filePath, []byte("test content"), 0644)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("lists directory successfully", func(t *testing.T) {
|
||||
tool := NewLsTool()
|
||||
params := LSParams{
|
||||
Path: tempDir,
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: LSToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that visible directories and files are included
|
||||
assert.Contains(t, response.Content, "dir1")
|
||||
assert.Contains(t, response.Content, "dir2")
|
||||
assert.Contains(t, response.Content, "dir3")
|
||||
assert.Contains(t, response.Content, "file1.txt")
|
||||
assert.Contains(t, response.Content, "file2.txt")
|
||||
|
||||
// Check that hidden files and directories are not included
|
||||
assert.NotContains(t, response.Content, ".hidden_dir")
|
||||
assert.NotContains(t, response.Content, ".hidden_file.txt")
|
||||
assert.NotContains(t, response.Content, ".hidden_root_file.txt")
|
||||
|
||||
// Check that __pycache__ is not included
|
||||
assert.NotContains(t, response.Content, "__pycache__")
|
||||
})
|
||||
|
||||
t.Run("handles non-existent path", func(t *testing.T) {
|
||||
tool := NewLsTool()
|
||||
params := LSParams{
|
||||
Path: filepath.Join(tempDir, "non_existent_dir"),
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: LSToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "path does not exist")
|
||||
})
|
||||
|
||||
t.Run("handles empty path parameter", func(t *testing.T) {
|
||||
// For this test, we need to mock the config.WorkingDirectory function
|
||||
// Since we can't easily do that, we'll just check that the response doesn't contain an error message
|
||||
|
||||
tool := NewLsTool()
|
||||
params := LSParams{
|
||||
Path: "",
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: LSToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The response should either contain a valid directory listing or an error
|
||||
// We'll just check that it's not empty
|
||||
assert.NotEmpty(t, response.Content)
|
||||
})
|
||||
|
||||
t.Run("handles invalid parameters", func(t *testing.T) {
|
||||
tool := NewLsTool()
|
||||
call := ToolCall{
|
||||
Name: LSToolName,
|
||||
Input: "invalid json",
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Content, "error parsing parameters")
|
||||
})
|
||||
|
||||
t.Run("respects ignore patterns", func(t *testing.T) {
|
||||
tool := NewLsTool()
|
||||
params := LSParams{
|
||||
Path: tempDir,
|
||||
Ignore: []string{"file1.txt", "dir1"},
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: LSToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The output format is a tree, so we need to check for specific patterns
|
||||
// Check that file1.txt is not directly mentioned
|
||||
assert.NotContains(t, response.Content, "- file1.txt")
|
||||
|
||||
// Check that dir1/ is not directly mentioned
|
||||
assert.NotContains(t, response.Content, "- dir1/")
|
||||
})
|
||||
|
||||
t.Run("handles relative path", func(t *testing.T) {
|
||||
// Save original working directory
|
||||
origWd, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
os.Chdir(origWd)
|
||||
}()
|
||||
|
||||
// Change to a directory above the temp directory
|
||||
parentDir := filepath.Dir(tempDir)
|
||||
err = os.Chdir(parentDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
tool := NewLsTool()
|
||||
params := LSParams{
|
||||
Path: filepath.Base(tempDir),
|
||||
}
|
||||
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
require.NoError(t, err)
|
||||
|
||||
call := ToolCall{
|
||||
Name: LSToolName,
|
||||
Input: string(paramsJSON),
|
||||
}
|
||||
|
||||
response, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should list the temp directory contents
|
||||
assert.Contains(t, response.Content, "dir1")
|
||||
assert.Contains(t, response.Content, "file1.txt")
|
||||
})
|
||||
}
|
||||
|
||||
func TestShouldSkip(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
path string
|
||||
ignorePatterns []string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "hidden file",
|
||||
path: "/path/to/.hidden_file",
|
||||
ignorePatterns: []string{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "hidden directory",
|
||||
path: "/path/to/.hidden_dir",
|
||||
ignorePatterns: []string{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "pycache directory",
|
||||
path: "/path/to/__pycache__/file.pyc",
|
||||
ignorePatterns: []string{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "node_modules directory",
|
||||
path: "/path/to/node_modules/package",
|
||||
ignorePatterns: []string{},
|
||||
expected: false, // The shouldSkip function doesn't directly check for node_modules in the path
|
||||
},
|
||||
{
|
||||
name: "normal file",
|
||||
path: "/path/to/normal_file.txt",
|
||||
ignorePatterns: []string{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "normal directory",
|
||||
path: "/path/to/normal_dir",
|
||||
ignorePatterns: []string{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "ignored by pattern",
|
||||
path: "/path/to/ignore_me.txt",
|
||||
ignorePatterns: []string{"ignore_*.txt"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "not ignored by pattern",
|
||||
path: "/path/to/keep_me.txt",
|
||||
ignorePatterns: []string{"ignore_*.txt"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := shouldSkip(tc.path, tc.ignorePatterns)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateFileTree(t *testing.T) {
|
||||
paths := []string{
|
||||
"/path/to/file1.txt",
|
||||
"/path/to/dir1/file2.txt",
|
||||
"/path/to/dir1/subdir/file3.txt",
|
||||
"/path/to/dir2/file4.txt",
|
||||
}
|
||||
|
||||
tree := createFileTree(paths)
|
||||
|
||||
// Check the structure of the tree
|
||||
assert.Len(t, tree, 1) // Should have one root node
|
||||
|
||||
// Check the root node
|
||||
rootNode := tree[0]
|
||||
assert.Equal(t, "path", rootNode.Name)
|
||||
assert.Equal(t, "directory", rootNode.Type)
|
||||
assert.Len(t, rootNode.Children, 1)
|
||||
|
||||
// Check the "to" node
|
||||
toNode := rootNode.Children[0]
|
||||
assert.Equal(t, "to", toNode.Name)
|
||||
assert.Equal(t, "directory", toNode.Type)
|
||||
assert.Len(t, toNode.Children, 3) // file1.txt, dir1, dir2
|
||||
|
||||
// Find the dir1 node
|
||||
var dir1Node *TreeNode
|
||||
for _, child := range toNode.Children {
|
||||
if child.Name == "dir1" {
|
||||
dir1Node = child
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, dir1Node)
|
||||
assert.Equal(t, "directory", dir1Node.Type)
|
||||
assert.Len(t, dir1Node.Children, 2) // file2.txt and subdir
|
||||
}
|
||||
|
||||
func TestPrintTree(t *testing.T) {
|
||||
// Create a simple tree
|
||||
tree := []*TreeNode{
|
||||
{
|
||||
Name: "dir1",
|
||||
Path: "dir1",
|
||||
Type: "directory",
|
||||
Children: []*TreeNode{
|
||||
{
|
||||
Name: "file1.txt",
|
||||
Path: "dir1/file1.txt",
|
||||
Type: "file",
|
||||
},
|
||||
{
|
||||
Name: "subdir",
|
||||
Path: "dir1/subdir",
|
||||
Type: "directory",
|
||||
Children: []*TreeNode{
|
||||
{
|
||||
Name: "file2.txt",
|
||||
Path: "dir1/subdir/file2.txt",
|
||||
Type: "file",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "file3.txt",
|
||||
Path: "file3.txt",
|
||||
Type: "file",
|
||||
},
|
||||
}
|
||||
|
||||
result := printTree(tree, "/root")
|
||||
|
||||
// Check the output format
|
||||
assert.Contains(t, result, "- /root/")
|
||||
assert.Contains(t, result, " - dir1/")
|
||||
assert.Contains(t, result, " - file1.txt")
|
||||
assert.Contains(t, result, " - subdir/")
|
||||
assert.Contains(t, result, " - file2.txt")
|
||||
assert.Contains(t, result, " - file3.txt")
|
||||
}
|
||||
|
||||
func TestListDirectory(t *testing.T) {
|
||||
// Create a temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "list_directory_test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create a test directory structure
|
||||
testDirs := []string{
|
||||
"dir1",
|
||||
"dir1/subdir1",
|
||||
".hidden_dir",
|
||||
}
|
||||
|
||||
testFiles := []string{
|
||||
"file1.txt",
|
||||
"file2.txt",
|
||||
"dir1/file3.txt",
|
||||
"dir1/subdir1/file4.txt",
|
||||
".hidden_file.txt",
|
||||
}
|
||||
|
||||
// Create directories
|
||||
for _, dir := range testDirs {
|
||||
dirPath := filepath.Join(tempDir, dir)
|
||||
err := os.MkdirAll(dirPath, 0755)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create files
|
||||
for _, file := range testFiles {
|
||||
filePath := filepath.Join(tempDir, file)
|
||||
err := os.WriteFile(filePath, []byte("test content"), 0644)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("lists files with no limit", func(t *testing.T) {
|
||||
files, truncated, err := listDirectory(tempDir, []string{}, 1000)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, truncated)
|
||||
|
||||
// Check that visible files and directories are included
|
||||
containsPath := func(paths []string, target string) bool {
|
||||
targetPath := filepath.Join(tempDir, target)
|
||||
for _, path := range paths {
|
||||
if strings.HasPrefix(path, targetPath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
assert.True(t, containsPath(files, "dir1"))
|
||||
assert.True(t, containsPath(files, "file1.txt"))
|
||||
assert.True(t, containsPath(files, "file2.txt"))
|
||||
assert.True(t, containsPath(files, "dir1/file3.txt"))
|
||||
|
||||
// Check that hidden files and directories are not included
|
||||
assert.False(t, containsPath(files, ".hidden_dir"))
|
||||
assert.False(t, containsPath(files, ".hidden_file.txt"))
|
||||
})
|
||||
|
||||
t.Run("respects limit and returns truncated flag", func(t *testing.T) {
|
||||
files, truncated, err := listDirectory(tempDir, []string{}, 2)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, truncated)
|
||||
assert.Len(t, files, 2)
|
||||
})
|
||||
|
||||
t.Run("respects ignore patterns", func(t *testing.T) {
|
||||
files, truncated, err := listDirectory(tempDir, []string{"*.txt"}, 1000)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, truncated)
|
||||
|
||||
// Check that no .txt files are included
|
||||
for _, file := range files {
|
||||
assert.False(t, strings.HasSuffix(file, ".txt"), "Found .txt file: %s", file)
|
||||
}
|
||||
|
||||
// But directories should still be included
|
||||
containsDir := false
|
||||
for _, file := range files {
|
||||
if strings.Contains(file, "dir1") {
|
||||
containsDir = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, containsDir)
|
||||
})
|
||||
}
|
||||
@@ -1,198 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
)
|
||||
|
||||
type DefinitionParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
Line int `json:"line"`
|
||||
Column int `json:"column"`
|
||||
}
|
||||
|
||||
type definitionTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
}
|
||||
|
||||
const (
|
||||
DefinitionToolName = "definition"
|
||||
definitionDescription = `Find the definition of a symbol at a specific position in a file.
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to find where a symbol is defined
|
||||
- Helpful for understanding code structure and relationships
|
||||
- Great for navigating between implementation and interface
|
||||
|
||||
HOW TO USE:
|
||||
- Provide the path to the file containing the symbol
|
||||
- Specify the line number (1-based) where the symbol appears
|
||||
- Specify the column number (1-based) where the symbol appears
|
||||
- Results show the location of the symbol's definition
|
||||
|
||||
FEATURES:
|
||||
- Finds definitions across files in the project
|
||||
- Works with variables, functions, classes, interfaces, etc.
|
||||
- Returns file path, line, and column of the definition
|
||||
|
||||
LIMITATIONS:
|
||||
- Requires a functioning LSP server for the file type
|
||||
- May not work for all symbols depending on LSP capabilities
|
||||
- Results depend on the accuracy of the LSP server
|
||||
|
||||
TIPS:
|
||||
- Use in conjunction with References tool to understand usage
|
||||
- Combine with View tool to examine the definition
|
||||
`
|
||||
)
|
||||
|
||||
func NewDefinitionTool(lspClients map[string]*lsp.Client) BaseTool {
|
||||
return &definitionTool{
|
||||
lspClients,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *definitionTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: DefinitionToolName,
|
||||
Description: definitionDescription,
|
||||
Parameters: map[string]any{
|
||||
"file_path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The path to the file containing the symbol",
|
||||
},
|
||||
"line": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "The line number (1-based) where the symbol appears",
|
||||
},
|
||||
"column": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "The column number (1-based) where the symbol appears",
|
||||
},
|
||||
},
|
||||
Required: []string{"file_path", "line", "column"},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *definitionTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params DefinitionParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
lsps := b.lspClients
|
||||
|
||||
if len(lsps) == 0 {
|
||||
return NewTextResponse("\nLSP clients are still initializing. Definition lookup will be available once they're ready.\n"), nil
|
||||
}
|
||||
|
||||
// Ensure file is open in LSP
|
||||
notifyLspOpenFile(ctx, params.FilePath, lsps)
|
||||
|
||||
// Convert 1-based line/column to 0-based for LSP protocol
|
||||
line := max(0, params.Line-1)
|
||||
column := max(0, params.Column-1)
|
||||
|
||||
output := getDefinition(ctx, params.FilePath, line, column, lsps)
|
||||
|
||||
return NewTextResponse(output), nil
|
||||
}
|
||||
|
||||
func getDefinition(ctx context.Context, filePath string, line, column int, lsps map[string]*lsp.Client) string {
|
||||
var results []string
|
||||
|
||||
slog.Debug(fmt.Sprintf("Looking for definition in %s at line %d, column %d", filePath, line+1, column+1))
|
||||
slog.Debug(fmt.Sprintf("Available LSP clients: %v", getClientNames(lsps)))
|
||||
|
||||
for lspName, client := range lsps {
|
||||
slog.Debug(fmt.Sprintf("Trying LSP client: %s", lspName))
|
||||
// Create definition params
|
||||
uri := fmt.Sprintf("file://%s", filePath)
|
||||
definitionParams := protocol.DefinitionParams{
|
||||
TextDocumentPositionParams: protocol.TextDocumentPositionParams{
|
||||
TextDocument: protocol.TextDocumentIdentifier{
|
||||
URI: protocol.DocumentUri(uri),
|
||||
},
|
||||
Position: protocol.Position{
|
||||
Line: uint32(line),
|
||||
Character: uint32(column),
|
||||
},
|
||||
},
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("Sending definition request with params: %+v", definitionParams))
|
||||
|
||||
// Get definition
|
||||
definition, err := client.Definition(ctx, definitionParams)
|
||||
if err != nil {
|
||||
slog.Debug(fmt.Sprintf("Error from %s: %s", lspName, err))
|
||||
results = append(results, fmt.Sprintf("Error from %s: %s", lspName, err))
|
||||
continue
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("Got definition result type: %T", definition.Value))
|
||||
|
||||
// Process the definition result
|
||||
locations := processDefinitionResult(definition)
|
||||
slog.Debug(fmt.Sprintf("Processed locations count: %d", len(locations)))
|
||||
if len(locations) == 0 {
|
||||
results = append(results, fmt.Sprintf("No definition found by %s", lspName))
|
||||
continue
|
||||
}
|
||||
|
||||
// Format the locations
|
||||
for _, loc := range locations {
|
||||
path := strings.TrimPrefix(string(loc.URI), "file://")
|
||||
// Convert 0-based line/column to 1-based for display
|
||||
defLine := loc.Range.Start.Line + 1
|
||||
defColumn := loc.Range.Start.Character + 1
|
||||
slog.Debug(fmt.Sprintf("Found definition at %s:%d:%d", path, defLine, defColumn))
|
||||
results = append(results, fmt.Sprintf("Definition found by %s: %s:%d:%d", lspName, path, defLine, defColumn))
|
||||
}
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
return "No definition found for the symbol at the specified position."
|
||||
}
|
||||
|
||||
return strings.Join(results, "\n")
|
||||
}
|
||||
|
||||
func processDefinitionResult(result protocol.Or_Result_textDocument_definition) []protocol.Location {
|
||||
var locations []protocol.Location
|
||||
|
||||
switch v := result.Value.(type) {
|
||||
case protocol.Location:
|
||||
locations = append(locations, v)
|
||||
case []protocol.Location:
|
||||
locations = append(locations, v...)
|
||||
case []protocol.DefinitionLink:
|
||||
for _, link := range v {
|
||||
locations = append(locations, protocol.Location{
|
||||
URI: link.TargetURI,
|
||||
Range: link.TargetRange,
|
||||
})
|
||||
}
|
||||
case protocol.Or_Definition:
|
||||
switch d := v.Value.(type) {
|
||||
case protocol.Location:
|
||||
locations = append(locations, d)
|
||||
case []protocol.Location:
|
||||
locations = append(locations, d...)
|
||||
}
|
||||
}
|
||||
|
||||
return locations
|
||||
}
|
||||
|
||||
// Helper function to get LSP client names for debugging
|
||||
func getClientNames(lsps map[string]*lsp.Client) []string {
|
||||
names := make([]string, 0, len(lsps))
|
||||
for name := range lsps {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
@@ -1,296 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
)
|
||||
|
||||
type DiagnosticsParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
}
|
||||
type diagnosticsTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
}
|
||||
|
||||
const (
|
||||
DiagnosticsToolName = "diagnostics"
|
||||
diagnosticsDescription = `Get diagnostics for a file and/or project.
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to check for errors or warnings in your code
|
||||
- Helpful for debugging and ensuring code quality
|
||||
- Good for getting a quick overview of issues in a file or project
|
||||
HOW TO USE:
|
||||
- Provide a path to a file to get diagnostics for that file
|
||||
- Leave the path empty to get diagnostics for the entire project
|
||||
- Results are displayed in a structured format with severity levels
|
||||
FEATURES:
|
||||
- Displays errors, warnings, and hints
|
||||
- Groups diagnostics by severity
|
||||
- Provides detailed information about each diagnostic
|
||||
LIMITATIONS:
|
||||
- Results are limited to the diagnostics provided by the LSP clients
|
||||
- May not cover all possible issues in the code
|
||||
- Does not provide suggestions for fixing issues
|
||||
TIPS:
|
||||
- Use in conjunction with other tools for a comprehensive code review
|
||||
- Combine with the LSP client for real-time diagnostics
|
||||
`
|
||||
)
|
||||
|
||||
func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
|
||||
return &diagnosticsTool{
|
||||
lspClients,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *diagnosticsTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: DiagnosticsToolName,
|
||||
Description: diagnosticsDescription,
|
||||
Parameters: map[string]any{
|
||||
"file_path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
|
||||
},
|
||||
},
|
||||
Required: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params DiagnosticsParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
lsps := b.lspClients
|
||||
|
||||
if len(lsps) == 0 {
|
||||
// Return a more helpful message when LSP clients aren't ready yet
|
||||
return NewTextResponse("\n<diagnostic_summary>\nLSP clients are still initializing. Diagnostics will be available once they're ready.\n</diagnostic_summary>\n"), nil
|
||||
}
|
||||
|
||||
if params.FilePath != "" {
|
||||
notifyLspOpenFile(ctx, params.FilePath, lsps)
|
||||
waitForLspDiagnostics(ctx, params.FilePath, lsps)
|
||||
}
|
||||
|
||||
output := getDiagnostics(params.FilePath, lsps)
|
||||
|
||||
return NewTextResponse(output), nil
|
||||
}
|
||||
|
||||
func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
|
||||
for _, client := range lsps {
|
||||
err := client.OpenFile(ctx, filePath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func waitForLspDiagnostics(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
|
||||
if len(lsps) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
diagChan := make(chan struct{}, 1)
|
||||
|
||||
for _, client := range lsps {
|
||||
originalDiags := make(map[protocol.DocumentUri][]protocol.Diagnostic)
|
||||
maps.Copy(originalDiags, client.GetDiagnostics())
|
||||
|
||||
handler := func(params json.RawMessage) {
|
||||
lsp.HandleDiagnostics(client, params)
|
||||
var diagParams protocol.PublishDiagnosticsParams
|
||||
if err := json.Unmarshal(params, &diagParams); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if diagParams.URI.Path() == filePath || hasDiagnosticsChanged(client.GetDiagnostics(), originalDiags) {
|
||||
select {
|
||||
case diagChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client.RegisterNotificationHandler("textDocument/publishDiagnostics", handler)
|
||||
|
||||
if client.IsFileOpen(filePath) {
|
||||
err := client.NotifyChange(ctx, filePath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
err := client.OpenFile(ctx, filePath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-diagChan:
|
||||
case <-time.After(5 * time.Second):
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
func hasDiagnosticsChanged(current, original map[protocol.DocumentUri][]protocol.Diagnostic) bool {
|
||||
for uri, diags := range current {
|
||||
origDiags, exists := original[uri]
|
||||
if !exists || len(diags) != len(origDiags) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
|
||||
fileDiagnostics := []string{}
|
||||
projectDiagnostics := []string{}
|
||||
|
||||
formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string {
|
||||
severity := "Info"
|
||||
switch diagnostic.Severity {
|
||||
case protocol.SeverityError:
|
||||
severity = "Error"
|
||||
case protocol.SeverityWarning:
|
||||
severity = "Warn"
|
||||
case protocol.SeverityHint:
|
||||
severity = "Hint"
|
||||
}
|
||||
|
||||
location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
|
||||
|
||||
sourceInfo := ""
|
||||
if diagnostic.Source != "" {
|
||||
sourceInfo = diagnostic.Source
|
||||
} else if source != "" {
|
||||
sourceInfo = source
|
||||
}
|
||||
|
||||
codeInfo := ""
|
||||
if diagnostic.Code != nil {
|
||||
codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
|
||||
}
|
||||
|
||||
tagsInfo := ""
|
||||
if len(diagnostic.Tags) > 0 {
|
||||
tags := []string{}
|
||||
for _, tag := range diagnostic.Tags {
|
||||
switch tag {
|
||||
case protocol.Unnecessary:
|
||||
tags = append(tags, "unnecessary")
|
||||
case protocol.Deprecated:
|
||||
tags = append(tags, "deprecated")
|
||||
}
|
||||
}
|
||||
if len(tags) > 0 {
|
||||
tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s: %s [%s]%s%s %s",
|
||||
severity,
|
||||
location,
|
||||
sourceInfo,
|
||||
codeInfo,
|
||||
tagsInfo,
|
||||
diagnostic.Message)
|
||||
}
|
||||
|
||||
for lspName, client := range lsps {
|
||||
diagnostics := client.GetDiagnostics()
|
||||
if len(diagnostics) > 0 {
|
||||
for location, diags := range diagnostics {
|
||||
isCurrentFile := location.Path() == filePath
|
||||
|
||||
for _, diag := range diags {
|
||||
formattedDiag := formatDiagnostic(location.Path(), diag, lspName)
|
||||
|
||||
if isCurrentFile {
|
||||
fileDiagnostics = append(fileDiagnostics, formattedDiag)
|
||||
} else {
|
||||
projectDiagnostics = append(projectDiagnostics, formattedDiag)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(fileDiagnostics, func(i, j int) bool {
|
||||
iIsError := strings.HasPrefix(fileDiagnostics[i], "Error")
|
||||
jIsError := strings.HasPrefix(fileDiagnostics[j], "Error")
|
||||
if iIsError != jIsError {
|
||||
return iIsError // Errors come first
|
||||
}
|
||||
return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically
|
||||
})
|
||||
|
||||
sort.Slice(projectDiagnostics, func(i, j int) bool {
|
||||
iIsError := strings.HasPrefix(projectDiagnostics[i], "Error")
|
||||
jIsError := strings.HasPrefix(projectDiagnostics[j], "Error")
|
||||
if iIsError != jIsError {
|
||||
return iIsError
|
||||
}
|
||||
return projectDiagnostics[i] < projectDiagnostics[j]
|
||||
})
|
||||
|
||||
output := ""
|
||||
|
||||
if len(fileDiagnostics) > 0 {
|
||||
output += "\n<file_diagnostics>\n"
|
||||
if len(fileDiagnostics) > 10 {
|
||||
output += strings.Join(fileDiagnostics[:10], "\n")
|
||||
output += fmt.Sprintf("\n... and %d more diagnostics", len(fileDiagnostics)-10)
|
||||
} else {
|
||||
output += strings.Join(fileDiagnostics, "\n")
|
||||
}
|
||||
output += "\n</file_diagnostics>\n"
|
||||
}
|
||||
|
||||
if len(projectDiagnostics) > 0 {
|
||||
output += "\n<project_diagnostics>\n"
|
||||
if len(projectDiagnostics) > 10 {
|
||||
output += strings.Join(projectDiagnostics[:10], "\n")
|
||||
output += fmt.Sprintf("\n... and %d more diagnostics", len(projectDiagnostics)-10)
|
||||
} else {
|
||||
output += strings.Join(projectDiagnostics, "\n")
|
||||
}
|
||||
output += "\n</project_diagnostics>\n"
|
||||
}
|
||||
|
||||
if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
|
||||
fileErrors := countSeverity(fileDiagnostics, "Error")
|
||||
fileWarnings := countSeverity(fileDiagnostics, "Warn")
|
||||
projectErrors := countSeverity(projectDiagnostics, "Error")
|
||||
projectWarnings := countSeverity(projectDiagnostics, "Warn")
|
||||
|
||||
output += "\n<diagnostic_summary>\n"
|
||||
output += fmt.Sprintf("Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
|
||||
output += fmt.Sprintf("Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
|
||||
output += "</diagnostic_summary>\n"
|
||||
}
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
func countSeverity(diagnostics []string, severity string) int {
|
||||
count := 0
|
||||
for _, diag := range diagnostics {
|
||||
if strings.HasPrefix(diag, severity) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
@@ -1,204 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
)
|
||||
|
||||
type DocSymbolsParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
}
|
||||
|
||||
type docSymbolsTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
}
|
||||
|
||||
const (
|
||||
DocSymbolsToolName = "docSymbols"
|
||||
docSymbolsDescription = `Get document symbols for a file.
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to understand the structure of a file
|
||||
- Helpful for finding classes, functions, methods, and variables in a file
|
||||
- Great for getting an overview of a file's organization
|
||||
|
||||
HOW TO USE:
|
||||
- Provide the path to the file to get symbols for
|
||||
- Results show all symbols defined in the file with their kind and location
|
||||
|
||||
FEATURES:
|
||||
- Lists all symbols in a hierarchical structure
|
||||
- Shows symbol types (function, class, variable, etc.)
|
||||
- Provides location information for each symbol
|
||||
- Organizes symbols by their scope and relationship
|
||||
|
||||
LIMITATIONS:
|
||||
- Requires a functioning LSP server for the file type
|
||||
- Results depend on the accuracy of the LSP server
|
||||
- May not work for all file types
|
||||
|
||||
TIPS:
|
||||
- Use to quickly understand the structure of a large file
|
||||
- Combine with Definition and References tools for deeper code exploration
|
||||
`
|
||||
)
|
||||
|
||||
func NewDocSymbolsTool(lspClients map[string]*lsp.Client) BaseTool {
|
||||
return &docSymbolsTool{
|
||||
lspClients,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *docSymbolsTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: DocSymbolsToolName,
|
||||
Description: docSymbolsDescription,
|
||||
Parameters: map[string]any{
|
||||
"file_path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The path to the file to get symbols for",
|
||||
},
|
||||
},
|
||||
Required: []string{"file_path"},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *docSymbolsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params DocSymbolsParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
lsps := b.lspClients
|
||||
|
||||
if len(lsps) == 0 {
|
||||
return NewTextResponse("\nLSP clients are still initializing. Document symbols lookup will be available once they're ready.\n"), nil
|
||||
}
|
||||
|
||||
// Ensure file is open in LSP
|
||||
notifyLspOpenFile(ctx, params.FilePath, lsps)
|
||||
|
||||
output := getDocumentSymbols(ctx, params.FilePath, lsps)
|
||||
|
||||
return NewTextResponse(output), nil
|
||||
}
|
||||
|
||||
func getDocumentSymbols(ctx context.Context, filePath string, lsps map[string]*lsp.Client) string {
|
||||
var results []string
|
||||
|
||||
for lspName, client := range lsps {
|
||||
// Create document symbol params
|
||||
uri := fmt.Sprintf("file://%s", filePath)
|
||||
symbolParams := protocol.DocumentSymbolParams{
|
||||
TextDocument: protocol.TextDocumentIdentifier{
|
||||
URI: protocol.DocumentUri(uri),
|
||||
},
|
||||
}
|
||||
|
||||
// Get document symbols
|
||||
symbolResult, err := client.DocumentSymbol(ctx, symbolParams)
|
||||
if err != nil {
|
||||
results = append(results, fmt.Sprintf("Error from %s: %s", lspName, err))
|
||||
continue
|
||||
}
|
||||
|
||||
// Process the symbol result
|
||||
symbols := processDocumentSymbolResult(symbolResult)
|
||||
if len(symbols) == 0 {
|
||||
results = append(results, fmt.Sprintf("No symbols found by %s", lspName))
|
||||
continue
|
||||
}
|
||||
|
||||
// Format the symbols
|
||||
results = append(results, fmt.Sprintf("Symbols found by %s:", lspName))
|
||||
for _, symbol := range symbols {
|
||||
results = append(results, formatSymbol(symbol, 1))
|
||||
}
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
return "No symbols found in the specified file."
|
||||
}
|
||||
|
||||
return strings.Join(results, "\n")
|
||||
}
|
||||
|
||||
func processDocumentSymbolResult(result protocol.Or_Result_textDocument_documentSymbol) []SymbolInfo {
|
||||
var symbols []SymbolInfo
|
||||
|
||||
switch v := result.Value.(type) {
|
||||
case []protocol.SymbolInformation:
|
||||
for _, si := range v {
|
||||
symbols = append(symbols, SymbolInfo{
|
||||
Name: si.Name,
|
||||
Kind: symbolKindToString(si.Kind),
|
||||
Location: locationToString(si.Location),
|
||||
Children: nil,
|
||||
})
|
||||
}
|
||||
case []protocol.DocumentSymbol:
|
||||
for _, ds := range v {
|
||||
symbols = append(symbols, documentSymbolToSymbolInfo(ds))
|
||||
}
|
||||
}
|
||||
|
||||
return symbols
|
||||
}
|
||||
|
||||
// SymbolInfo represents a symbol in a document
|
||||
type SymbolInfo struct {
|
||||
Name string
|
||||
Kind string
|
||||
Location string
|
||||
Children []SymbolInfo
|
||||
}
|
||||
|
||||
func documentSymbolToSymbolInfo(symbol protocol.DocumentSymbol) SymbolInfo {
|
||||
info := SymbolInfo{
|
||||
Name: symbol.Name,
|
||||
Kind: symbolKindToString(symbol.Kind),
|
||||
Location: fmt.Sprintf("Line %d-%d",
|
||||
symbol.Range.Start.Line+1,
|
||||
symbol.Range.End.Line+1),
|
||||
Children: []SymbolInfo{},
|
||||
}
|
||||
|
||||
for _, child := range symbol.Children {
|
||||
info.Children = append(info.Children, documentSymbolToSymbolInfo(child))
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
func locationToString(location protocol.Location) string {
|
||||
return fmt.Sprintf("Line %d-%d",
|
||||
location.Range.Start.Line+1,
|
||||
location.Range.End.Line+1)
|
||||
}
|
||||
|
||||
func symbolKindToString(kind protocol.SymbolKind) string {
|
||||
if kindStr, ok := protocol.TableKindMap[kind]; ok {
|
||||
return kindStr
|
||||
}
|
||||
return "Unknown"
|
||||
}
|
||||
|
||||
func formatSymbol(symbol SymbolInfo, level int) string {
|
||||
indent := strings.Repeat(" ", level)
|
||||
result := fmt.Sprintf("%s- %s (%s) %s", indent, symbol.Name, symbol.Kind, symbol.Location)
|
||||
|
||||
var childResults []string
|
||||
for _, child := range symbol.Children {
|
||||
childResults = append(childResults, formatSymbol(child, level+1))
|
||||
}
|
||||
|
||||
if len(childResults) > 0 {
|
||||
return result + "\n" + strings.Join(childResults, "\n")
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -1,161 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
)
|
||||
|
||||
type ReferencesParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
Line int `json:"line"`
|
||||
Column int `json:"column"`
|
||||
IncludeDeclaration bool `json:"include_declaration"`
|
||||
}
|
||||
|
||||
type referencesTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
}
|
||||
|
||||
const (
|
||||
ReferencesToolName = "references"
|
||||
referencesDescription = `Find all references to a symbol at a specific position in a file.
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to find all places where a symbol is used
|
||||
- Helpful for understanding code usage and dependencies
|
||||
- Great for refactoring and impact analysis
|
||||
|
||||
HOW TO USE:
|
||||
- Provide the path to the file containing the symbol
|
||||
- Specify the line number (1-based) where the symbol appears
|
||||
- Specify the column number (1-based) where the symbol appears
|
||||
- Optionally set include_declaration to include the declaration in results
|
||||
- Results show all locations where the symbol is referenced
|
||||
|
||||
FEATURES:
|
||||
- Finds references across files in the project
|
||||
- Works with variables, functions, classes, interfaces, etc.
|
||||
- Returns file paths, lines, and columns of all references
|
||||
|
||||
LIMITATIONS:
|
||||
- Requires a functioning LSP server for the file type
|
||||
- May not find all references depending on LSP capabilities
|
||||
- Results depend on the accuracy of the LSP server
|
||||
|
||||
TIPS:
|
||||
- Use in conjunction with Definition tool to understand symbol origins
|
||||
- Combine with View tool to examine the references
|
||||
`
|
||||
)
|
||||
|
||||
func NewReferencesTool(lspClients map[string]*lsp.Client) BaseTool {
|
||||
return &referencesTool{
|
||||
lspClients,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *referencesTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: ReferencesToolName,
|
||||
Description: referencesDescription,
|
||||
Parameters: map[string]any{
|
||||
"file_path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The path to the file containing the symbol",
|
||||
},
|
||||
"line": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "The line number (1-based) where the symbol appears",
|
||||
},
|
||||
"column": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "The column number (1-based) where the symbol appears",
|
||||
},
|
||||
"include_declaration": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Whether to include the declaration in the results",
|
||||
},
|
||||
},
|
||||
Required: []string{"file_path", "line", "column"},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *referencesTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params ReferencesParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
lsps := b.lspClients
|
||||
|
||||
if len(lsps) == 0 {
|
||||
return NewTextResponse("\nLSP clients are still initializing. References lookup will be available once they're ready.\n"), nil
|
||||
}
|
||||
|
||||
// Ensure file is open in LSP
|
||||
notifyLspOpenFile(ctx, params.FilePath, lsps)
|
||||
|
||||
// Convert 1-based line/column to 0-based for LSP protocol
|
||||
line := max(0, params.Line-1)
|
||||
column := max(0, params.Column-1)
|
||||
|
||||
output := getReferences(ctx, params.FilePath, line, column, params.IncludeDeclaration, lsps)
|
||||
|
||||
return NewTextResponse(output), nil
|
||||
}
|
||||
|
||||
func getReferences(ctx context.Context, filePath string, line, column int, includeDeclaration bool, lsps map[string]*lsp.Client) string {
|
||||
var results []string
|
||||
|
||||
for lspName, client := range lsps {
|
||||
// Create references params
|
||||
uri := fmt.Sprintf("file://%s", filePath)
|
||||
referencesParams := protocol.ReferenceParams{
|
||||
TextDocumentPositionParams: protocol.TextDocumentPositionParams{
|
||||
TextDocument: protocol.TextDocumentIdentifier{
|
||||
URI: protocol.DocumentUri(uri),
|
||||
},
|
||||
Position: protocol.Position{
|
||||
Line: uint32(line),
|
||||
Character: uint32(column),
|
||||
},
|
||||
},
|
||||
Context: protocol.ReferenceContext{
|
||||
IncludeDeclaration: includeDeclaration,
|
||||
},
|
||||
}
|
||||
|
||||
// Get references
|
||||
references, err := client.References(ctx, referencesParams)
|
||||
if err != nil {
|
||||
results = append(results, fmt.Sprintf("Error from %s: %s", lspName, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if len(references) == 0 {
|
||||
results = append(results, fmt.Sprintf("No references found by %s", lspName))
|
||||
continue
|
||||
}
|
||||
|
||||
// Format the locations
|
||||
results = append(results, fmt.Sprintf("References found by %s:", lspName))
|
||||
for _, loc := range references {
|
||||
path := strings.TrimPrefix(string(loc.URI), "file://")
|
||||
// Convert 0-based line/column to 1-based for display
|
||||
refLine := loc.Range.Start.Line + 1
|
||||
refColumn := loc.Range.Start.Character + 1
|
||||
results = append(results, fmt.Sprintf(" %s:%d:%d", path, refLine, refColumn))
|
||||
}
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
return "No references found for the symbol at the specified position."
|
||||
}
|
||||
|
||||
return strings.Join(results, "\n")
|
||||
}
|
||||
|
||||
@@ -1,162 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
)
|
||||
|
||||
type WorkspaceSymbolsParams struct {
|
||||
Query string `json:"query"`
|
||||
}
|
||||
|
||||
type workspaceSymbolsTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
}
|
||||
|
||||
const (
|
||||
WorkspaceSymbolsToolName = "workspaceSymbols"
|
||||
workspaceSymbolsDescription = `Find symbols across the workspace matching a query.
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to find symbols across multiple files
|
||||
- Helpful for locating classes, functions, or variables in a project
|
||||
- Great for exploring large codebases
|
||||
|
||||
HOW TO USE:
|
||||
- Provide a query string to search for symbols
|
||||
- Results show matching symbols from across the workspace
|
||||
|
||||
FEATURES:
|
||||
- Searches across all files in the workspace
|
||||
- Shows symbol types (function, class, variable, etc.)
|
||||
- Provides location information for each symbol
|
||||
- Works with partial matches and fuzzy search (depending on LSP server)
|
||||
|
||||
LIMITATIONS:
|
||||
- Requires a functioning LSP server for the file types
|
||||
- Results depend on the accuracy of the LSP server
|
||||
- Query capabilities vary by language server
|
||||
- May not work for all file types
|
||||
|
||||
TIPS:
|
||||
- Use specific queries to narrow down results
|
||||
- Combine with DocSymbols tool for detailed file exploration
|
||||
- Use with Definition tool to jump to symbol definitions
|
||||
`
|
||||
)
|
||||
|
||||
func NewWorkspaceSymbolsTool(lspClients map[string]*lsp.Client) BaseTool {
|
||||
return &workspaceSymbolsTool{
|
||||
lspClients,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *workspaceSymbolsTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: WorkspaceSymbolsToolName,
|
||||
Description: workspaceSymbolsDescription,
|
||||
Parameters: map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The query string to search for symbols",
|
||||
},
|
||||
},
|
||||
Required: []string{"query"},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *workspaceSymbolsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params WorkspaceSymbolsParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
lsps := b.lspClients
|
||||
|
||||
if len(lsps) == 0 {
|
||||
return NewTextResponse("\nLSP clients are still initializing. Workspace symbols lookup will be available once they're ready.\n"), nil
|
||||
}
|
||||
|
||||
output := getWorkspaceSymbols(ctx, params.Query, lsps)
|
||||
|
||||
return NewTextResponse(output), nil
|
||||
}
|
||||
|
||||
func getWorkspaceSymbols(ctx context.Context, query string, lsps map[string]*lsp.Client) string {
|
||||
var results []string
|
||||
|
||||
for lspName, client := range lsps {
|
||||
// Create workspace symbol params
|
||||
symbolParams := protocol.WorkspaceSymbolParams{
|
||||
Query: query,
|
||||
}
|
||||
|
||||
// Get workspace symbols
|
||||
symbolResult, err := client.Symbol(ctx, symbolParams)
|
||||
if err != nil {
|
||||
results = append(results, fmt.Sprintf("Error from %s: %s", lspName, err))
|
||||
continue
|
||||
}
|
||||
|
||||
// Process the symbol result
|
||||
symbols := processWorkspaceSymbolResult(symbolResult)
|
||||
if len(symbols) == 0 {
|
||||
results = append(results, fmt.Sprintf("No symbols found by %s for query '%s'", lspName, query))
|
||||
continue
|
||||
}
|
||||
|
||||
// Format the symbols
|
||||
results = append(results, fmt.Sprintf("Symbols found by %s for query '%s':", lspName, query))
|
||||
for _, symbol := range symbols {
|
||||
results = append(results, fmt.Sprintf(" %s (%s) - %s", symbol.Name, symbol.Kind, symbol.Location))
|
||||
}
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No symbols found matching query '%s'.", query)
|
||||
}
|
||||
|
||||
return strings.Join(results, "\n")
|
||||
}
|
||||
|
||||
func processWorkspaceSymbolResult(result protocol.Or_Result_workspace_symbol) []SymbolInfo {
|
||||
var symbols []SymbolInfo
|
||||
|
||||
switch v := result.Value.(type) {
|
||||
case []protocol.SymbolInformation:
|
||||
for _, si := range v {
|
||||
symbols = append(symbols, SymbolInfo{
|
||||
Name: si.Name,
|
||||
Kind: symbolKindToString(si.Kind),
|
||||
Location: formatWorkspaceLocation(si.Location),
|
||||
Children: nil,
|
||||
})
|
||||
}
|
||||
case []protocol.WorkspaceSymbol:
|
||||
for _, ws := range v {
|
||||
location := "Unknown location"
|
||||
if ws.Location.Value != nil {
|
||||
if loc, ok := ws.Location.Value.(protocol.Location); ok {
|
||||
location = formatWorkspaceLocation(loc)
|
||||
}
|
||||
}
|
||||
symbols = append(symbols, SymbolInfo{
|
||||
Name: ws.Name,
|
||||
Kind: symbolKindToString(ws.Kind),
|
||||
Location: location,
|
||||
Children: nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return symbols
|
||||
}
|
||||
|
||||
func formatWorkspaceLocation(location protocol.Location) string {
|
||||
path := strings.TrimPrefix(string(location.URI), "file://")
|
||||
return fmt.Sprintf("%s:%d:%d", path, location.Range.Start.Line+1, location.Range.Start.Character+1)
|
||||
}
|
||||
@@ -1,375 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/diff"
|
||||
"github.com/sst/opencode/internal/history"
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/permission"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type PatchParams struct {
|
||||
PatchText string `json:"patch_text"`
|
||||
}
|
||||
|
||||
type PatchResponseMetadata struct {
|
||||
FilesChanged []string `json:"files_changed"`
|
||||
Additions int `json:"additions"`
|
||||
Removals int `json:"removals"`
|
||||
}
|
||||
|
||||
type patchTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
permissions permission.Service
|
||||
files history.Service
|
||||
}
|
||||
|
||||
const (
|
||||
PatchToolName = "patch"
|
||||
patchDescription = `Applies a patch to multiple files in one operation. This tool is useful for making coordinated changes across multiple files.
|
||||
|
||||
The patch text must follow this format:
|
||||
*** Begin Patch
|
||||
*** Update File: /path/to/file
|
||||
@@ Context line (unique within the file)
|
||||
Line to keep
|
||||
-Line to remove
|
||||
+Line to add
|
||||
Line to keep
|
||||
*** Add File: /path/to/new/file
|
||||
+Content of the new file
|
||||
+More content
|
||||
*** Delete File: /path/to/file/to/delete
|
||||
*** End Patch
|
||||
|
||||
Before using this tool:
|
||||
1. Use the FileRead tool to understand the files' contents and context
|
||||
2. Verify all file paths are correct (use the LS tool)
|
||||
|
||||
CRITICAL REQUIREMENTS FOR USING THIS TOOL:
|
||||
|
||||
1. UNIQUENESS: Context lines MUST uniquely identify the specific sections you want to change
|
||||
2. PRECISION: All whitespace, indentation, and surrounding code must match exactly
|
||||
3. VALIDATION: Ensure edits result in idiomatic, correct code
|
||||
4. PATHS: Always use absolute file paths (starting with /)
|
||||
|
||||
The tool will apply all changes in a single atomic operation.`
|
||||
)
|
||||
|
||||
func NewPatchTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool {
|
||||
return &patchTool{
|
||||
lspClients: lspClients,
|
||||
permissions: permissions,
|
||||
files: files,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *patchTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: PatchToolName,
|
||||
Description: patchDescription,
|
||||
Parameters: map[string]any{
|
||||
"patch_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The full patch text that describes all changes to be made",
|
||||
},
|
||||
},
|
||||
Required: []string{"patch_text"},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params PatchParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse("invalid parameters"), nil
|
||||
}
|
||||
|
||||
if params.PatchText == "" {
|
||||
return NewTextErrorResponse("patch_text is required"), nil
|
||||
}
|
||||
|
||||
// Identify all files needed for the patch and verify they've been read
|
||||
filesToRead := diff.IdentifyFilesNeeded(params.PatchText)
|
||||
for _, filePath := range filesToRead {
|
||||
absPath := filePath
|
||||
if !filepath.IsAbs(absPath) {
|
||||
wd := config.WorkingDirectory()
|
||||
absPath = filepath.Join(wd, absPath)
|
||||
}
|
||||
|
||||
if getLastReadTime(absPath).IsZero() {
|
||||
return NewTextErrorResponse(fmt.Sprintf("you must read the file %s before patching it. Use the FileRead tool first", filePath)), nil
|
||||
}
|
||||
|
||||
fileInfo, err := os.Stat(absPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return NewTextErrorResponse(fmt.Sprintf("file not found: %s", absPath)), nil
|
||||
}
|
||||
return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
|
||||
}
|
||||
|
||||
if fileInfo.IsDir() {
|
||||
return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", absPath)), nil
|
||||
}
|
||||
|
||||
modTime := fileInfo.ModTime()
|
||||
lastRead := getLastReadTime(absPath)
|
||||
if modTime.After(lastRead) {
|
||||
return NewTextErrorResponse(
|
||||
fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
|
||||
absPath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339),
|
||||
)), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check for new files to ensure they don't already exist
|
||||
filesToAdd := diff.IdentifyFilesAdded(params.PatchText)
|
||||
for _, filePath := range filesToAdd {
|
||||
absPath := filePath
|
||||
if !filepath.IsAbs(absPath) {
|
||||
wd := config.WorkingDirectory()
|
||||
absPath = filepath.Join(wd, absPath)
|
||||
}
|
||||
|
||||
_, err := os.Stat(absPath)
|
||||
if err == nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("file already exists and cannot be added: %s", absPath)), nil
|
||||
} else if !os.IsNotExist(err) {
|
||||
return ToolResponse{}, fmt.Errorf("failed to check file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load all required files
|
||||
currentFiles := make(map[string]string)
|
||||
for _, filePath := range filesToRead {
|
||||
absPath := filePath
|
||||
if !filepath.IsAbs(absPath) {
|
||||
wd := config.WorkingDirectory()
|
||||
absPath = filepath.Join(wd, absPath)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("failed to read file %s: %w", absPath, err)
|
||||
}
|
||||
currentFiles[filePath] = string(content)
|
||||
}
|
||||
|
||||
// Process the patch
|
||||
patch, fuzz, err := diff.TextToPatch(params.PatchText, currentFiles)
|
||||
if err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("failed to parse patch: %s", err)), nil
|
||||
}
|
||||
|
||||
if fuzz > 3 {
|
||||
return NewTextErrorResponse(fmt.Sprintf("patch contains fuzzy matches (fuzz level: %d). Please make your context lines more precise", fuzz)), nil
|
||||
}
|
||||
|
||||
// Convert patch to commit
|
||||
commit, err := diff.PatchToCommit(patch, currentFiles)
|
||||
if err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("failed to create commit from patch: %s", err)), nil
|
||||
}
|
||||
|
||||
// Get session ID and message ID
|
||||
sessionID, messageID := GetContextValues(ctx)
|
||||
if sessionID == "" || messageID == "" {
|
||||
return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a patch")
|
||||
}
|
||||
|
||||
// Request permission for all changes
|
||||
for path, change := range commit.Changes {
|
||||
switch change.Type {
|
||||
case diff.ActionAdd:
|
||||
dir := filepath.Dir(path)
|
||||
patchDiff, _, _ := diff.GenerateDiff("", *change.NewContent, path)
|
||||
p := p.permissions.Request(
|
||||
ctx,
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: dir,
|
||||
ToolName: PatchToolName,
|
||||
Action: "create",
|
||||
Description: fmt.Sprintf("Create file %s", path),
|
||||
Params: EditPermissionsParams{
|
||||
FilePath: path,
|
||||
Diff: patchDiff,
|
||||
},
|
||||
},
|
||||
)
|
||||
if !p {
|
||||
return ToolResponse{}, permission.ErrorPermissionDenied
|
||||
}
|
||||
case diff.ActionUpdate:
|
||||
currentContent := ""
|
||||
if change.OldContent != nil {
|
||||
currentContent = *change.OldContent
|
||||
}
|
||||
newContent := ""
|
||||
if change.NewContent != nil {
|
||||
newContent = *change.NewContent
|
||||
}
|
||||
patchDiff, _, _ := diff.GenerateDiff(currentContent, newContent, path)
|
||||
dir := filepath.Dir(path)
|
||||
p := p.permissions.Request(
|
||||
ctx,
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: dir,
|
||||
ToolName: PatchToolName,
|
||||
Action: "update",
|
||||
Description: fmt.Sprintf("Update file %s", path),
|
||||
Params: EditPermissionsParams{
|
||||
FilePath: path,
|
||||
Diff: patchDiff,
|
||||
},
|
||||
},
|
||||
)
|
||||
if !p {
|
||||
return ToolResponse{}, permission.ErrorPermissionDenied
|
||||
}
|
||||
case diff.ActionDelete:
|
||||
dir := filepath.Dir(path)
|
||||
patchDiff, _, _ := diff.GenerateDiff(*change.OldContent, "", path)
|
||||
p := p.permissions.Request(
|
||||
ctx,
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: dir,
|
||||
ToolName: PatchToolName,
|
||||
Action: "delete",
|
||||
Description: fmt.Sprintf("Delete file %s", path),
|
||||
Params: EditPermissionsParams{
|
||||
FilePath: path,
|
||||
Diff: patchDiff,
|
||||
},
|
||||
},
|
||||
)
|
||||
if !p {
|
||||
return ToolResponse{}, permission.ErrorPermissionDenied
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the changes to the filesystem
|
||||
err = diff.ApplyCommit(commit, func(path string, content string) error {
|
||||
absPath := path
|
||||
if !filepath.IsAbs(absPath) {
|
||||
wd := config.WorkingDirectory()
|
||||
absPath = filepath.Join(wd, absPath)
|
||||
}
|
||||
|
||||
// Create parent directories if needed
|
||||
dir := filepath.Dir(absPath)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return fmt.Errorf("failed to create parent directories for %s: %w", absPath, err)
|
||||
}
|
||||
|
||||
return os.WriteFile(absPath, []byte(content), 0o644)
|
||||
}, func(path string) error {
|
||||
absPath := path
|
||||
if !filepath.IsAbs(absPath) {
|
||||
wd := config.WorkingDirectory()
|
||||
absPath = filepath.Join(wd, absPath)
|
||||
}
|
||||
return os.Remove(absPath)
|
||||
})
|
||||
if err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("failed to apply patch: %s", err)), nil
|
||||
}
|
||||
|
||||
// Update file history for all modified files
|
||||
changedFiles := []string{}
|
||||
totalAdditions := 0
|
||||
totalRemovals := 0
|
||||
|
||||
for path, change := range commit.Changes {
|
||||
absPath := path
|
||||
if !filepath.IsAbs(absPath) {
|
||||
wd := config.WorkingDirectory()
|
||||
absPath = filepath.Join(wd, absPath)
|
||||
}
|
||||
changedFiles = append(changedFiles, absPath)
|
||||
|
||||
oldContent := ""
|
||||
if change.OldContent != nil {
|
||||
oldContent = *change.OldContent
|
||||
}
|
||||
|
||||
newContent := ""
|
||||
if change.NewContent != nil {
|
||||
newContent = *change.NewContent
|
||||
}
|
||||
|
||||
// Calculate diff statistics
|
||||
_, additions, removals := diff.GenerateDiff(oldContent, newContent, path)
|
||||
totalAdditions += additions
|
||||
totalRemovals += removals
|
||||
|
||||
// Update history
|
||||
file, err := p.files.GetLatestByPathAndSession(ctx, absPath, sessionID)
|
||||
if err != nil && change.Type != diff.ActionAdd {
|
||||
// If not adding a file, create history entry for existing file
|
||||
_, err = p.files.Create(ctx, sessionID, absPath, oldContent)
|
||||
if err != nil {
|
||||
slog.Debug("Error creating file history", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil && change.Type != diff.ActionAdd && file.Content != oldContent {
|
||||
// User manually changed content, store intermediate version
|
||||
_, err = p.files.CreateVersion(ctx, sessionID, absPath, oldContent)
|
||||
if err != nil {
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Store new version
|
||||
if change.Type == diff.ActionDelete {
|
||||
_, err = p.files.CreateVersion(ctx, sessionID, absPath, "")
|
||||
} else {
|
||||
_, err = p.files.CreateVersion(ctx, sessionID, absPath, newContent)
|
||||
}
|
||||
if err != nil {
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
|
||||
// Record file operations
|
||||
recordFileWrite(absPath)
|
||||
recordFileRead(absPath)
|
||||
}
|
||||
|
||||
// Run LSP diagnostics on all changed files
|
||||
for _, filePath := range changedFiles {
|
||||
waitForLspDiagnostics(ctx, filePath, p.lspClients)
|
||||
}
|
||||
|
||||
result := fmt.Sprintf("Patch applied successfully. %d files changed, %d additions, %d removals",
|
||||
len(changedFiles), totalAdditions, totalRemovals)
|
||||
|
||||
diagnosticsText := ""
|
||||
for _, filePath := range changedFiles {
|
||||
diagnosticsText += getDiagnostics(filePath, p.lspClients)
|
||||
}
|
||||
|
||||
if diagnosticsText != "" {
|
||||
result += "\n\nDiagnostics:\n" + diagnosticsText
|
||||
}
|
||||
|
||||
return WithResponseMetadata(
|
||||
NewTextResponse(result),
|
||||
PatchResponseMetadata{
|
||||
FilesChanged: changedFiles,
|
||||
Additions: totalAdditions,
|
||||
Removals: totalRemovals,
|
||||
}), nil
|
||||
}
|
||||
@@ -1,308 +0,0 @@
|
||||
package shell
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sst/opencode/internal/status"
|
||||
)
|
||||
|
||||
type PersistentShell struct {
|
||||
cmd *exec.Cmd
|
||||
stdin *os.File
|
||||
isAlive bool
|
||||
cwd string
|
||||
mu sync.Mutex
|
||||
commandQueue chan *commandExecution
|
||||
}
|
||||
|
||||
type commandExecution struct {
|
||||
command string
|
||||
timeout time.Duration
|
||||
resultChan chan commandResult
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
type commandResult struct {
|
||||
stdout string
|
||||
stderr string
|
||||
exitCode int
|
||||
interrupted bool
|
||||
err error
|
||||
}
|
||||
|
||||
var (
|
||||
shellInstance *PersistentShell
|
||||
shellInstanceOnce sync.Once
|
||||
)
|
||||
|
||||
func GetPersistentShell(workingDir string) *PersistentShell {
|
||||
shellInstanceOnce.Do(func() {
|
||||
shellInstance = newPersistentShell(workingDir)
|
||||
})
|
||||
|
||||
if shellInstance == nil {
|
||||
shellInstance = newPersistentShell(workingDir)
|
||||
} else if !shellInstance.isAlive {
|
||||
shellInstance = newPersistentShell(shellInstance.cwd)
|
||||
}
|
||||
|
||||
return shellInstance
|
||||
}
|
||||
|
||||
func newPersistentShell(cwd string) *PersistentShell {
|
||||
shellPath := os.Getenv("SHELL")
|
||||
if shellPath == "" {
|
||||
shellPath = "/bin/bash"
|
||||
}
|
||||
|
||||
cmd := exec.Command(shellPath, "-l")
|
||||
cmd.Dir = cwd
|
||||
|
||||
stdinPipe, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
shell := &PersistentShell{
|
||||
cmd: cmd,
|
||||
stdin: stdinPipe.(*os.File),
|
||||
isAlive: true,
|
||||
cwd: cwd,
|
||||
commandQueue: make(chan *commandExecution, 10),
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
|
||||
shell.isAlive = false
|
||||
close(shell.commandQueue)
|
||||
}
|
||||
}()
|
||||
shell.processCommands()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
if err != nil {
|
||||
status.Error(fmt.Sprintf("Shell process exited with error: %v", err))
|
||||
}
|
||||
shell.isAlive = false
|
||||
close(shell.commandQueue)
|
||||
}()
|
||||
|
||||
return shell
|
||||
}
|
||||
|
||||
func (s *PersistentShell) processCommands() {
|
||||
for cmd := range s.commandQueue {
|
||||
result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
|
||||
cmd.resultChan <- result
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.isAlive {
|
||||
return commandResult{
|
||||
stderr: "Shell is not alive",
|
||||
exitCode: 1,
|
||||
err: errors.New("shell is not alive"),
|
||||
}
|
||||
}
|
||||
|
||||
tempDir := os.TempDir()
|
||||
stdoutFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stdout-%d", time.Now().UnixNano()))
|
||||
stderrFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stderr-%d", time.Now().UnixNano()))
|
||||
statusFile := filepath.Join(tempDir, fmt.Sprintf("opencode-status-%d", time.Now().UnixNano()))
|
||||
cwdFile := filepath.Join(tempDir, fmt.Sprintf("opencode-cwd-%d", time.Now().UnixNano()))
|
||||
|
||||
defer func() {
|
||||
os.Remove(stdoutFile)
|
||||
os.Remove(stderrFile)
|
||||
os.Remove(statusFile)
|
||||
os.Remove(cwdFile)
|
||||
}()
|
||||
|
||||
fullCommand := fmt.Sprintf(`
|
||||
eval %s < /dev/null > %s 2> %s
|
||||
EXEC_EXIT_CODE=$?
|
||||
pwd > %s
|
||||
echo $EXEC_EXIT_CODE > %s
|
||||
`,
|
||||
shellQuote(command),
|
||||
shellQuote(stdoutFile),
|
||||
shellQuote(stderrFile),
|
||||
shellQuote(cwdFile),
|
||||
shellQuote(statusFile),
|
||||
)
|
||||
|
||||
_, err := s.stdin.Write([]byte(fullCommand + "\n"))
|
||||
if err != nil {
|
||||
return commandResult{
|
||||
stderr: fmt.Sprintf("Failed to write command to shell: %v", err),
|
||||
exitCode: 1,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
interrupted := false
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.killChildren()
|
||||
interrupted = true
|
||||
done <- true
|
||||
return
|
||||
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
if fileExists(statusFile) && fileSize(statusFile) > 0 {
|
||||
done <- true
|
||||
return
|
||||
}
|
||||
|
||||
if timeout > 0 {
|
||||
elapsed := time.Since(startTime)
|
||||
if elapsed > timeout {
|
||||
s.killChildren()
|
||||
interrupted = true
|
||||
done <- true
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
|
||||
stdout := readFileOrEmpty(stdoutFile)
|
||||
stderr := readFileOrEmpty(stderrFile)
|
||||
exitCodeStr := readFileOrEmpty(statusFile)
|
||||
newCwd := readFileOrEmpty(cwdFile)
|
||||
|
||||
exitCode := 0
|
||||
if exitCodeStr != "" {
|
||||
fmt.Sscanf(exitCodeStr, "%d", &exitCode)
|
||||
} else if interrupted {
|
||||
exitCode = 143
|
||||
stderr += "\nCommand execution timed out or was interrupted"
|
||||
}
|
||||
|
||||
if newCwd != "" {
|
||||
s.cwd = strings.TrimSpace(newCwd)
|
||||
}
|
||||
|
||||
return commandResult{
|
||||
stdout: stdout,
|
||||
stderr: stderr,
|
||||
exitCode: exitCode,
|
||||
interrupted: interrupted,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PersistentShell) killChildren() {
|
||||
if s.cmd == nil || s.cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
|
||||
pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
|
||||
output, err := pgrepCmd.Output()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for pidStr := range strings.SplitSeq(string(output), "\n") {
|
||||
if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
|
||||
var pid int
|
||||
fmt.Sscanf(pidStr, "%d", &pid)
|
||||
if pid > 0 {
|
||||
proc, err := os.FindProcess(pid)
|
||||
if err == nil {
|
||||
proc.Signal(syscall.SIGTERM)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
|
||||
if !s.isAlive {
|
||||
return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
|
||||
}
|
||||
|
||||
timeout := time.Duration(timeoutMs) * time.Millisecond
|
||||
|
||||
resultChan := make(chan commandResult)
|
||||
s.commandQueue <- &commandExecution{
|
||||
command: command,
|
||||
timeout: timeout,
|
||||
resultChan: resultChan,
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
result := <-resultChan
|
||||
return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
|
||||
}
|
||||
|
||||
func (s *PersistentShell) Close() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.isAlive {
|
||||
return
|
||||
}
|
||||
|
||||
s.stdin.Write([]byte("exit\n"))
|
||||
|
||||
s.cmd.Process.Kill()
|
||||
s.isAlive = false
|
||||
}
|
||||
|
||||
func shellQuote(s string) string {
|
||||
return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
|
||||
}
|
||||
|
||||
func readFileOrEmpty(path string) string {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(content)
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func fileSize(path string) int64 {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return info.Size()
|
||||
}
|
||||
@@ -1,84 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type ToolInfo struct {
|
||||
Name string
|
||||
Description string
|
||||
Parameters map[string]any
|
||||
Required []string
|
||||
}
|
||||
|
||||
type toolResponseType string
|
||||
|
||||
type (
|
||||
sessionIDContextKey string
|
||||
messageIDContextKey string
|
||||
)
|
||||
|
||||
const (
|
||||
ToolResponseTypeText toolResponseType = "text"
|
||||
ToolResponseTypeImage toolResponseType = "image"
|
||||
|
||||
SessionIDContextKey sessionIDContextKey = "session_id"
|
||||
MessageIDContextKey messageIDContextKey = "message_id"
|
||||
)
|
||||
|
||||
type ToolResponse struct {
|
||||
Type toolResponseType `json:"type"`
|
||||
Content string `json:"content"`
|
||||
Metadata string `json:"metadata,omitempty"`
|
||||
IsError bool `json:"is_error"`
|
||||
}
|
||||
|
||||
func NewTextResponse(content string) ToolResponse {
|
||||
return ToolResponse{
|
||||
Type: ToolResponseTypeText,
|
||||
Content: content,
|
||||
}
|
||||
}
|
||||
|
||||
func WithResponseMetadata(response ToolResponse, metadata any) ToolResponse {
|
||||
if metadata != nil {
|
||||
metadataBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return response
|
||||
}
|
||||
response.Metadata = string(metadataBytes)
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
func NewTextErrorResponse(content string) ToolResponse {
|
||||
return ToolResponse{
|
||||
Type: ToolResponseTypeText,
|
||||
Content: content,
|
||||
IsError: true,
|
||||
}
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Input string `json:"input"`
|
||||
}
|
||||
|
||||
type BaseTool interface {
|
||||
Info() ToolInfo
|
||||
Run(ctx context.Context, params ToolCall) (ToolResponse, error)
|
||||
}
|
||||
|
||||
func GetContextValues(ctx context.Context) (string, string) {
|
||||
sessionID := ctx.Value(SessionIDContextKey)
|
||||
messageID := ctx.Value(MessageIDContextKey)
|
||||
if sessionID == nil {
|
||||
return "", ""
|
||||
}
|
||||
if messageID == nil {
|
||||
return sessionID.(string), ""
|
||||
}
|
||||
return sessionID.(string), messageID.(string)
|
||||
}
|
||||
@@ -1,312 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
)
|
||||
|
||||
type ViewParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
Offset int `json:"offset"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
type viewTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
}
|
||||
|
||||
type ViewResponseMetadata struct {
|
||||
FilePath string `json:"file_path"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
const (
|
||||
ViewToolName = "view"
|
||||
MaxReadSize = 250 * 1024
|
||||
DefaultReadLimit = 2000
|
||||
MaxLineLength = 2000
|
||||
viewDescription = `File viewing tool that reads and displays the contents of files with line numbers, allowing you to examine code, logs, or text data.
|
||||
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to read the contents of a specific file
|
||||
- Helpful for examining source code, configuration files, or log files
|
||||
- Perfect for looking at text-based file formats
|
||||
|
||||
HOW TO USE:
|
||||
- Provide the path to the file you want to view
|
||||
- Optionally specify an offset to start reading from a specific line
|
||||
- Optionally specify a limit to control how many lines are read
|
||||
|
||||
FEATURES:
|
||||
- Displays file contents with line numbers for easy reference
|
||||
- Can read from any position in a file using the offset parameter
|
||||
- Handles large files by limiting the number of lines read
|
||||
- Automatically truncates very long lines for better display
|
||||
- Suggests similar file names when the requested file isn't found
|
||||
|
||||
LIMITATIONS:
|
||||
- Maximum file size is 250KB
|
||||
- Default reading limit is 2000 lines
|
||||
- Lines longer than 2000 characters are truncated
|
||||
- Cannot display binary files or images
|
||||
- Images can be identified but not displayed
|
||||
|
||||
TIPS:
|
||||
- Use with Glob tool to first find files you want to view
|
||||
- For code exploration, first use Grep to find relevant files, then View to examine them
|
||||
- When viewing large files, use the offset parameter to read specific sections`
|
||||
)
|
||||
|
||||
func NewViewTool(lspClients map[string]*lsp.Client) BaseTool {
|
||||
return &viewTool{
|
||||
lspClients,
|
||||
}
|
||||
}
|
||||
|
||||
func (v *viewTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: ViewToolName,
|
||||
Description: viewDescription,
|
||||
Parameters: map[string]any{
|
||||
"file_path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The path to the file to read",
|
||||
},
|
||||
"offset": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "The line number to start reading from (0-based)",
|
||||
},
|
||||
"limit": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "The number of lines to read (defaults to 2000)",
|
||||
},
|
||||
},
|
||||
Required: []string{"file_path"},
|
||||
}
|
||||
}
|
||||
|
||||
// Run implements Tool.
|
||||
func (v *viewTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params ViewParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
if params.FilePath == "" {
|
||||
return NewTextErrorResponse("file_path is required"), nil
|
||||
}
|
||||
|
||||
// Handle relative paths
|
||||
filePath := params.FilePath
|
||||
if !filepath.IsAbs(filePath) {
|
||||
filePath = filepath.Join(config.WorkingDirectory(), filePath)
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
fileInfo, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// Try to offer suggestions for similarly named files
|
||||
dir := filepath.Dir(filePath)
|
||||
base := filepath.Base(filePath)
|
||||
|
||||
dirEntries, dirErr := os.ReadDir(dir)
|
||||
if dirErr == nil {
|
||||
var suggestions []string
|
||||
for _, entry := range dirEntries {
|
||||
if strings.Contains(strings.ToLower(entry.Name()), strings.ToLower(base)) ||
|
||||
strings.Contains(strings.ToLower(base), strings.ToLower(entry.Name())) {
|
||||
suggestions = append(suggestions, filepath.Join(dir, entry.Name()))
|
||||
if len(suggestions) >= 3 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(suggestions) > 0 {
|
||||
return NewTextErrorResponse(fmt.Sprintf("File not found: %s\n\nDid you mean one of these?\n%s",
|
||||
filePath, strings.Join(suggestions, "\n"))), nil
|
||||
}
|
||||
}
|
||||
|
||||
return NewTextErrorResponse(fmt.Sprintf("File not found: %s", filePath)), nil
|
||||
}
|
||||
return ToolResponse{}, fmt.Errorf("error accessing file: %w", err)
|
||||
}
|
||||
|
||||
// Check if it's a directory
|
||||
if fileInfo.IsDir() {
|
||||
return NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
|
||||
}
|
||||
|
||||
// Check file size
|
||||
if fileInfo.Size() > MaxReadSize {
|
||||
return NewTextErrorResponse(fmt.Sprintf("File is too large (%d bytes). Maximum size is %d bytes",
|
||||
fileInfo.Size(), MaxReadSize)), nil
|
||||
}
|
||||
|
||||
// Set default limit if not provided
|
||||
if params.Limit <= 0 {
|
||||
params.Limit = DefaultReadLimit
|
||||
}
|
||||
|
||||
// Check if it's an image file
|
||||
isImage, imageType := isImageFile(filePath)
|
||||
// TODO: handle images
|
||||
if isImage {
|
||||
return NewTextErrorResponse(fmt.Sprintf("This is an image file of type: %s\nUse a different tool to process images", imageType)), nil
|
||||
}
|
||||
|
||||
// Read the file content
|
||||
content, lineCount, err := readTextFile(filePath, params.Offset, params.Limit)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("error reading file: %w", err)
|
||||
}
|
||||
|
||||
notifyLspOpenFile(ctx, filePath, v.lspClients)
|
||||
output := "<file>\n"
|
||||
// Format the output with line numbers
|
||||
output += addLineNumbers(content, params.Offset+1)
|
||||
|
||||
// Add a note if the content was truncated
|
||||
if lineCount > params.Offset+len(strings.Split(content, "\n")) {
|
||||
output += fmt.Sprintf("\n\n(File has more lines. Use 'offset' parameter to read beyond line %d)",
|
||||
params.Offset+len(strings.Split(content, "\n")))
|
||||
}
|
||||
output += "\n</file>\n"
|
||||
output += getDiagnostics(filePath, v.lspClients)
|
||||
recordFileRead(filePath)
|
||||
return WithResponseMetadata(
|
||||
NewTextResponse(output),
|
||||
ViewResponseMetadata{
|
||||
FilePath: filePath,
|
||||
Content: content,
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
func addLineNumbers(content string, startLine int) string {
|
||||
if content == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
lines := strings.Split(content, "\n")
|
||||
|
||||
var result []string
|
||||
for i, line := range lines {
|
||||
line = strings.TrimSuffix(line, "\r")
|
||||
|
||||
lineNum := i + startLine
|
||||
numStr := fmt.Sprintf("%d", lineNum)
|
||||
|
||||
if len(numStr) >= 6 {
|
||||
result = append(result, fmt.Sprintf("%s|%s", numStr, line))
|
||||
} else {
|
||||
paddedNum := fmt.Sprintf("%6s", numStr)
|
||||
result = append(result, fmt.Sprintf("%s|%s", paddedNum, line))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(result, "\n")
|
||||
}
|
||||
|
||||
func readTextFile(filePath string, offset, limit int) (string, int, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
lineCount := 0
|
||||
|
||||
scanner := NewLineScanner(file)
|
||||
if offset > 0 {
|
||||
for lineCount < offset && scanner.Scan() {
|
||||
lineCount++
|
||||
}
|
||||
if err = scanner.Err(); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
}
|
||||
|
||||
if offset == 0 {
|
||||
_, err = file.Seek(0, io.SeekStart)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lineCount = offset
|
||||
|
||||
for scanner.Scan() && len(lines) < limit {
|
||||
lineCount++
|
||||
lineText := scanner.Text()
|
||||
if len(lineText) > MaxLineLength {
|
||||
lineText = lineText[:MaxLineLength] + "..."
|
||||
}
|
||||
lines = append(lines, lineText)
|
||||
}
|
||||
|
||||
// Continue scanning to get total line count
|
||||
for scanner.Scan() {
|
||||
lineCount++
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n"), lineCount, nil
|
||||
}
|
||||
|
||||
func isImageFile(filePath string) (bool, string) {
|
||||
ext := strings.ToLower(filepath.Ext(filePath))
|
||||
switch ext {
|
||||
case ".jpg", ".jpeg":
|
||||
return true, "JPEG"
|
||||
case ".png":
|
||||
return true, "PNG"
|
||||
case ".gif":
|
||||
return true, "GIF"
|
||||
case ".bmp":
|
||||
return true, "BMP"
|
||||
case ".svg":
|
||||
return true, "SVG"
|
||||
case ".webp":
|
||||
return true, "WebP"
|
||||
default:
|
||||
return false, ""
|
||||
}
|
||||
}
|
||||
|
||||
type LineScanner struct {
|
||||
scanner *bufio.Scanner
|
||||
}
|
||||
|
||||
func NewLineScanner(r io.Reader) *LineScanner {
|
||||
return &LineScanner{
|
||||
scanner: bufio.NewScanner(r),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LineScanner) Scan() bool {
|
||||
return s.scanner.Scan()
|
||||
}
|
||||
|
||||
func (s *LineScanner) Text() string {
|
||||
return s.scanner.Text()
|
||||
}
|
||||
|
||||
func (s *LineScanner) Err() error {
|
||||
return s.scanner.Err()
|
||||
}
|
||||
@@ -1,228 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/diff"
|
||||
"github.com/sst/opencode/internal/history"
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/permission"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type WriteParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type WritePermissionsParams struct {
|
||||
FilePath string `json:"file_path"`
|
||||
Diff string `json:"diff"`
|
||||
}
|
||||
|
||||
type writeTool struct {
|
||||
lspClients map[string]*lsp.Client
|
||||
permissions permission.Service
|
||||
files history.Service
|
||||
}
|
||||
|
||||
type WriteResponseMetadata struct {
|
||||
Diff string `json:"diff"`
|
||||
Additions int `json:"additions"`
|
||||
Removals int `json:"removals"`
|
||||
}
|
||||
|
||||
const (
|
||||
WriteToolName = "write"
|
||||
writeDescription = `File writing tool that creates or updates files in the filesystem, allowing you to save or modify text content.
|
||||
|
||||
WHEN TO USE THIS TOOL:
|
||||
- Use when you need to create a new file
|
||||
- Helpful for updating existing files with modified content
|
||||
- Perfect for saving generated code, configurations, or text data
|
||||
|
||||
HOW TO USE:
|
||||
- Provide the path to the file you want to write
|
||||
- Include the content to be written to the file
|
||||
- The tool will create any necessary parent directories
|
||||
|
||||
FEATURES:
|
||||
- Can create new files or overwrite existing ones
|
||||
- Creates parent directories automatically if they don't exist
|
||||
- Checks if the file has been modified since last read for safety
|
||||
- Avoids unnecessary writes when content hasn't changed
|
||||
|
||||
LIMITATIONS:
|
||||
- You should read a file before writing to it to avoid conflicts
|
||||
- Cannot append to files (rewrites the entire file)
|
||||
|
||||
|
||||
TIPS:
|
||||
- Use the View tool first to examine existing files before modifying them
|
||||
- Use the LS tool to verify the correct location when creating new files
|
||||
- Combine with Glob and Grep tools to find and modify multiple files
|
||||
- Always include descriptive comments when making changes to existing code`
|
||||
)
|
||||
|
||||
func NewWriteTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool {
|
||||
return &writeTool{
|
||||
lspClients: lspClients,
|
||||
permissions: permissions,
|
||||
files: files,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *writeTool) Info() ToolInfo {
|
||||
return ToolInfo{
|
||||
Name: WriteToolName,
|
||||
Description: writeDescription,
|
||||
Parameters: map[string]any{
|
||||
"file_path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The path to the file to write",
|
||||
},
|
||||
"content": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The content to write to the file",
|
||||
},
|
||||
},
|
||||
Required: []string{"file_path", "content"},
|
||||
}
|
||||
}
|
||||
|
||||
func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
|
||||
var params WriteParams
|
||||
if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
|
||||
return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
||||
}
|
||||
|
||||
if params.FilePath == "" {
|
||||
return NewTextErrorResponse("file_path is required"), nil
|
||||
}
|
||||
|
||||
if params.Content == "" {
|
||||
return NewTextErrorResponse("content is required"), nil
|
||||
}
|
||||
|
||||
filePath := params.FilePath
|
||||
if !filepath.IsAbs(filePath) {
|
||||
filePath = filepath.Join(config.WorkingDirectory(), filePath)
|
||||
}
|
||||
|
||||
fileInfo, err := os.Stat(filePath)
|
||||
if err == nil {
|
||||
if fileInfo.IsDir() {
|
||||
return NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
|
||||
}
|
||||
|
||||
modTime := fileInfo.ModTime()
|
||||
lastRead := getLastReadTime(filePath)
|
||||
if modTime.After(lastRead) {
|
||||
return NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
|
||||
filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
|
||||
}
|
||||
|
||||
oldContent, readErr := os.ReadFile(filePath)
|
||||
if readErr == nil && string(oldContent) == params.Content {
|
||||
return NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
return ToolResponse{}, fmt.Errorf("error checking file: %w", err)
|
||||
}
|
||||
|
||||
dir := filepath.Dir(filePath)
|
||||
if err = os.MkdirAll(dir, 0o755); err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("error creating directory: %w", err)
|
||||
}
|
||||
|
||||
oldContent := ""
|
||||
if fileInfo != nil && !fileInfo.IsDir() {
|
||||
oldBytes, readErr := os.ReadFile(filePath)
|
||||
if readErr == nil {
|
||||
oldContent = string(oldBytes)
|
||||
}
|
||||
}
|
||||
|
||||
sessionID, messageID := GetContextValues(ctx)
|
||||
if sessionID == "" || messageID == "" {
|
||||
return ToolResponse{}, fmt.Errorf("session_id and message_id are required")
|
||||
}
|
||||
|
||||
diff, additions, removals := diff.GenerateDiff(
|
||||
oldContent,
|
||||
params.Content,
|
||||
filePath,
|
||||
)
|
||||
|
||||
rootDir := config.WorkingDirectory()
|
||||
permissionPath := filepath.Dir(filePath)
|
||||
if strings.HasPrefix(filePath, rootDir) {
|
||||
permissionPath = rootDir
|
||||
}
|
||||
p := w.permissions.Request(
|
||||
ctx,
|
||||
permission.CreatePermissionRequest{
|
||||
SessionID: sessionID,
|
||||
Path: permissionPath,
|
||||
ToolName: WriteToolName,
|
||||
Action: "write",
|
||||
Description: fmt.Sprintf("Create file %s", filePath),
|
||||
Params: WritePermissionsParams{
|
||||
FilePath: filePath,
|
||||
Diff: diff,
|
||||
},
|
||||
},
|
||||
)
|
||||
if !p {
|
||||
return ToolResponse{}, permission.ErrorPermissionDenied
|
||||
}
|
||||
|
||||
err = os.WriteFile(filePath, []byte(params.Content), 0o644)
|
||||
if err != nil {
|
||||
return ToolResponse{}, fmt.Errorf("error writing file: %w", err)
|
||||
}
|
||||
|
||||
// Check if file exists in history
|
||||
file, err := w.files.GetLatestByPathAndSession(ctx, filePath, sessionID)
|
||||
if err != nil {
|
||||
_, err = w.files.Create(ctx, sessionID, filePath, oldContent)
|
||||
if err != nil {
|
||||
// Log error but don't fail the operation
|
||||
return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
|
||||
}
|
||||
}
|
||||
if file.Content != oldContent {
|
||||
// User Manually changed the content store an intermediate version
|
||||
_, err = w.files.CreateVersion(ctx, sessionID, filePath, oldContent)
|
||||
if err != nil {
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
}
|
||||
// Store the new version
|
||||
_, err = w.files.CreateVersion(ctx, sessionID, filePath, params.Content)
|
||||
if err != nil {
|
||||
slog.Debug("Error creating file history version", "error", err)
|
||||
}
|
||||
|
||||
recordFileWrite(filePath)
|
||||
recordFileRead(filePath)
|
||||
waitForLspDiagnostics(ctx, filePath, w.lspClients)
|
||||
|
||||
result := fmt.Sprintf("File successfully written: %s", filePath)
|
||||
result = fmt.Sprintf("<result>\n%s\n</result>", result)
|
||||
result += getDiagnostics(filePath, w.lspClients)
|
||||
return WithResponseMetadata(NewTextResponse(result),
|
||||
WriteResponseMetadata{
|
||||
Diff: diff,
|
||||
Additions: additions,
|
||||
Removals: removals,
|
||||
},
|
||||
), nil
|
||||
}
|
||||
@@ -1,292 +0,0 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-logfmt/logfmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sst/opencode/internal/db"
|
||||
"github.com/sst/opencode/internal/pubsub"
|
||||
)
|
||||
|
||||
type Log struct {
|
||||
ID string
|
||||
SessionID string
|
||||
Timestamp time.Time
|
||||
Level string
|
||||
Message string
|
||||
Attributes map[string]string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
EventLogCreated pubsub.EventType = "log_created"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
pubsub.Subscriber[Log]
|
||||
|
||||
Create(ctx context.Context, timestamp time.Time, level, message string, attributes map[string]string, sessionID string) error
|
||||
ListBySession(ctx context.Context, sessionID string) ([]Log, error)
|
||||
ListAll(ctx context.Context, limit int) ([]Log, error)
|
||||
}
|
||||
|
||||
type service struct {
|
||||
db *db.Queries
|
||||
broker *pubsub.Broker[Log]
|
||||
}
|
||||
|
||||
var globalLoggingService *service
|
||||
|
||||
func InitService(dbConn *sql.DB) error {
|
||||
if globalLoggingService != nil {
|
||||
return fmt.Errorf("logging service already initialized")
|
||||
}
|
||||
queries := db.New(dbConn)
|
||||
broker := pubsub.NewBroker[Log]()
|
||||
|
||||
globalLoggingService = &service{
|
||||
db: queries,
|
||||
broker: broker,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetService() Service {
|
||||
if globalLoggingService == nil {
|
||||
panic("logging service not initialized. Call logging.InitService() first.")
|
||||
}
|
||||
return globalLoggingService
|
||||
}
|
||||
|
||||
func (s *service) Create(ctx context.Context, timestamp time.Time, level, message string, attributes map[string]string, sessionID string) error {
|
||||
if level == "" {
|
||||
level = "info"
|
||||
}
|
||||
|
||||
var attributesJSON sql.NullString
|
||||
if len(attributes) > 0 {
|
||||
attributesBytes, err := json.Marshal(attributes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal log attributes: %w", err)
|
||||
}
|
||||
attributesJSON = sql.NullString{String: string(attributesBytes), Valid: true}
|
||||
}
|
||||
|
||||
dbLog, err := s.db.CreateLog(ctx, db.CreateLogParams{
|
||||
ID: uuid.New().String(),
|
||||
SessionID: sql.NullString{String: sessionID, Valid: sessionID != ""},
|
||||
Timestamp: timestamp.UTC().Format(time.RFC3339Nano),
|
||||
Level: level,
|
||||
Message: message,
|
||||
Attributes: attributesJSON,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.CreateLog: %w", err)
|
||||
}
|
||||
|
||||
log := s.fromDBItem(dbLog)
|
||||
s.broker.Publish(EventLogCreated, log)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) ListBySession(ctx context.Context, sessionID string) ([]Log, error) {
|
||||
dbLogs, err := s.db.ListLogsBySession(ctx, sql.NullString{String: sessionID, Valid: true})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListLogsBySession: %w", err)
|
||||
}
|
||||
|
||||
logs := make([]Log, len(dbLogs))
|
||||
for i, dbSess := range dbLogs {
|
||||
logs[i] = s.fromDBItem(dbSess)
|
||||
}
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
func (s *service) ListAll(ctx context.Context, limit int) ([]Log, error) {
|
||||
dbLogs, err := s.db.ListAllLogs(ctx, int64(limit))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListAllLogs: %w", err)
|
||||
}
|
||||
logs := make([]Log, len(dbLogs))
|
||||
for i, dbSess := range dbLogs {
|
||||
logs[i] = s.fromDBItem(dbSess)
|
||||
}
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[Log] {
|
||||
return s.broker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func (s *service) fromDBItem(item db.Log) Log {
|
||||
log := Log{
|
||||
ID: item.ID,
|
||||
SessionID: item.SessionID.String,
|
||||
Level: item.Level,
|
||||
Message: item.Message,
|
||||
}
|
||||
|
||||
// Parse timestamp from ISO string
|
||||
timestamp, err := time.Parse(time.RFC3339Nano, item.Timestamp)
|
||||
if err == nil {
|
||||
log.Timestamp = timestamp
|
||||
} else {
|
||||
log.Timestamp = time.Now() // Fallback
|
||||
}
|
||||
|
||||
// Parse created_at from ISO string
|
||||
createdAt, err := time.Parse(time.RFC3339Nano, item.CreatedAt)
|
||||
if err == nil {
|
||||
log.CreatedAt = createdAt
|
||||
} else {
|
||||
log.CreatedAt = time.Now() // Fallback
|
||||
}
|
||||
|
||||
if item.Attributes.Valid && item.Attributes.String != "" {
|
||||
if err := json.Unmarshal([]byte(item.Attributes.String), &log.Attributes); err != nil {
|
||||
slog.Error("Failed to unmarshal log attributes", "log_id", item.ID, "error", err)
|
||||
log.Attributes = make(map[string]string)
|
||||
}
|
||||
} else {
|
||||
log.Attributes = make(map[string]string)
|
||||
}
|
||||
|
||||
return log
|
||||
}
|
||||
|
||||
func Create(ctx context.Context, timestamp time.Time, level, message string, attributes map[string]string, sessionID string) error {
|
||||
return GetService().Create(ctx, timestamp, level, message, attributes, sessionID)
|
||||
}
|
||||
|
||||
func ListBySession(ctx context.Context, sessionID string) ([]Log, error) {
|
||||
return GetService().ListBySession(ctx, sessionID)
|
||||
}
|
||||
|
||||
func ListAll(ctx context.Context, limit int) ([]Log, error) {
|
||||
return GetService().ListAll(ctx, limit)
|
||||
}
|
||||
|
||||
func Subscribe(ctx context.Context) <-chan pubsub.Event[Log] {
|
||||
return GetService().Subscribe(ctx)
|
||||
}
|
||||
|
||||
type slogWriter struct{}
|
||||
|
||||
func (sw *slogWriter) Write(p []byte) (n int, err error) {
|
||||
// Example: time=2024-05-09T12:34:56.789-05:00 level=INFO msg="User request" session=xyz foo=bar
|
||||
d := logfmt.NewDecoder(bytes.NewReader(p))
|
||||
for d.ScanRecord() {
|
||||
var timestamp time.Time
|
||||
var level string
|
||||
var message string
|
||||
var sessionID string
|
||||
var attributes map[string]string
|
||||
|
||||
attributes = make(map[string]string)
|
||||
hasTimestamp := false
|
||||
|
||||
for d.ScanKeyval() {
|
||||
key := string(d.Key())
|
||||
value := string(d.Value())
|
||||
|
||||
switch key {
|
||||
case "time":
|
||||
parsedTime, timeErr := time.Parse(time.RFC3339Nano, value)
|
||||
if timeErr != nil {
|
||||
parsedTime, timeErr = time.Parse(time.RFC3339, value)
|
||||
if timeErr != nil {
|
||||
slog.Error("Failed to parse time in slog writer", "value", value, "error", timeErr)
|
||||
timestamp = time.Now().UTC()
|
||||
hasTimestamp = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
timestamp = parsedTime
|
||||
hasTimestamp = true
|
||||
case "level":
|
||||
level = strings.ToLower(value)
|
||||
case "msg", "message":
|
||||
message = value
|
||||
case "session_id":
|
||||
sessionID = value
|
||||
default:
|
||||
attributes[key] = value
|
||||
}
|
||||
}
|
||||
if d.Err() != nil {
|
||||
return len(p), fmt.Errorf("logfmt.ScanRecord: %w", d.Err())
|
||||
}
|
||||
|
||||
if !hasTimestamp {
|
||||
timestamp = time.Now()
|
||||
}
|
||||
|
||||
// Create log entry via the service (non-blocking or handle error appropriately)
|
||||
// Using context.Background() as this is a low-level logging write.
|
||||
go func(timestamp time.Time, level, message string, attributes map[string]string, sessionID string) { // Run in a goroutine to avoid blocking slog
|
||||
if globalLoggingService == nil {
|
||||
// If the logging service is not initialized, log the message to stderr
|
||||
// fmt.Fprintf(os.Stderr, "ERROR [logging.slogWriter]: logging service not initialized\n")
|
||||
return
|
||||
}
|
||||
if err := Create(context.Background(), timestamp, level, message, attributes, sessionID); err != nil {
|
||||
// Log internal error using a more primitive logger to avoid loops
|
||||
fmt.Fprintf(os.Stderr, "ERROR [logging.slogWriter]: failed to persist log: %v\n", err)
|
||||
}
|
||||
}(timestamp, level, message, attributes, sessionID)
|
||||
}
|
||||
if d.Err() != nil {
|
||||
return len(p), fmt.Errorf("logfmt.ScanRecord final: %w", d.Err())
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func NewSlogWriter() io.Writer {
|
||||
return &slogWriter{}
|
||||
}
|
||||
|
||||
// RecoverPanic is a common function to handle panics gracefully.
|
||||
// It logs the error, creates a panic log file with stack trace,
|
||||
// and executes an optional cleanup function.
|
||||
func RecoverPanic(name string, cleanup func()) {
|
||||
if r := recover(); r != nil {
|
||||
errorMsg := fmt.Sprintf("Panic in %s: %v", name, r)
|
||||
// Use slog directly here, as our service might be the one panicking.
|
||||
slog.Error(errorMsg)
|
||||
// status.Error(errorMsg)
|
||||
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
filename := fmt.Sprintf("opencode-panic-%s-%s.log", name, timestamp)
|
||||
|
||||
file, err := os.Create(filename)
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("Failed to create panic log file '%s': %v", filename, err)
|
||||
slog.Error(errMsg)
|
||||
// status.Error(errMsg)
|
||||
} else {
|
||||
defer file.Close()
|
||||
fmt.Fprintf(file, "Panic in %s: %v\n\n", name, r)
|
||||
fmt.Fprintf(file, "Time: %s\n\n", time.Now().Format(time.RFC3339))
|
||||
fmt.Fprintf(file, "Stack Trace:\n%s\n", string(debug.Stack())) // Capture stack trace
|
||||
infoMsg := fmt.Sprintf("Panic details written to %s", filename)
|
||||
slog.Info(infoMsg)
|
||||
// status.Info(infoMsg)
|
||||
}
|
||||
|
||||
if cleanup != nil {
|
||||
cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,797 +0,0 @@
|
||||
package lsp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/logging"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
"github.com/sst/opencode/internal/status"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
Cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
stdout *bufio.Reader
|
||||
stderr io.ReadCloser
|
||||
|
||||
// Request ID counter
|
||||
nextID atomic.Int32
|
||||
|
||||
// Response handlers
|
||||
handlers map[int32]chan *Message
|
||||
handlersMu sync.RWMutex
|
||||
|
||||
// Server request handlers
|
||||
serverRequestHandlers map[string]ServerRequestHandler
|
||||
serverHandlersMu sync.RWMutex
|
||||
|
||||
// Notification handlers
|
||||
notificationHandlers map[string]NotificationHandler
|
||||
notificationMu sync.RWMutex
|
||||
|
||||
// Diagnostic cache
|
||||
diagnostics map[protocol.DocumentUri][]protocol.Diagnostic
|
||||
diagnosticsMu sync.RWMutex
|
||||
|
||||
// Files are currently opened by the LSP
|
||||
openFiles map[string]*OpenFileInfo
|
||||
openFilesMu sync.RWMutex
|
||||
|
||||
// Server state
|
||||
serverState atomic.Value
|
||||
}
|
||||
|
||||
func NewClient(ctx context.Context, command string, args ...string) (*Client, error) {
|
||||
cmd := exec.CommandContext(ctx, command, args...)
|
||||
// Copy env
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create stdin pipe: %w", err)
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
Cmd: cmd,
|
||||
stdin: stdin,
|
||||
stdout: bufio.NewReader(stdout),
|
||||
stderr: stderr,
|
||||
handlers: make(map[int32]chan *Message),
|
||||
notificationHandlers: make(map[string]NotificationHandler),
|
||||
serverRequestHandlers: make(map[string]ServerRequestHandler),
|
||||
diagnostics: make(map[protocol.DocumentUri][]protocol.Diagnostic),
|
||||
openFiles: make(map[string]*OpenFileInfo),
|
||||
}
|
||||
|
||||
// Initialize server state
|
||||
client.serverState.Store(StateStarting)
|
||||
|
||||
// Start the LSP server process
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start LSP server: %w", err)
|
||||
}
|
||||
|
||||
// Handle stderr in a separate goroutine
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
for scanner.Scan() {
|
||||
slog.Info("LSP Server", "message", scanner.Text())
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
slog.Error("Error reading LSP stderr", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Start message handling loop
|
||||
go func() {
|
||||
defer logging.RecoverPanic("LSP-message-handler", func() {
|
||||
status.Error("LSP message handler crashed, LSP functionality may be impaired")
|
||||
})
|
||||
client.handleMessages()
|
||||
}()
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) RegisterNotificationHandler(method string, handler NotificationHandler) {
|
||||
c.notificationMu.Lock()
|
||||
defer c.notificationMu.Unlock()
|
||||
c.notificationHandlers[method] = handler
|
||||
}
|
||||
|
||||
func (c *Client) RegisterServerRequestHandler(method string, handler ServerRequestHandler) {
|
||||
c.serverHandlersMu.Lock()
|
||||
defer c.serverHandlersMu.Unlock()
|
||||
c.serverRequestHandlers[method] = handler
|
||||
}
|
||||
|
||||
func (c *Client) InitializeLSPClient(ctx context.Context, workspaceDir string) (*protocol.InitializeResult, error) {
|
||||
initParams := &protocol.InitializeParams{
|
||||
WorkspaceFoldersInitializeParams: protocol.WorkspaceFoldersInitializeParams{
|
||||
WorkspaceFolders: []protocol.WorkspaceFolder{
|
||||
{
|
||||
URI: protocol.URI("file://" + workspaceDir),
|
||||
Name: workspaceDir,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
XInitializeParams: protocol.XInitializeParams{
|
||||
ProcessID: int32(os.Getpid()),
|
||||
ClientInfo: &protocol.ClientInfo{
|
||||
Name: "mcp-language-server",
|
||||
Version: "0.1.0",
|
||||
},
|
||||
RootPath: workspaceDir,
|
||||
RootURI: protocol.DocumentUri("file://" + workspaceDir),
|
||||
Capabilities: protocol.ClientCapabilities{
|
||||
Workspace: protocol.WorkspaceClientCapabilities{
|
||||
Configuration: true,
|
||||
DidChangeConfiguration: protocol.DidChangeConfigurationClientCapabilities{
|
||||
DynamicRegistration: true,
|
||||
},
|
||||
DidChangeWatchedFiles: protocol.DidChangeWatchedFilesClientCapabilities{
|
||||
DynamicRegistration: true,
|
||||
RelativePatternSupport: true,
|
||||
},
|
||||
},
|
||||
TextDocument: protocol.TextDocumentClientCapabilities{
|
||||
Synchronization: &protocol.TextDocumentSyncClientCapabilities{
|
||||
DynamicRegistration: true,
|
||||
DidSave: true,
|
||||
},
|
||||
Completion: protocol.CompletionClientCapabilities{
|
||||
CompletionItem: protocol.ClientCompletionItemOptions{},
|
||||
},
|
||||
CodeLens: &protocol.CodeLensClientCapabilities{
|
||||
DynamicRegistration: true,
|
||||
},
|
||||
DocumentSymbol: protocol.DocumentSymbolClientCapabilities{},
|
||||
CodeAction: protocol.CodeActionClientCapabilities{
|
||||
CodeActionLiteralSupport: protocol.ClientCodeActionLiteralOptions{
|
||||
CodeActionKind: protocol.ClientCodeActionKindOptions{
|
||||
ValueSet: []protocol.CodeActionKind{},
|
||||
},
|
||||
},
|
||||
},
|
||||
PublishDiagnostics: protocol.PublishDiagnosticsClientCapabilities{
|
||||
VersionSupport: true,
|
||||
},
|
||||
SemanticTokens: protocol.SemanticTokensClientCapabilities{
|
||||
Requests: protocol.ClientSemanticTokensRequestOptions{
|
||||
Range: &protocol.Or_ClientSemanticTokensRequestOptions_range{},
|
||||
Full: &protocol.Or_ClientSemanticTokensRequestOptions_full{},
|
||||
},
|
||||
TokenTypes: []string{},
|
||||
TokenModifiers: []string{},
|
||||
Formats: []protocol.TokenFormat{},
|
||||
},
|
||||
},
|
||||
Window: protocol.WindowClientCapabilities{},
|
||||
},
|
||||
InitializationOptions: map[string]any{
|
||||
"codelenses": map[string]bool{
|
||||
"generate": true,
|
||||
"regenerate_cgo": true,
|
||||
"test": true,
|
||||
"tidy": true,
|
||||
"upgrade_dependency": true,
|
||||
"vendor": true,
|
||||
"vulncheck": false,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var result protocol.InitializeResult
|
||||
if err := c.Call(ctx, "initialize", initParams, &result); err != nil {
|
||||
return nil, fmt.Errorf("initialize failed: %w", err)
|
||||
}
|
||||
|
||||
if err := c.Notify(ctx, "initialized", struct{}{}); err != nil {
|
||||
return nil, fmt.Errorf("initialized notification failed: %w", err)
|
||||
}
|
||||
|
||||
// Register handlers
|
||||
c.RegisterServerRequestHandler("workspace/applyEdit", HandleApplyEdit)
|
||||
c.RegisterServerRequestHandler("workspace/configuration", HandleWorkspaceConfiguration)
|
||||
c.RegisterServerRequestHandler("client/registerCapability", HandleRegisterCapability)
|
||||
c.RegisterNotificationHandler("window/showMessage", HandleServerMessage)
|
||||
c.RegisterNotificationHandler("textDocument/publishDiagnostics",
|
||||
func(params json.RawMessage) { HandleDiagnostics(c, params) })
|
||||
|
||||
// Notify the LSP server
|
||||
err := c.Initialized(ctx, protocol.InitializedParams{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initialization failed: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
// Try to close all open files first
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Attempt to close files but continue shutdown regardless
|
||||
c.CloseAllFiles(ctx)
|
||||
|
||||
// Close stdin to signal the server
|
||||
if err := c.stdin.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close stdin: %w", err)
|
||||
}
|
||||
|
||||
// Use a channel to handle the Wait with timeout
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- c.Cmd.Wait()
|
||||
}()
|
||||
|
||||
// Wait for process to exit with timeout
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-time.After(2 * time.Second):
|
||||
// If we timeout, try to kill the process
|
||||
if err := c.Cmd.Process.Kill(); err != nil {
|
||||
return fmt.Errorf("failed to kill process: %w", err)
|
||||
}
|
||||
return fmt.Errorf("process killed after timeout")
|
||||
}
|
||||
}
|
||||
|
||||
type ServerState int
|
||||
|
||||
const (
|
||||
StateStarting ServerState = iota
|
||||
StateReady
|
||||
StateError
|
||||
)
|
||||
|
||||
// GetServerState returns the current state of the LSP server
|
||||
func (c *Client) GetServerState() ServerState {
|
||||
if val := c.serverState.Load(); val != nil {
|
||||
return val.(ServerState)
|
||||
}
|
||||
return StateStarting
|
||||
}
|
||||
|
||||
// SetServerState sets the current state of the LSP server
|
||||
func (c *Client) SetServerState(state ServerState) {
|
||||
c.serverState.Store(state)
|
||||
}
|
||||
|
||||
// WaitForServerReady waits for the server to be ready by polling the server
|
||||
// with a simple request until it responds successfully or times out
|
||||
func (c *Client) WaitForServerReady(ctx context.Context) error {
|
||||
cnf := config.Get()
|
||||
|
||||
// Set initial state
|
||||
c.SetServerState(StateStarting)
|
||||
|
||||
// Create a context with timeout
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Try to ping the server with a simple request
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("Waiting for LSP server to be ready...")
|
||||
}
|
||||
|
||||
// Determine server type for specialized initialization
|
||||
serverType := c.detectServerType()
|
||||
|
||||
// For TypeScript-like servers, we need to open some key files first
|
||||
if serverType == ServerTypeTypeScript {
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("TypeScript-like server detected, opening key configuration files")
|
||||
}
|
||||
c.openKeyConfigFiles(ctx)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.SetServerState(StateError)
|
||||
return fmt.Errorf("timeout waiting for LSP server to be ready")
|
||||
case <-ticker.C:
|
||||
// Try a ping method appropriate for this server type
|
||||
err := c.pingServerByType(ctx, serverType)
|
||||
if err == nil {
|
||||
// Server responded successfully
|
||||
c.SetServerState(StateReady)
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("LSP server is ready")
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
slog.Debug("LSP server not ready yet", "error", err, "serverType", serverType)
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("LSP server not ready yet", "error", err, "serverType", serverType)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ServerType represents the type of LSP server
|
||||
type ServerType int
|
||||
|
||||
const (
|
||||
ServerTypeUnknown ServerType = iota
|
||||
ServerTypeGo
|
||||
ServerTypeTypeScript
|
||||
ServerTypeRust
|
||||
ServerTypePython
|
||||
ServerTypeGeneric
|
||||
)
|
||||
|
||||
// detectServerType tries to determine what type of LSP server we're dealing with
|
||||
func (c *Client) detectServerType() ServerType {
|
||||
if c.Cmd == nil {
|
||||
return ServerTypeUnknown
|
||||
}
|
||||
|
||||
cmdPath := strings.ToLower(c.Cmd.Path)
|
||||
|
||||
switch {
|
||||
case strings.Contains(cmdPath, "gopls"):
|
||||
return ServerTypeGo
|
||||
case strings.Contains(cmdPath, "typescript") || strings.Contains(cmdPath, "vtsls") || strings.Contains(cmdPath, "tsserver"):
|
||||
return ServerTypeTypeScript
|
||||
case strings.Contains(cmdPath, "rust-analyzer"):
|
||||
return ServerTypeRust
|
||||
case strings.Contains(cmdPath, "pyright") || strings.Contains(cmdPath, "pylsp") || strings.Contains(cmdPath, "python"):
|
||||
return ServerTypePython
|
||||
default:
|
||||
return ServerTypeGeneric
|
||||
}
|
||||
}
|
||||
|
||||
// openKeyConfigFiles opens important configuration files that help initialize the server
|
||||
func (c *Client) openKeyConfigFiles(ctx context.Context) {
|
||||
workDir := config.WorkingDirectory()
|
||||
serverType := c.detectServerType()
|
||||
|
||||
var filesToOpen []string
|
||||
|
||||
switch serverType {
|
||||
case ServerTypeTypeScript:
|
||||
// TypeScript servers need these config files to properly initialize
|
||||
filesToOpen = []string{
|
||||
filepath.Join(workDir, "tsconfig.json"),
|
||||
filepath.Join(workDir, "package.json"),
|
||||
filepath.Join(workDir, "jsconfig.json"),
|
||||
}
|
||||
|
||||
// Also find and open a few TypeScript files to help the server initialize
|
||||
c.openTypeScriptFiles(ctx, workDir)
|
||||
case ServerTypeGo:
|
||||
filesToOpen = []string{
|
||||
filepath.Join(workDir, "go.mod"),
|
||||
filepath.Join(workDir, "go.sum"),
|
||||
}
|
||||
case ServerTypeRust:
|
||||
filesToOpen = []string{
|
||||
filepath.Join(workDir, "Cargo.toml"),
|
||||
filepath.Join(workDir, "Cargo.lock"),
|
||||
}
|
||||
}
|
||||
|
||||
// Try to open each file, ignoring errors if they don't exist
|
||||
for _, file := range filesToOpen {
|
||||
if _, err := os.Stat(file); err == nil {
|
||||
// File exists, try to open it
|
||||
if err := c.OpenFile(ctx, file); err != nil {
|
||||
slog.Debug("Failed to open key config file", "file", file, "error", err)
|
||||
} else {
|
||||
slog.Debug("Opened key config file for initialization", "file", file)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pingServerByType sends a ping request appropriate for the server type
|
||||
func (c *Client) pingServerByType(ctx context.Context, serverType ServerType) error {
|
||||
switch serverType {
|
||||
case ServerTypeTypeScript:
|
||||
// For TypeScript, try a document symbol request on an open file
|
||||
return c.pingTypeScriptServer(ctx)
|
||||
case ServerTypeGo:
|
||||
// For Go, workspace/symbol works well
|
||||
return c.pingWithWorkspaceSymbol(ctx)
|
||||
case ServerTypeRust:
|
||||
// For Rust, workspace/symbol works well
|
||||
return c.pingWithWorkspaceSymbol(ctx)
|
||||
default:
|
||||
// Default ping method
|
||||
return c.pingWithWorkspaceSymbol(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// pingTypeScriptServer tries to ping a TypeScript server with appropriate methods
|
||||
func (c *Client) pingTypeScriptServer(ctx context.Context) error {
|
||||
// First try workspace/symbol which works for many servers
|
||||
if err := c.pingWithWorkspaceSymbol(ctx); err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If that fails, try to find an open file and request document symbols
|
||||
c.openFilesMu.RLock()
|
||||
defer c.openFilesMu.RUnlock()
|
||||
|
||||
// If we have any open files, try to get document symbols for one
|
||||
for uri := range c.openFiles {
|
||||
filePath := strings.TrimPrefix(uri, "file://")
|
||||
if strings.HasSuffix(filePath, ".ts") || strings.HasSuffix(filePath, ".js") ||
|
||||
strings.HasSuffix(filePath, ".tsx") || strings.HasSuffix(filePath, ".jsx") {
|
||||
var symbols []protocol.DocumentSymbol
|
||||
err := c.Call(ctx, "textDocument/documentSymbol", protocol.DocumentSymbolParams{
|
||||
TextDocument: protocol.TextDocumentIdentifier{
|
||||
URI: protocol.DocumentUri(uri),
|
||||
},
|
||||
}, &symbols)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we have no open TypeScript files, try to find and open one
|
||||
workDir := config.WorkingDirectory()
|
||||
err := filepath.WalkDir(workDir, func(path string, d os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip directories and non-TypeScript files
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
ext := filepath.Ext(path)
|
||||
if ext == ".ts" || ext == ".js" || ext == ".tsx" || ext == ".jsx" {
|
||||
// Found a TypeScript file, try to open it
|
||||
if err := c.OpenFile(ctx, path); err == nil {
|
||||
// Successfully opened, stop walking
|
||||
return filepath.SkipAll
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
slog.Debug("Error walking directory for TypeScript files", "error", err)
|
||||
}
|
||||
|
||||
// Final fallback - just try a generic capability
|
||||
return c.pingWithServerCapabilities(ctx)
|
||||
}
|
||||
|
||||
// openTypeScriptFiles finds and opens TypeScript files to help initialize the server
|
||||
func (c *Client) openTypeScriptFiles(ctx context.Context, workDir string) {
|
||||
cnf := config.Get()
|
||||
filesOpened := 0
|
||||
maxFilesToOpen := 5 // Limit to a reasonable number of files
|
||||
|
||||
// Find and open TypeScript files
|
||||
err := filepath.WalkDir(workDir, func(path string, d os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip directories and non-TypeScript files
|
||||
if d.IsDir() {
|
||||
// Skip common directories to avoid wasting time
|
||||
if shouldSkipDir(path) {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if we've opened enough files
|
||||
if filesOpened >= maxFilesToOpen {
|
||||
return filepath.SkipAll
|
||||
}
|
||||
|
||||
// Check file extension
|
||||
ext := filepath.Ext(path)
|
||||
if ext == ".ts" || ext == ".tsx" || ext == ".js" || ext == ".jsx" {
|
||||
// Try to open the file
|
||||
if err := c.OpenFile(ctx, path); err == nil {
|
||||
filesOpened++
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("Opened TypeScript file for initialization", "file", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil && cnf.DebugLSP {
|
||||
slog.Debug("Error walking directory for TypeScript files", "error", err)
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("Opened TypeScript files for initialization", "count", filesOpened)
|
||||
}
|
||||
}
|
||||
|
||||
// shouldSkipDir returns true if the directory should be skipped during file search
|
||||
func shouldSkipDir(path string) bool {
|
||||
dirName := filepath.Base(path)
|
||||
|
||||
// Skip hidden directories
|
||||
if strings.HasPrefix(dirName, ".") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Skip common directories that won't contain relevant source files
|
||||
skipDirs := map[string]bool{
|
||||
"node_modules": true,
|
||||
"dist": true,
|
||||
"build": true,
|
||||
"coverage": true,
|
||||
"vendor": true,
|
||||
"target": true,
|
||||
}
|
||||
|
||||
return skipDirs[dirName]
|
||||
}
|
||||
|
||||
// pingWithWorkspaceSymbol tries a workspace/symbol request
|
||||
func (c *Client) pingWithWorkspaceSymbol(ctx context.Context) error {
|
||||
var result []protocol.SymbolInformation
|
||||
return c.Call(ctx, "workspace/symbol", protocol.WorkspaceSymbolParams{
|
||||
Query: "",
|
||||
}, &result)
|
||||
}
|
||||
|
||||
// pingWithServerCapabilities tries to get server capabilities
|
||||
func (c *Client) pingWithServerCapabilities(ctx context.Context) error {
|
||||
// This is a very lightweight request that should work for most servers
|
||||
return c.Notify(ctx, "$/cancelRequest", struct{ ID int }{ID: -1})
|
||||
}
|
||||
|
||||
type OpenFileInfo struct {
|
||||
Version int32
|
||||
URI protocol.DocumentUri
|
||||
}
|
||||
|
||||
func (c *Client) OpenFile(ctx context.Context, filepath string) error {
|
||||
uri := fmt.Sprintf("file://%s", filepath)
|
||||
|
||||
c.openFilesMu.Lock()
|
||||
if _, exists := c.openFiles[uri]; exists {
|
||||
c.openFilesMu.Unlock()
|
||||
return nil // Already open
|
||||
}
|
||||
c.openFilesMu.Unlock()
|
||||
|
||||
// Skip files that do not exist or cannot be read
|
||||
content, err := os.ReadFile(filepath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading file: %w", err)
|
||||
}
|
||||
|
||||
params := protocol.DidOpenTextDocumentParams{
|
||||
TextDocument: protocol.TextDocumentItem{
|
||||
URI: protocol.DocumentUri(uri),
|
||||
LanguageID: DetectLanguageID(uri),
|
||||
Version: 1,
|
||||
Text: string(content),
|
||||
},
|
||||
}
|
||||
|
||||
if err := c.Notify(ctx, "textDocument/didOpen", params); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.openFilesMu.Lock()
|
||||
c.openFiles[uri] = &OpenFileInfo{
|
||||
Version: 1,
|
||||
URI: protocol.DocumentUri(uri),
|
||||
}
|
||||
c.openFilesMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) NotifyChange(ctx context.Context, filepath string) error {
|
||||
uri := fmt.Sprintf("file://%s", filepath)
|
||||
|
||||
// Verify file exists before attempting to read it
|
||||
if _, err := os.Stat(filepath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// File was deleted - close it in the LSP client instead of notifying change
|
||||
return c.CloseFile(ctx, filepath)
|
||||
}
|
||||
return fmt.Errorf("error checking file: %w", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(filepath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading file: %w", err)
|
||||
}
|
||||
|
||||
c.openFilesMu.Lock()
|
||||
fileInfo, isOpen := c.openFiles[uri]
|
||||
if !isOpen {
|
||||
c.openFilesMu.Unlock()
|
||||
return fmt.Errorf("cannot notify change for unopened file: %s", filepath)
|
||||
}
|
||||
|
||||
// Increment version
|
||||
fileInfo.Version++
|
||||
version := fileInfo.Version
|
||||
c.openFilesMu.Unlock()
|
||||
|
||||
params := protocol.DidChangeTextDocumentParams{
|
||||
TextDocument: protocol.VersionedTextDocumentIdentifier{
|
||||
TextDocumentIdentifier: protocol.TextDocumentIdentifier{
|
||||
URI: protocol.DocumentUri(uri),
|
||||
},
|
||||
Version: version,
|
||||
},
|
||||
ContentChanges: []protocol.TextDocumentContentChangeEvent{
|
||||
{
|
||||
Value: protocol.TextDocumentContentChangeWholeDocument{
|
||||
Text: string(content),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return c.Notify(ctx, "textDocument/didChange", params)
|
||||
}
|
||||
|
||||
func (c *Client) CloseFile(ctx context.Context, filepath string) error {
|
||||
cnf := config.Get()
|
||||
uri := fmt.Sprintf("file://%s", filepath)
|
||||
|
||||
c.openFilesMu.Lock()
|
||||
if _, exists := c.openFiles[uri]; !exists {
|
||||
c.openFilesMu.Unlock()
|
||||
return nil // Already closed
|
||||
}
|
||||
c.openFilesMu.Unlock()
|
||||
|
||||
params := protocol.DidCloseTextDocumentParams{
|
||||
TextDocument: protocol.TextDocumentIdentifier{
|
||||
URI: protocol.DocumentUri(uri),
|
||||
},
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("Closing file", "file", filepath)
|
||||
}
|
||||
if err := c.Notify(ctx, "textDocument/didClose", params); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.openFilesMu.Lock()
|
||||
delete(c.openFiles, uri)
|
||||
c.openFilesMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) IsFileOpen(filepath string) bool {
|
||||
uri := fmt.Sprintf("file://%s", filepath)
|
||||
c.openFilesMu.RLock()
|
||||
defer c.openFilesMu.RUnlock()
|
||||
_, exists := c.openFiles[uri]
|
||||
return exists
|
||||
}
|
||||
|
||||
// CloseAllFiles closes all currently open files
|
||||
func (c *Client) CloseAllFiles(ctx context.Context) {
|
||||
cnf := config.Get()
|
||||
c.openFilesMu.Lock()
|
||||
filesToClose := make([]string, 0, len(c.openFiles))
|
||||
|
||||
// First collect all URIs that need to be closed
|
||||
for uri := range c.openFiles {
|
||||
// Convert URI back to file path by trimming "file://" prefix
|
||||
filePath := strings.TrimPrefix(uri, "file://")
|
||||
filesToClose = append(filesToClose, filePath)
|
||||
}
|
||||
c.openFilesMu.Unlock()
|
||||
|
||||
// Then close them all
|
||||
for _, filePath := range filesToClose {
|
||||
err := c.CloseFile(ctx, filePath)
|
||||
if err != nil && cnf.DebugLSP {
|
||||
slog.Warn("Error closing file", "file", filePath, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("Closed all files", "files", filesToClose)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) GetFileDiagnostics(uri protocol.DocumentUri) []protocol.Diagnostic {
|
||||
c.diagnosticsMu.RLock()
|
||||
defer c.diagnosticsMu.RUnlock()
|
||||
|
||||
return c.diagnostics[uri]
|
||||
}
|
||||
|
||||
// GetDiagnostics returns all diagnostics for all files
|
||||
func (c *Client) GetDiagnostics() map[protocol.DocumentUri][]protocol.Diagnostic {
|
||||
return c.diagnostics
|
||||
}
|
||||
|
||||
// OpenFileOnDemand opens a file only if it's not already open
|
||||
// This is used for lazy-loading files when they're actually needed
|
||||
func (c *Client) OpenFileOnDemand(ctx context.Context, filepath string) error {
|
||||
// Check if the file is already open
|
||||
if c.IsFileOpen(filepath) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Open the file
|
||||
return c.OpenFile(ctx, filepath)
|
||||
}
|
||||
|
||||
// GetDiagnosticsForFile ensures a file is open and returns its diagnostics
|
||||
// This is useful for on-demand diagnostics when using lazy loading
|
||||
func (c *Client) GetDiagnosticsForFile(ctx context.Context, filepath string) ([]protocol.Diagnostic, error) {
|
||||
uri := fmt.Sprintf("file://%s", filepath)
|
||||
documentUri := protocol.DocumentUri(uri)
|
||||
|
||||
// Make sure the file is open
|
||||
if !c.IsFileOpen(filepath) {
|
||||
if err := c.OpenFile(ctx, filepath); err != nil {
|
||||
return nil, fmt.Errorf("failed to open file for diagnostics: %w", err)
|
||||
}
|
||||
|
||||
// Give the LSP server a moment to process the file
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Get diagnostics
|
||||
c.diagnosticsMu.RLock()
|
||||
diagnostics := c.diagnostics[documentUri]
|
||||
c.diagnosticsMu.RUnlock()
|
||||
|
||||
return diagnostics, nil
|
||||
}
|
||||
|
||||
// ClearDiagnosticsForURI removes diagnostics for a specific URI from the cache
|
||||
func (c *Client) ClearDiagnosticsForURI(uri protocol.DocumentUri) {
|
||||
c.diagnosticsMu.Lock()
|
||||
defer c.diagnosticsMu.Unlock()
|
||||
delete(c.diagnostics, uri)
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
package discovery
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// IntegrateLSPServers discovers languages and LSP servers and integrates them into the application configuration
|
||||
func IntegrateLSPServers(workingDir string) error {
|
||||
// Get the current configuration
|
||||
cfg := config.Get()
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("config not loaded")
|
||||
}
|
||||
|
||||
// Check if this is the first run
|
||||
shouldInit, err := config.ShouldShowInitDialog()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check initialization status: %w", err)
|
||||
}
|
||||
|
||||
// Always run language detection, but log differently for first run vs. subsequent runs
|
||||
if shouldInit || len(cfg.LSP) == 0 {
|
||||
slog.Info("Running initial LSP auto-discovery...")
|
||||
} else {
|
||||
slog.Debug("Running LSP auto-discovery to detect new languages...")
|
||||
}
|
||||
|
||||
// Configure LSP servers
|
||||
servers, err := ConfigureLSPServers(workingDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure LSP servers: %w", err)
|
||||
}
|
||||
|
||||
// Update the configuration with discovered servers
|
||||
for langID, serverInfo := range servers {
|
||||
// Skip languages that already have a configured server
|
||||
if _, exists := cfg.LSP[langID]; exists {
|
||||
slog.Debug("LSP server already configured for language", "language", langID)
|
||||
continue
|
||||
}
|
||||
|
||||
if serverInfo.Available {
|
||||
// Only add servers that were found
|
||||
cfg.LSP[langID] = config.LSPConfig{
|
||||
Disabled: false,
|
||||
Command: serverInfo.Path,
|
||||
Args: serverInfo.Args,
|
||||
}
|
||||
slog.Info("Added LSP server to configuration",
|
||||
"language", langID,
|
||||
"command", serverInfo.Command,
|
||||
"path", serverInfo.Path)
|
||||
} else {
|
||||
slog.Warn("LSP server not available",
|
||||
"language", langID,
|
||||
"command", serverInfo.Command,
|
||||
"installCmd", serverInfo.InstallCmd)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,298 +0,0 @@
|
||||
package discovery
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// LanguageInfo stores information about a detected language
|
||||
type LanguageInfo struct {
|
||||
// Language identifier (e.g., "go", "typescript", "python")
|
||||
ID string
|
||||
|
||||
// Number of files detected for this language
|
||||
FileCount int
|
||||
|
||||
// Project files associated with this language (e.g., go.mod, package.json)
|
||||
ProjectFiles []string
|
||||
|
||||
// Whether this is likely a primary language in the project
|
||||
IsPrimary bool
|
||||
}
|
||||
|
||||
// ProjectFile represents a project configuration file
|
||||
type ProjectFile struct {
|
||||
// File name or pattern to match
|
||||
Name string
|
||||
|
||||
// Associated language ID
|
||||
LanguageID string
|
||||
|
||||
// Whether this file strongly indicates the language is primary
|
||||
IsPrimary bool
|
||||
}
|
||||
|
||||
// Common project files that indicate specific languages
|
||||
var projectFilePatterns = []ProjectFile{
|
||||
{Name: "go.mod", LanguageID: "go", IsPrimary: true},
|
||||
{Name: "go.sum", LanguageID: "go", IsPrimary: false},
|
||||
{Name: "package.json", LanguageID: "javascript", IsPrimary: true}, // Could be TypeScript too
|
||||
{Name: "tsconfig.json", LanguageID: "typescript", IsPrimary: true},
|
||||
{Name: "jsconfig.json", LanguageID: "javascript", IsPrimary: true},
|
||||
{Name: "pyproject.toml", LanguageID: "python", IsPrimary: true},
|
||||
{Name: "setup.py", LanguageID: "python", IsPrimary: true},
|
||||
{Name: "requirements.txt", LanguageID: "python", IsPrimary: true},
|
||||
{Name: "Cargo.toml", LanguageID: "rust", IsPrimary: true},
|
||||
{Name: "Cargo.lock", LanguageID: "rust", IsPrimary: false},
|
||||
{Name: "CMakeLists.txt", LanguageID: "cmake", IsPrimary: true},
|
||||
{Name: "pom.xml", LanguageID: "java", IsPrimary: true},
|
||||
{Name: "build.gradle", LanguageID: "java", IsPrimary: true},
|
||||
{Name: "build.gradle.kts", LanguageID: "kotlin", IsPrimary: true},
|
||||
{Name: "composer.json", LanguageID: "php", IsPrimary: true},
|
||||
{Name: "Gemfile", LanguageID: "ruby", IsPrimary: true},
|
||||
{Name: "Rakefile", LanguageID: "ruby", IsPrimary: true},
|
||||
{Name: "mix.exs", LanguageID: "elixir", IsPrimary: true},
|
||||
{Name: "rebar.config", LanguageID: "erlang", IsPrimary: true},
|
||||
{Name: "dune-project", LanguageID: "ocaml", IsPrimary: true},
|
||||
{Name: "stack.yaml", LanguageID: "haskell", IsPrimary: true},
|
||||
{Name: "cabal.project", LanguageID: "haskell", IsPrimary: true},
|
||||
{Name: "Makefile", LanguageID: "make", IsPrimary: false},
|
||||
{Name: "Dockerfile", LanguageID: "dockerfile", IsPrimary: false},
|
||||
}
|
||||
|
||||
// Map of file extensions to language IDs
|
||||
var extensionToLanguage = map[string]string{
|
||||
".go": "go",
|
||||
".js": "javascript",
|
||||
".jsx": "javascript",
|
||||
".ts": "typescript",
|
||||
".tsx": "typescript",
|
||||
".py": "python",
|
||||
".rs": "rust",
|
||||
".java": "java",
|
||||
".c": "c",
|
||||
".cpp": "cpp",
|
||||
".h": "c",
|
||||
".hpp": "cpp",
|
||||
".rb": "ruby",
|
||||
".php": "php",
|
||||
".cs": "csharp",
|
||||
".fs": "fsharp",
|
||||
".swift": "swift",
|
||||
".kt": "kotlin",
|
||||
".scala": "scala",
|
||||
".hs": "haskell",
|
||||
".ml": "ocaml",
|
||||
".ex": "elixir",
|
||||
".exs": "elixir",
|
||||
".erl": "erlang",
|
||||
".lua": "lua",
|
||||
".r": "r",
|
||||
".sh": "shell",
|
||||
".bash": "shell",
|
||||
".zsh": "shell",
|
||||
".html": "html",
|
||||
".css": "css",
|
||||
".scss": "scss",
|
||||
".sass": "sass",
|
||||
".less": "less",
|
||||
".json": "json",
|
||||
".xml": "xml",
|
||||
".yaml": "yaml",
|
||||
".yml": "yaml",
|
||||
".md": "markdown",
|
||||
".dart": "dart",
|
||||
}
|
||||
|
||||
// Directories to exclude from scanning
|
||||
var excludedDirs = map[string]bool{
|
||||
".git": true,
|
||||
"node_modules": true,
|
||||
"vendor": true,
|
||||
"dist": true,
|
||||
"build": true,
|
||||
"target": true,
|
||||
".idea": true,
|
||||
".vscode": true,
|
||||
".github": true,
|
||||
".gitlab": true,
|
||||
"__pycache__": true,
|
||||
".next": true,
|
||||
".nuxt": true,
|
||||
"venv": true,
|
||||
"env": true,
|
||||
".env": true,
|
||||
}
|
||||
|
||||
// DetectLanguages scans a directory to identify programming languages used in the project
|
||||
func DetectLanguages(rootDir string) (map[string]LanguageInfo, error) {
|
||||
languages := make(map[string]LanguageInfo)
|
||||
var mutex sync.Mutex
|
||||
|
||||
// Walk the directory tree
|
||||
err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil // Skip files that can't be accessed
|
||||
}
|
||||
|
||||
// Skip excluded directories
|
||||
if info.IsDir() {
|
||||
if excludedDirs[info.Name()] || strings.HasPrefix(info.Name(), ".") {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip hidden files
|
||||
if strings.HasPrefix(info.Name(), ".") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for project files
|
||||
for _, pattern := range projectFilePatterns {
|
||||
if info.Name() == pattern.Name {
|
||||
mutex.Lock()
|
||||
lang, exists := languages[pattern.LanguageID]
|
||||
if !exists {
|
||||
lang = LanguageInfo{
|
||||
ID: pattern.LanguageID,
|
||||
FileCount: 0,
|
||||
ProjectFiles: []string{},
|
||||
IsPrimary: pattern.IsPrimary,
|
||||
}
|
||||
}
|
||||
lang.ProjectFiles = append(lang.ProjectFiles, path)
|
||||
if pattern.IsPrimary {
|
||||
lang.IsPrimary = true
|
||||
}
|
||||
languages[pattern.LanguageID] = lang
|
||||
mutex.Unlock()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check file extension
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
if langID, ok := extensionToLanguage[ext]; ok {
|
||||
mutex.Lock()
|
||||
lang, exists := languages[langID]
|
||||
if !exists {
|
||||
lang = LanguageInfo{
|
||||
ID: langID,
|
||||
FileCount: 0,
|
||||
ProjectFiles: []string{},
|
||||
}
|
||||
}
|
||||
lang.FileCount++
|
||||
languages[langID] = lang
|
||||
mutex.Unlock()
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Determine primary languages based on file count if not already marked
|
||||
determinePrimaryLanguages(languages)
|
||||
|
||||
// Log detected languages
|
||||
for id, info := range languages {
|
||||
if info.IsPrimary {
|
||||
slog.Debug("Detected primary language", "language", id, "files", info.FileCount, "projectFiles", len(info.ProjectFiles))
|
||||
} else {
|
||||
slog.Debug("Detected secondary language", "language", id, "files", info.FileCount)
|
||||
}
|
||||
}
|
||||
|
||||
return languages, nil
|
||||
}
|
||||
|
||||
// determinePrimaryLanguages marks languages as primary based on file count
|
||||
func determinePrimaryLanguages(languages map[string]LanguageInfo) {
|
||||
// Find the language with the most files
|
||||
var maxFiles int
|
||||
for _, info := range languages {
|
||||
if info.FileCount > maxFiles {
|
||||
maxFiles = info.FileCount
|
||||
}
|
||||
}
|
||||
|
||||
// Mark languages with at least 20% of the max files as primary
|
||||
threshold := max(maxFiles/5, 5) // At least 5 files to be considered primary
|
||||
|
||||
for id, info := range languages {
|
||||
if !info.IsPrimary && info.FileCount >= threshold {
|
||||
info.IsPrimary = true
|
||||
languages[id] = info
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetLanguageIDFromExtension returns the language ID for a given file extension
|
||||
func GetLanguageIDFromExtension(ext string) string {
|
||||
ext = strings.ToLower(ext)
|
||||
if langID, ok := extensionToLanguage[ext]; ok {
|
||||
return langID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetLanguageIDFromProtocol converts a protocol.LanguageKind to our language ID string
|
||||
func GetLanguageIDFromProtocol(langKind string) string {
|
||||
// Convert protocol language kind to our language ID
|
||||
switch langKind {
|
||||
case "go":
|
||||
return "go"
|
||||
case "typescript":
|
||||
return "typescript"
|
||||
case "typescriptreact":
|
||||
return "typescript"
|
||||
case "javascript":
|
||||
return "javascript"
|
||||
case "javascriptreact":
|
||||
return "javascript"
|
||||
case "python":
|
||||
return "python"
|
||||
case "rust":
|
||||
return "rust"
|
||||
case "java":
|
||||
return "java"
|
||||
case "c":
|
||||
return "c"
|
||||
case "cpp":
|
||||
return "cpp"
|
||||
default:
|
||||
// Try to normalize the language kind
|
||||
return strings.ToLower(langKind)
|
||||
}
|
||||
}
|
||||
|
||||
// GetLanguageIDFromPath determines the language ID from a file path
|
||||
func GetLanguageIDFromPath(path string) string {
|
||||
// Check file extension first
|
||||
ext := filepath.Ext(path)
|
||||
if langID := GetLanguageIDFromExtension(ext); langID != "" {
|
||||
return langID
|
||||
}
|
||||
|
||||
// Check if it's a known project file
|
||||
filename := filepath.Base(path)
|
||||
for _, pattern := range projectFilePatterns {
|
||||
if filename == pattern.Name {
|
||||
return pattern.LanguageID
|
||||
}
|
||||
}
|
||||
|
||||
// Use LSP's detection as a fallback
|
||||
uri := "file://" + path
|
||||
langKind := lsp.DetectLanguageID(uri)
|
||||
return GetLanguageIDFromProtocol(string(langKind))
|
||||
}
|
||||
@@ -1,306 +0,0 @@
|
||||
package discovery
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// ServerInfo contains information about an LSP server
|
||||
type ServerInfo struct {
|
||||
// Command to run the server
|
||||
Command string
|
||||
|
||||
// Arguments to pass to the command
|
||||
Args []string
|
||||
|
||||
// Command to install the server (for user guidance)
|
||||
InstallCmd string
|
||||
|
||||
// Whether this server is available
|
||||
Available bool
|
||||
|
||||
// Full path to the executable (if found)
|
||||
Path string
|
||||
}
|
||||
|
||||
// LanguageServerMap maps language IDs to their corresponding LSP servers
|
||||
var LanguageServerMap = map[string]ServerInfo{
|
||||
"go": {
|
||||
Command: "gopls",
|
||||
InstallCmd: "go install golang.org/x/tools/gopls@latest",
|
||||
},
|
||||
"typescript": {
|
||||
Command: "typescript-language-server",
|
||||
Args: []string{"--stdio"},
|
||||
InstallCmd: "npm install -g typescript-language-server typescript",
|
||||
},
|
||||
"javascript": {
|
||||
Command: "typescript-language-server",
|
||||
Args: []string{"--stdio"},
|
||||
InstallCmd: "npm install -g typescript-language-server typescript",
|
||||
},
|
||||
"python": {
|
||||
Command: "pylsp",
|
||||
InstallCmd: "pip install python-lsp-server",
|
||||
},
|
||||
"rust": {
|
||||
Command: "rust-analyzer",
|
||||
InstallCmd: "rustup component add rust-analyzer",
|
||||
},
|
||||
"java": {
|
||||
Command: "jdtls",
|
||||
InstallCmd: "Install Eclipse JDT Language Server",
|
||||
},
|
||||
"c": {
|
||||
Command: "clangd",
|
||||
InstallCmd: "Install clangd from your package manager",
|
||||
},
|
||||
"cpp": {
|
||||
Command: "clangd",
|
||||
InstallCmd: "Install clangd from your package manager",
|
||||
},
|
||||
"php": {
|
||||
Command: "intelephense",
|
||||
Args: []string{"--stdio"},
|
||||
InstallCmd: "npm install -g intelephense",
|
||||
},
|
||||
"ruby": {
|
||||
Command: "solargraph",
|
||||
Args: []string{"stdio"},
|
||||
InstallCmd: "gem install solargraph",
|
||||
},
|
||||
"lua": {
|
||||
Command: "lua-language-server",
|
||||
InstallCmd: "Install lua-language-server from your package manager",
|
||||
},
|
||||
"html": {
|
||||
Command: "vscode-html-language-server",
|
||||
Args: []string{"--stdio"},
|
||||
InstallCmd: "npm install -g vscode-langservers-extracted",
|
||||
},
|
||||
"css": {
|
||||
Command: "vscode-css-language-server",
|
||||
Args: []string{"--stdio"},
|
||||
InstallCmd: "npm install -g vscode-langservers-extracted",
|
||||
},
|
||||
"json": {
|
||||
Command: "vscode-json-language-server",
|
||||
Args: []string{"--stdio"},
|
||||
InstallCmd: "npm install -g vscode-langservers-extracted",
|
||||
},
|
||||
"yaml": {
|
||||
Command: "yaml-language-server",
|
||||
Args: []string{"--stdio"},
|
||||
InstallCmd: "npm install -g yaml-language-server",
|
||||
},
|
||||
}
|
||||
|
||||
// FindLSPServer searches for an LSP server for the given language
|
||||
func FindLSPServer(languageID string) (ServerInfo, error) {
|
||||
// Get server info for the language
|
||||
serverInfo, exists := LanguageServerMap[languageID]
|
||||
if !exists {
|
||||
return ServerInfo{}, fmt.Errorf("no LSP server defined for language: %s", languageID)
|
||||
}
|
||||
|
||||
// Check if the command is in PATH
|
||||
path, err := exec.LookPath(serverInfo.Command)
|
||||
if err == nil {
|
||||
serverInfo.Available = true
|
||||
serverInfo.Path = path
|
||||
slog.Debug("Found LSP server in PATH", "language", languageID, "command", serverInfo.Command, "path", path)
|
||||
return serverInfo, nil
|
||||
}
|
||||
|
||||
// If not in PATH, search in common installation locations
|
||||
paths := getCommonLSPPaths(languageID, serverInfo.Command)
|
||||
for _, searchPath := range paths {
|
||||
if _, err := os.Stat(searchPath); err == nil {
|
||||
// Found the server
|
||||
serverInfo.Available = true
|
||||
serverInfo.Path = searchPath
|
||||
slog.Debug("Found LSP server in common location", "language", languageID, "command", serverInfo.Command, "path", searchPath)
|
||||
return serverInfo, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Server not found
|
||||
slog.Debug("LSP server not found", "language", languageID, "command", serverInfo.Command)
|
||||
return serverInfo, fmt.Errorf("LSP server for %s not found. Install with: %s", languageID, serverInfo.InstallCmd)
|
||||
}
|
||||
|
||||
// getCommonLSPPaths returns common installation paths for LSP servers based on language and OS
|
||||
func getCommonLSPPaths(languageID, command string) []string {
|
||||
var paths []string
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
slog.Error("Failed to get user home directory", "error", err)
|
||||
return paths
|
||||
}
|
||||
|
||||
// Add platform-specific paths
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
// macOS paths
|
||||
paths = append(paths,
|
||||
fmt.Sprintf("/usr/local/bin/%s", command),
|
||||
fmt.Sprintf("/opt/homebrew/bin/%s", command),
|
||||
fmt.Sprintf("%s/.local/bin/%s", homeDir, command),
|
||||
)
|
||||
case "linux":
|
||||
// Linux paths
|
||||
paths = append(paths,
|
||||
fmt.Sprintf("/usr/bin/%s", command),
|
||||
fmt.Sprintf("/usr/local/bin/%s", command),
|
||||
fmt.Sprintf("%s/.local/bin/%s", homeDir, command),
|
||||
)
|
||||
case "windows":
|
||||
// Windows paths
|
||||
paths = append(paths,
|
||||
fmt.Sprintf("%s\\AppData\\Local\\Programs\\%s.exe", homeDir, command),
|
||||
fmt.Sprintf("C:\\Program Files\\%s\\bin\\%s.exe", command, command),
|
||||
)
|
||||
}
|
||||
|
||||
// Add language-specific paths
|
||||
switch languageID {
|
||||
case "go":
|
||||
gopath := os.Getenv("GOPATH")
|
||||
if gopath == "" {
|
||||
gopath = filepath.Join(homeDir, "go")
|
||||
}
|
||||
paths = append(paths, filepath.Join(gopath, "bin", command))
|
||||
if runtime.GOOS == "windows" {
|
||||
paths = append(paths, filepath.Join(gopath, "bin", command+".exe"))
|
||||
}
|
||||
case "typescript", "javascript", "html", "css", "json", "yaml", "php":
|
||||
// Node.js global packages
|
||||
if runtime.GOOS == "windows" {
|
||||
paths = append(paths,
|
||||
fmt.Sprintf("%s\\AppData\\Roaming\\npm\\%s.cmd", homeDir, command),
|
||||
fmt.Sprintf("%s\\AppData\\Roaming\\npm\\node_modules\\.bin\\%s.cmd", homeDir, command),
|
||||
)
|
||||
} else {
|
||||
paths = append(paths,
|
||||
fmt.Sprintf("%s/.npm-global/bin/%s", homeDir, command),
|
||||
fmt.Sprintf("%s/.nvm/versions/node/*/bin/%s", homeDir, command),
|
||||
fmt.Sprintf("/usr/local/lib/node_modules/.bin/%s", command),
|
||||
)
|
||||
}
|
||||
case "python":
|
||||
// Python paths
|
||||
if runtime.GOOS == "windows" {
|
||||
paths = append(paths,
|
||||
fmt.Sprintf("%s\\AppData\\Local\\Programs\\Python\\Python*\\Scripts\\%s.exe", homeDir, command),
|
||||
fmt.Sprintf("C:\\Python*\\Scripts\\%s.exe", command),
|
||||
)
|
||||
} else {
|
||||
paths = append(paths,
|
||||
fmt.Sprintf("%s/.local/bin/%s", homeDir, command),
|
||||
fmt.Sprintf("%s/.pyenv/shims/%s", homeDir, command),
|
||||
fmt.Sprintf("/usr/local/bin/%s", command),
|
||||
)
|
||||
}
|
||||
case "rust":
|
||||
// Rust paths
|
||||
if runtime.GOOS == "windows" {
|
||||
paths = append(paths,
|
||||
fmt.Sprintf("%s\\.rustup\\toolchains\\*\\bin\\%s.exe", homeDir, command),
|
||||
fmt.Sprintf("%s\\.cargo\\bin\\%s.exe", homeDir, command),
|
||||
)
|
||||
} else {
|
||||
paths = append(paths,
|
||||
fmt.Sprintf("%s/.rustup/toolchains/*/bin/%s", homeDir, command),
|
||||
fmt.Sprintf("%s/.cargo/bin/%s", homeDir, command),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Add VSCode extensions path
|
||||
vscodePath := getVSCodeExtensionsPath(homeDir)
|
||||
if vscodePath != "" {
|
||||
paths = append(paths, vscodePath)
|
||||
}
|
||||
|
||||
// Expand any glob patterns in paths
|
||||
var expandedPaths []string
|
||||
for _, path := range paths {
|
||||
if strings.Contains(path, "*") {
|
||||
// This is a glob pattern, expand it
|
||||
matches, err := filepath.Glob(path)
|
||||
if err == nil {
|
||||
expandedPaths = append(expandedPaths, matches...)
|
||||
}
|
||||
} else {
|
||||
expandedPaths = append(expandedPaths, path)
|
||||
}
|
||||
}
|
||||
|
||||
return expandedPaths
|
||||
}
|
||||
|
||||
// getVSCodeExtensionsPath returns the path to VSCode extensions directory
|
||||
func getVSCodeExtensionsPath(homeDir string) string {
|
||||
var basePath string
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
basePath = filepath.Join(homeDir, "Library", "Application Support", "Code", "User", "globalStorage")
|
||||
case "linux":
|
||||
basePath = filepath.Join(homeDir, ".config", "Code", "User", "globalStorage")
|
||||
case "windows":
|
||||
basePath = filepath.Join(homeDir, "AppData", "Roaming", "Code", "User", "globalStorage")
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check if the directory exists
|
||||
if _, err := os.Stat(basePath); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return basePath
|
||||
}
|
||||
|
||||
// ConfigureLSPServers detects languages and configures LSP servers
|
||||
func ConfigureLSPServers(rootDir string) (map[string]ServerInfo, error) {
|
||||
// Detect languages in the project
|
||||
languages, err := DetectLanguages(rootDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to detect languages: %w", err)
|
||||
}
|
||||
|
||||
// Find LSP servers for detected languages
|
||||
servers := make(map[string]ServerInfo)
|
||||
for langID, langInfo := range languages {
|
||||
// Prioritize primary languages but include all languages that have server definitions
|
||||
if !langInfo.IsPrimary && langInfo.FileCount < 3 {
|
||||
// Skip non-primary languages with very few files
|
||||
slog.Debug("Skipping non-primary language with few files", "language", langID, "files", langInfo.FileCount)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if we have a server for this language
|
||||
serverInfo, err := FindLSPServer(langID)
|
||||
if err != nil {
|
||||
slog.Warn("LSP server not found", "language", langID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Add to the map of configured servers
|
||||
servers[langID] = serverInfo
|
||||
if langInfo.IsPrimary {
|
||||
slog.Info("Configured LSP server for primary language", "language", langID, "command", serverInfo.Command, "path", serverInfo.Path)
|
||||
} else {
|
||||
slog.Info("Configured LSP server for secondary language", "language", langID, "command", serverInfo.Command, "path", serverInfo.Path)
|
||||
}
|
||||
}
|
||||
|
||||
return servers, nil
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
package lsp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
"github.com/sst/opencode/internal/lsp/util"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// Requests
|
||||
|
||||
func HandleWorkspaceConfiguration(params json.RawMessage) (any, error) {
|
||||
return []map[string]any{{}}, nil
|
||||
}
|
||||
|
||||
func HandleRegisterCapability(params json.RawMessage) (any, error) {
|
||||
var registerParams protocol.RegistrationParams
|
||||
if err := json.Unmarshal(params, ®isterParams); err != nil {
|
||||
slog.Error("Error unmarshaling registration params", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, reg := range registerParams.Registrations {
|
||||
switch reg.Method {
|
||||
case "workspace/didChangeWatchedFiles":
|
||||
// Parse the registration options
|
||||
optionsJSON, err := json.Marshal(reg.RegisterOptions)
|
||||
if err != nil {
|
||||
slog.Error("Error marshaling registration options", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
var options protocol.DidChangeWatchedFilesRegistrationOptions
|
||||
if err := json.Unmarshal(optionsJSON, &options); err != nil {
|
||||
slog.Error("Error unmarshaling registration options", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Store the file watchers registrations
|
||||
notifyFileWatchRegistration(reg.ID, options.Watchers)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func HandleApplyEdit(params json.RawMessage) (any, error) {
|
||||
var edit protocol.ApplyWorkspaceEditParams
|
||||
if err := json.Unmarshal(params, &edit); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err := util.ApplyWorkspaceEdit(edit.Edit)
|
||||
if err != nil {
|
||||
slog.Error("Error applying workspace edit", "error", err)
|
||||
return protocol.ApplyWorkspaceEditResult{Applied: false, FailureReason: err.Error()}, nil
|
||||
}
|
||||
|
||||
return protocol.ApplyWorkspaceEditResult{Applied: true}, nil
|
||||
}
|
||||
|
||||
// FileWatchRegistrationHandler is a function that will be called when file watch registrations are received
|
||||
type FileWatchRegistrationHandler func(id string, watchers []protocol.FileSystemWatcher)
|
||||
|
||||
// fileWatchHandler holds the current handler for file watch registrations
|
||||
var fileWatchHandler FileWatchRegistrationHandler
|
||||
|
||||
// RegisterFileWatchHandler sets the handler for file watch registrations
|
||||
func RegisterFileWatchHandler(handler FileWatchRegistrationHandler) {
|
||||
fileWatchHandler = handler
|
||||
}
|
||||
|
||||
// notifyFileWatchRegistration notifies the handler about new file watch registrations
|
||||
func notifyFileWatchRegistration(id string, watchers []protocol.FileSystemWatcher) {
|
||||
if fileWatchHandler != nil {
|
||||
fileWatchHandler(id, watchers)
|
||||
}
|
||||
}
|
||||
|
||||
// Notifications
|
||||
|
||||
func HandleServerMessage(params json.RawMessage) {
|
||||
cnf := config.Get()
|
||||
var msg struct {
|
||||
Type int `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.Unmarshal(params, &msg); err == nil {
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("Server message", "type", msg.Type, "message", msg.Message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func HandleDiagnostics(client *Client, params json.RawMessage) {
|
||||
var diagParams protocol.PublishDiagnosticsParams
|
||||
if err := json.Unmarshal(params, &diagParams); err != nil {
|
||||
slog.Error("Error unmarshaling diagnostics params", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
client.diagnosticsMu.Lock()
|
||||
defer client.diagnosticsMu.Unlock()
|
||||
|
||||
client.diagnostics[diagParams.URI] = diagParams.Diagnostics
|
||||
}
|
||||
@@ -1,132 +0,0 @@
|
||||
package lsp
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
)
|
||||
|
||||
func DetectLanguageID(uri string) protocol.LanguageKind {
|
||||
ext := strings.ToLower(filepath.Ext(uri))
|
||||
switch ext {
|
||||
case ".abap":
|
||||
return protocol.LangABAP
|
||||
case ".bat":
|
||||
return protocol.LangWindowsBat
|
||||
case ".bib", ".bibtex":
|
||||
return protocol.LangBibTeX
|
||||
case ".clj":
|
||||
return protocol.LangClojure
|
||||
case ".coffee":
|
||||
return protocol.LangCoffeescript
|
||||
case ".c":
|
||||
return protocol.LangC
|
||||
case ".cpp", ".cxx", ".cc", ".c++":
|
||||
return protocol.LangCPP
|
||||
case ".cs":
|
||||
return protocol.LangCSharp
|
||||
case ".css":
|
||||
return protocol.LangCSS
|
||||
case ".d":
|
||||
return protocol.LangD
|
||||
case ".pas", ".pascal":
|
||||
return protocol.LangDelphi
|
||||
case ".diff", ".patch":
|
||||
return protocol.LangDiff
|
||||
case ".dart":
|
||||
return protocol.LangDart
|
||||
case ".dockerfile":
|
||||
return protocol.LangDockerfile
|
||||
case ".ex", ".exs":
|
||||
return protocol.LangElixir
|
||||
case ".erl", ".hrl":
|
||||
return protocol.LangErlang
|
||||
case ".fs", ".fsi", ".fsx", ".fsscript":
|
||||
return protocol.LangFSharp
|
||||
case ".gitcommit":
|
||||
return protocol.LangGitCommit
|
||||
case ".gitrebase":
|
||||
return protocol.LangGitRebase
|
||||
case ".go":
|
||||
return protocol.LangGo
|
||||
case ".groovy":
|
||||
return protocol.LangGroovy
|
||||
case ".hbs", ".handlebars":
|
||||
return protocol.LangHandlebars
|
||||
case ".hs":
|
||||
return protocol.LangHaskell
|
||||
case ".html", ".htm":
|
||||
return protocol.LangHTML
|
||||
case ".ini":
|
||||
return protocol.LangIni
|
||||
case ".java":
|
||||
return protocol.LangJava
|
||||
case ".js":
|
||||
return protocol.LangJavaScript
|
||||
case ".jsx":
|
||||
return protocol.LangJavaScriptReact
|
||||
case ".json":
|
||||
return protocol.LangJSON
|
||||
case ".tex", ".latex":
|
||||
return protocol.LangLaTeX
|
||||
case ".less":
|
||||
return protocol.LangLess
|
||||
case ".lua":
|
||||
return protocol.LangLua
|
||||
case ".makefile", "makefile":
|
||||
return protocol.LangMakefile
|
||||
case ".md", ".markdown":
|
||||
return protocol.LangMarkdown
|
||||
case ".m":
|
||||
return protocol.LangObjectiveC
|
||||
case ".mm":
|
||||
return protocol.LangObjectiveCPP
|
||||
case ".pl":
|
||||
return protocol.LangPerl
|
||||
case ".pm":
|
||||
return protocol.LangPerl6
|
||||
case ".php":
|
||||
return protocol.LangPHP
|
||||
case ".ps1", ".psm1":
|
||||
return protocol.LangPowershell
|
||||
case ".pug", ".jade":
|
||||
return protocol.LangPug
|
||||
case ".py":
|
||||
return protocol.LangPython
|
||||
case ".r":
|
||||
return protocol.LangR
|
||||
case ".cshtml", ".razor":
|
||||
return protocol.LangRazor
|
||||
case ".rb":
|
||||
return protocol.LangRuby
|
||||
case ".rs":
|
||||
return protocol.LangRust
|
||||
case ".scss":
|
||||
return protocol.LangSCSS
|
||||
case ".sass":
|
||||
return protocol.LangSASS
|
||||
case ".scala":
|
||||
return protocol.LangScala
|
||||
case ".shader":
|
||||
return protocol.LangShaderLab
|
||||
case ".sh", ".bash", ".zsh", ".ksh":
|
||||
return protocol.LangShellScript
|
||||
case ".sql":
|
||||
return protocol.LangSQL
|
||||
case ".swift":
|
||||
return protocol.LangSwift
|
||||
case ".ts":
|
||||
return protocol.LangTypeScript
|
||||
case ".tsx":
|
||||
return protocol.LangTypeScriptReact
|
||||
case ".xml":
|
||||
return protocol.LangXML
|
||||
case ".xsl":
|
||||
return protocol.LangXSL
|
||||
case ".yaml", ".yml":
|
||||
return protocol.LangYAML
|
||||
default:
|
||||
return protocol.LanguageKind("") // Unknown language
|
||||
}
|
||||
}
|
||||
@@ -1,554 +0,0 @@
|
||||
// Generated code. Do not edit
|
||||
package lsp
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
)
|
||||
|
||||
// Implementation sends a textDocument/implementation request to the LSP server.
|
||||
// A request to resolve the implementation locations of a symbol at a given text document position. The request's parameter is of type TextDocumentPositionParams the response is of type Definition or a Thenable that resolves to such.
|
||||
func (c *Client) Implementation(ctx context.Context, params protocol.ImplementationParams) (protocol.Or_Result_textDocument_implementation, error) {
|
||||
var result protocol.Or_Result_textDocument_implementation
|
||||
err := c.Call(ctx, "textDocument/implementation", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// TypeDefinition sends a textDocument/typeDefinition request to the LSP server.
|
||||
// A request to resolve the type definition locations of a symbol at a given text document position. The request's parameter is of type TextDocumentPositionParams the response is of type Definition or a Thenable that resolves to such.
|
||||
func (c *Client) TypeDefinition(ctx context.Context, params protocol.TypeDefinitionParams) (protocol.Or_Result_textDocument_typeDefinition, error) {
|
||||
var result protocol.Or_Result_textDocument_typeDefinition
|
||||
err := c.Call(ctx, "textDocument/typeDefinition", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// DocumentColor sends a textDocument/documentColor request to the LSP server.
|
||||
// A request to list all color symbols found in a given text document. The request's parameter is of type DocumentColorParams the response is of type ColorInformation ColorInformation[] or a Thenable that resolves to such.
|
||||
func (c *Client) DocumentColor(ctx context.Context, params protocol.DocumentColorParams) ([]protocol.ColorInformation, error) {
|
||||
var result []protocol.ColorInformation
|
||||
err := c.Call(ctx, "textDocument/documentColor", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// ColorPresentation sends a textDocument/colorPresentation request to the LSP server.
|
||||
// A request to list all presentation for a color. The request's parameter is of type ColorPresentationParams the response is of type ColorInformation ColorInformation[] or a Thenable that resolves to such.
|
||||
func (c *Client) ColorPresentation(ctx context.Context, params protocol.ColorPresentationParams) ([]protocol.ColorPresentation, error) {
|
||||
var result []protocol.ColorPresentation
|
||||
err := c.Call(ctx, "textDocument/colorPresentation", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// FoldingRange sends a textDocument/foldingRange request to the LSP server.
|
||||
// A request to provide folding ranges in a document. The request's parameter is of type FoldingRangeParams, the response is of type FoldingRangeList or a Thenable that resolves to such.
|
||||
func (c *Client) FoldingRange(ctx context.Context, params protocol.FoldingRangeParams) ([]protocol.FoldingRange, error) {
|
||||
var result []protocol.FoldingRange
|
||||
err := c.Call(ctx, "textDocument/foldingRange", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Declaration sends a textDocument/declaration request to the LSP server.
|
||||
// A request to resolve the type definition locations of a symbol at a given text document position. The request's parameter is of type TextDocumentPositionParams the response is of type Declaration or a typed array of DeclarationLink or a Thenable that resolves to such.
|
||||
func (c *Client) Declaration(ctx context.Context, params protocol.DeclarationParams) (protocol.Or_Result_textDocument_declaration, error) {
|
||||
var result protocol.Or_Result_textDocument_declaration
|
||||
err := c.Call(ctx, "textDocument/declaration", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// SelectionRange sends a textDocument/selectionRange request to the LSP server.
|
||||
// A request to provide selection ranges in a document. The request's parameter is of type SelectionRangeParams, the response is of type SelectionRange SelectionRange[] or a Thenable that resolves to such.
|
||||
func (c *Client) SelectionRange(ctx context.Context, params protocol.SelectionRangeParams) ([]protocol.SelectionRange, error) {
|
||||
var result []protocol.SelectionRange
|
||||
err := c.Call(ctx, "textDocument/selectionRange", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// PrepareCallHierarchy sends a textDocument/prepareCallHierarchy request to the LSP server.
|
||||
// A request to result a CallHierarchyItem in a document at a given position. Can be used as an input to an incoming or outgoing call hierarchy. Since 3.16.0
|
||||
func (c *Client) PrepareCallHierarchy(ctx context.Context, params protocol.CallHierarchyPrepareParams) ([]protocol.CallHierarchyItem, error) {
|
||||
var result []protocol.CallHierarchyItem
|
||||
err := c.Call(ctx, "textDocument/prepareCallHierarchy", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// IncomingCalls sends a callHierarchy/incomingCalls request to the LSP server.
|
||||
// A request to resolve the incoming calls for a given CallHierarchyItem. Since 3.16.0
|
||||
func (c *Client) IncomingCalls(ctx context.Context, params protocol.CallHierarchyIncomingCallsParams) ([]protocol.CallHierarchyIncomingCall, error) {
|
||||
var result []protocol.CallHierarchyIncomingCall
|
||||
err := c.Call(ctx, "callHierarchy/incomingCalls", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// OutgoingCalls sends a callHierarchy/outgoingCalls request to the LSP server.
|
||||
// A request to resolve the outgoing calls for a given CallHierarchyItem. Since 3.16.0
|
||||
func (c *Client) OutgoingCalls(ctx context.Context, params protocol.CallHierarchyOutgoingCallsParams) ([]protocol.CallHierarchyOutgoingCall, error) {
|
||||
var result []protocol.CallHierarchyOutgoingCall
|
||||
err := c.Call(ctx, "callHierarchy/outgoingCalls", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// SemanticTokensFull sends a textDocument/semanticTokens/full request to the LSP server.
|
||||
// Since 3.16.0
|
||||
func (c *Client) SemanticTokensFull(ctx context.Context, params protocol.SemanticTokensParams) (protocol.SemanticTokens, error) {
|
||||
var result protocol.SemanticTokens
|
||||
err := c.Call(ctx, "textDocument/semanticTokens/full", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// SemanticTokensFullDelta sends a textDocument/semanticTokens/full/delta request to the LSP server.
|
||||
// Since 3.16.0
|
||||
func (c *Client) SemanticTokensFullDelta(ctx context.Context, params protocol.SemanticTokensDeltaParams) (protocol.Or_Result_textDocument_semanticTokens_full_delta, error) {
|
||||
var result protocol.Or_Result_textDocument_semanticTokens_full_delta
|
||||
err := c.Call(ctx, "textDocument/semanticTokens/full/delta", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// SemanticTokensRange sends a textDocument/semanticTokens/range request to the LSP server.
|
||||
// Since 3.16.0
|
||||
func (c *Client) SemanticTokensRange(ctx context.Context, params protocol.SemanticTokensRangeParams) (protocol.SemanticTokens, error) {
|
||||
var result protocol.SemanticTokens
|
||||
err := c.Call(ctx, "textDocument/semanticTokens/range", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// LinkedEditingRange sends a textDocument/linkedEditingRange request to the LSP server.
|
||||
// A request to provide ranges that can be edited together. Since 3.16.0
|
||||
func (c *Client) LinkedEditingRange(ctx context.Context, params protocol.LinkedEditingRangeParams) (protocol.LinkedEditingRanges, error) {
|
||||
var result protocol.LinkedEditingRanges
|
||||
err := c.Call(ctx, "textDocument/linkedEditingRange", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// WillCreateFiles sends a workspace/willCreateFiles request to the LSP server.
|
||||
// The will create files request is sent from the client to the server before files are actually created as long as the creation is triggered from within the client. The request can return a WorkspaceEdit which will be applied to workspace before the files are created. Hence the WorkspaceEdit can not manipulate the content of the file to be created. Since 3.16.0
|
||||
func (c *Client) WillCreateFiles(ctx context.Context, params protocol.CreateFilesParams) (protocol.WorkspaceEdit, error) {
|
||||
var result protocol.WorkspaceEdit
|
||||
err := c.Call(ctx, "workspace/willCreateFiles", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// WillRenameFiles sends a workspace/willRenameFiles request to the LSP server.
|
||||
// The will rename files request is sent from the client to the server before files are actually renamed as long as the rename is triggered from within the client. Since 3.16.0
|
||||
func (c *Client) WillRenameFiles(ctx context.Context, params protocol.RenameFilesParams) (protocol.WorkspaceEdit, error) {
|
||||
var result protocol.WorkspaceEdit
|
||||
err := c.Call(ctx, "workspace/willRenameFiles", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// WillDeleteFiles sends a workspace/willDeleteFiles request to the LSP server.
|
||||
// The did delete files notification is sent from the client to the server when files were deleted from within the client. Since 3.16.0
|
||||
func (c *Client) WillDeleteFiles(ctx context.Context, params protocol.DeleteFilesParams) (protocol.WorkspaceEdit, error) {
|
||||
var result protocol.WorkspaceEdit
|
||||
err := c.Call(ctx, "workspace/willDeleteFiles", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Moniker sends a textDocument/moniker request to the LSP server.
|
||||
// A request to get the moniker of a symbol at a given text document position. The request parameter is of type TextDocumentPositionParams. The response is of type Moniker Moniker[] or null.
|
||||
func (c *Client) Moniker(ctx context.Context, params protocol.MonikerParams) ([]protocol.Moniker, error) {
|
||||
var result []protocol.Moniker
|
||||
err := c.Call(ctx, "textDocument/moniker", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// PrepareTypeHierarchy sends a textDocument/prepareTypeHierarchy request to the LSP server.
|
||||
// A request to result a TypeHierarchyItem in a document at a given position. Can be used as an input to a subtypes or supertypes type hierarchy. Since 3.17.0
|
||||
func (c *Client) PrepareTypeHierarchy(ctx context.Context, params protocol.TypeHierarchyPrepareParams) ([]protocol.TypeHierarchyItem, error) {
|
||||
var result []protocol.TypeHierarchyItem
|
||||
err := c.Call(ctx, "textDocument/prepareTypeHierarchy", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Supertypes sends a typeHierarchy/supertypes request to the LSP server.
|
||||
// A request to resolve the supertypes for a given TypeHierarchyItem. Since 3.17.0
|
||||
func (c *Client) Supertypes(ctx context.Context, params protocol.TypeHierarchySupertypesParams) ([]protocol.TypeHierarchyItem, error) {
|
||||
var result []protocol.TypeHierarchyItem
|
||||
err := c.Call(ctx, "typeHierarchy/supertypes", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Subtypes sends a typeHierarchy/subtypes request to the LSP server.
|
||||
// A request to resolve the subtypes for a given TypeHierarchyItem. Since 3.17.0
|
||||
func (c *Client) Subtypes(ctx context.Context, params protocol.TypeHierarchySubtypesParams) ([]protocol.TypeHierarchyItem, error) {
|
||||
var result []protocol.TypeHierarchyItem
|
||||
err := c.Call(ctx, "typeHierarchy/subtypes", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// InlineValue sends a textDocument/inlineValue request to the LSP server.
|
||||
// A request to provide inline values in a document. The request's parameter is of type InlineValueParams, the response is of type InlineValue InlineValue[] or a Thenable that resolves to such. Since 3.17.0
|
||||
func (c *Client) InlineValue(ctx context.Context, params protocol.InlineValueParams) ([]protocol.InlineValue, error) {
|
||||
var result []protocol.InlineValue
|
||||
err := c.Call(ctx, "textDocument/inlineValue", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// InlayHint sends a textDocument/inlayHint request to the LSP server.
|
||||
// A request to provide inlay hints in a document. The request's parameter is of type InlayHintsParams, the response is of type InlayHint InlayHint[] or a Thenable that resolves to such. Since 3.17.0
|
||||
func (c *Client) InlayHint(ctx context.Context, params protocol.InlayHintParams) ([]protocol.InlayHint, error) {
|
||||
var result []protocol.InlayHint
|
||||
err := c.Call(ctx, "textDocument/inlayHint", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Resolve sends a inlayHint/resolve request to the LSP server.
|
||||
// A request to resolve additional properties for an inlay hint. The request's parameter is of type InlayHint, the response is of type InlayHint or a Thenable that resolves to such. Since 3.17.0
|
||||
func (c *Client) Resolve(ctx context.Context, params protocol.InlayHint) (protocol.InlayHint, error) {
|
||||
var result protocol.InlayHint
|
||||
err := c.Call(ctx, "inlayHint/resolve", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Diagnostic sends a textDocument/diagnostic request to the LSP server.
|
||||
// The document diagnostic request definition. Since 3.17.0
|
||||
func (c *Client) Diagnostic(ctx context.Context, params protocol.DocumentDiagnosticParams) (protocol.DocumentDiagnosticReport, error) {
|
||||
var result protocol.DocumentDiagnosticReport
|
||||
err := c.Call(ctx, "textDocument/diagnostic", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// DiagnosticWorkspace sends a workspace/diagnostic request to the LSP server.
|
||||
// The workspace diagnostic request definition. Since 3.17.0
|
||||
func (c *Client) DiagnosticWorkspace(ctx context.Context, params protocol.WorkspaceDiagnosticParams) (protocol.WorkspaceDiagnosticReport, error) {
|
||||
var result protocol.WorkspaceDiagnosticReport
|
||||
err := c.Call(ctx, "workspace/diagnostic", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// InlineCompletion sends a textDocument/inlineCompletion request to the LSP server.
|
||||
// A request to provide inline completions in a document. The request's parameter is of type InlineCompletionParams, the response is of type InlineCompletion InlineCompletion[] or a Thenable that resolves to such. Since 3.18.0 PROPOSED
|
||||
func (c *Client) InlineCompletion(ctx context.Context, params protocol.InlineCompletionParams) (protocol.Or_Result_textDocument_inlineCompletion, error) {
|
||||
var result protocol.Or_Result_textDocument_inlineCompletion
|
||||
err := c.Call(ctx, "textDocument/inlineCompletion", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// TextDocumentContent sends a workspace/textDocumentContent request to the LSP server.
|
||||
// The workspace/textDocumentContent request is sent from the client to the server to request the content of a text document. Since 3.18.0 PROPOSED
|
||||
func (c *Client) TextDocumentContent(ctx context.Context, params protocol.TextDocumentContentParams) (string, error) {
|
||||
var result string
|
||||
err := c.Call(ctx, "workspace/textDocumentContent", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Initialize sends a initialize request to the LSP server.
|
||||
// The initialize request is sent from the client to the server. It is sent once as the request after starting up the server. The requests parameter is of type InitializeParams the response if of type InitializeResult of a Thenable that resolves to such.
|
||||
func (c *Client) Initialize(ctx context.Context, params protocol.ParamInitialize) (protocol.InitializeResult, error) {
|
||||
var result protocol.InitializeResult
|
||||
err := c.Call(ctx, "initialize", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Shutdown sends a shutdown request to the LSP server.
|
||||
// A shutdown request is sent from the client to the server. It is sent once when the client decides to shutdown the server. The only notification that is sent after a shutdown request is the exit event.
|
||||
func (c *Client) Shutdown(ctx context.Context) error {
|
||||
return c.Call(ctx, "shutdown", nil, nil)
|
||||
}
|
||||
|
||||
// WillSaveWaitUntil sends a textDocument/willSaveWaitUntil request to the LSP server.
|
||||
// A document will save request is sent from the client to the server before the document is actually saved. The request can return an array of TextEdits which will be applied to the text document before it is saved. Please note that clients might drop results if computing the text edits took too long or if a server constantly fails on this request. This is done to keep the save fast and reliable.
|
||||
func (c *Client) WillSaveWaitUntil(ctx context.Context, params protocol.WillSaveTextDocumentParams) ([]protocol.TextEdit, error) {
|
||||
var result []protocol.TextEdit
|
||||
err := c.Call(ctx, "textDocument/willSaveWaitUntil", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Completion sends a textDocument/completion request to the LSP server.
|
||||
// Request to request completion at a given text document position. The request's parameter is of type TextDocumentPosition the response is of type CompletionItem CompletionItem[] or CompletionList or a Thenable that resolves to such. The request can delay the computation of the CompletionItem.detail detail and CompletionItem.documentation documentation properties to the completionItem/resolve request. However, properties that are needed for the initial sorting and filtering, like sortText, filterText, insertText, and textEdit, must not be changed during resolve.
|
||||
func (c *Client) Completion(ctx context.Context, params protocol.CompletionParams) (protocol.Or_Result_textDocument_completion, error) {
|
||||
var result protocol.Or_Result_textDocument_completion
|
||||
err := c.Call(ctx, "textDocument/completion", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// ResolveCompletionItem sends a completionItem/resolve request to the LSP server.
|
||||
// Request to resolve additional information for a given completion item.The request's parameter is of type CompletionItem the response is of type CompletionItem or a Thenable that resolves to such.
|
||||
func (c *Client) ResolveCompletionItem(ctx context.Context, params protocol.CompletionItem) (protocol.CompletionItem, error) {
|
||||
var result protocol.CompletionItem
|
||||
err := c.Call(ctx, "completionItem/resolve", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Hover sends a textDocument/hover request to the LSP server.
|
||||
// Request to request hover information at a given text document position. The request's parameter is of type TextDocumentPosition the response is of type Hover or a Thenable that resolves to such.
|
||||
func (c *Client) Hover(ctx context.Context, params protocol.HoverParams) (protocol.Hover, error) {
|
||||
var result protocol.Hover
|
||||
err := c.Call(ctx, "textDocument/hover", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// SignatureHelp sends a textDocument/signatureHelp request to the LSP server.
|
||||
func (c *Client) SignatureHelp(ctx context.Context, params protocol.SignatureHelpParams) (protocol.SignatureHelp, error) {
|
||||
var result protocol.SignatureHelp
|
||||
err := c.Call(ctx, "textDocument/signatureHelp", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Definition sends a textDocument/definition request to the LSP server.
|
||||
// A request to resolve the definition location of a symbol at a given text document position. The request's parameter is of type TextDocumentPosition the response is of either type Definition or a typed array of DefinitionLink or a Thenable that resolves to such.
|
||||
func (c *Client) Definition(ctx context.Context, params protocol.DefinitionParams) (protocol.Or_Result_textDocument_definition, error) {
|
||||
var result protocol.Or_Result_textDocument_definition
|
||||
err := c.Call(ctx, "textDocument/definition", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// References sends a textDocument/references request to the LSP server.
|
||||
// A request to resolve project-wide references for the symbol denoted by the given text document position. The request's parameter is of type ReferenceParams the response is of type Location Location[] or a Thenable that resolves to such.
|
||||
func (c *Client) References(ctx context.Context, params protocol.ReferenceParams) ([]protocol.Location, error) {
|
||||
var result []protocol.Location
|
||||
err := c.Call(ctx, "textDocument/references", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// DocumentHighlight sends a textDocument/documentHighlight request to the LSP server.
|
||||
// Request to resolve a DocumentHighlight for a given text document position. The request's parameter is of type TextDocumentPosition the request response is an array of type DocumentHighlight or a Thenable that resolves to such.
|
||||
func (c *Client) DocumentHighlight(ctx context.Context, params protocol.DocumentHighlightParams) ([]protocol.DocumentHighlight, error) {
|
||||
var result []protocol.DocumentHighlight
|
||||
err := c.Call(ctx, "textDocument/documentHighlight", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// DocumentSymbol sends a textDocument/documentSymbol request to the LSP server.
|
||||
// A request to list all symbols found in a given text document. The request's parameter is of type TextDocumentIdentifier the response is of type SymbolInformation SymbolInformation[] or a Thenable that resolves to such.
|
||||
func (c *Client) DocumentSymbol(ctx context.Context, params protocol.DocumentSymbolParams) (protocol.Or_Result_textDocument_documentSymbol, error) {
|
||||
var result protocol.Or_Result_textDocument_documentSymbol
|
||||
err := c.Call(ctx, "textDocument/documentSymbol", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// CodeAction sends a textDocument/codeAction request to the LSP server.
|
||||
// A request to provide commands for the given text document and range.
|
||||
func (c *Client) CodeAction(ctx context.Context, params protocol.CodeActionParams) ([]protocol.Or_Result_textDocument_codeAction_Item0_Elem, error) {
|
||||
var result []protocol.Or_Result_textDocument_codeAction_Item0_Elem
|
||||
err := c.Call(ctx, "textDocument/codeAction", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// ResolveCodeAction sends a codeAction/resolve request to the LSP server.
|
||||
// Request to resolve additional information for a given code action.The request's parameter is of type CodeAction the response is of type CodeAction or a Thenable that resolves to such.
|
||||
func (c *Client) ResolveCodeAction(ctx context.Context, params protocol.CodeAction) (protocol.CodeAction, error) {
|
||||
var result protocol.CodeAction
|
||||
err := c.Call(ctx, "codeAction/resolve", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Symbol sends a workspace/symbol request to the LSP server.
|
||||
// A request to list project-wide symbols matching the query string given by the WorkspaceSymbolParams. The response is of type SymbolInformation SymbolInformation[] or a Thenable that resolves to such. Since 3.17.0 - support for WorkspaceSymbol in the returned data. Clients need to advertise support for WorkspaceSymbols via the client capability workspace.symbol.resolveSupport.
|
||||
func (c *Client) Symbol(ctx context.Context, params protocol.WorkspaceSymbolParams) (protocol.Or_Result_workspace_symbol, error) {
|
||||
var result protocol.Or_Result_workspace_symbol
|
||||
err := c.Call(ctx, "workspace/symbol", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// ResolveWorkspaceSymbol sends a workspaceSymbol/resolve request to the LSP server.
|
||||
// A request to resolve the range inside the workspace symbol's location. Since 3.17.0
|
||||
func (c *Client) ResolveWorkspaceSymbol(ctx context.Context, params protocol.WorkspaceSymbol) (protocol.WorkspaceSymbol, error) {
|
||||
var result protocol.WorkspaceSymbol
|
||||
err := c.Call(ctx, "workspaceSymbol/resolve", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// CodeLens sends a textDocument/codeLens request to the LSP server.
|
||||
// A request to provide code lens for the given text document.
|
||||
func (c *Client) CodeLens(ctx context.Context, params protocol.CodeLensParams) ([]protocol.CodeLens, error) {
|
||||
var result []protocol.CodeLens
|
||||
err := c.Call(ctx, "textDocument/codeLens", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// ResolveCodeLens sends a codeLens/resolve request to the LSP server.
|
||||
// A request to resolve a command for a given code lens.
|
||||
func (c *Client) ResolveCodeLens(ctx context.Context, params protocol.CodeLens) (protocol.CodeLens, error) {
|
||||
var result protocol.CodeLens
|
||||
err := c.Call(ctx, "codeLens/resolve", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// DocumentLink sends a textDocument/documentLink request to the LSP server.
|
||||
// A request to provide document links
|
||||
func (c *Client) DocumentLink(ctx context.Context, params protocol.DocumentLinkParams) ([]protocol.DocumentLink, error) {
|
||||
var result []protocol.DocumentLink
|
||||
err := c.Call(ctx, "textDocument/documentLink", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// ResolveDocumentLink sends a documentLink/resolve request to the LSP server.
|
||||
// Request to resolve additional information for a given document link. The request's parameter is of type DocumentLink the response is of type DocumentLink or a Thenable that resolves to such.
|
||||
func (c *Client) ResolveDocumentLink(ctx context.Context, params protocol.DocumentLink) (protocol.DocumentLink, error) {
|
||||
var result protocol.DocumentLink
|
||||
err := c.Call(ctx, "documentLink/resolve", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Formatting sends a textDocument/formatting request to the LSP server.
|
||||
// A request to format a whole document.
|
||||
func (c *Client) Formatting(ctx context.Context, params protocol.DocumentFormattingParams) ([]protocol.TextEdit, error) {
|
||||
var result []protocol.TextEdit
|
||||
err := c.Call(ctx, "textDocument/formatting", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// RangeFormatting sends a textDocument/rangeFormatting request to the LSP server.
|
||||
// A request to format a range in a document.
|
||||
func (c *Client) RangeFormatting(ctx context.Context, params protocol.DocumentRangeFormattingParams) ([]protocol.TextEdit, error) {
|
||||
var result []protocol.TextEdit
|
||||
err := c.Call(ctx, "textDocument/rangeFormatting", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// RangesFormatting sends a textDocument/rangesFormatting request to the LSP server.
|
||||
// A request to format ranges in a document. Since 3.18.0 PROPOSED
|
||||
func (c *Client) RangesFormatting(ctx context.Context, params protocol.DocumentRangesFormattingParams) ([]protocol.TextEdit, error) {
|
||||
var result []protocol.TextEdit
|
||||
err := c.Call(ctx, "textDocument/rangesFormatting", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// OnTypeFormatting sends a textDocument/onTypeFormatting request to the LSP server.
|
||||
// A request to format a document on type.
|
||||
func (c *Client) OnTypeFormatting(ctx context.Context, params protocol.DocumentOnTypeFormattingParams) ([]protocol.TextEdit, error) {
|
||||
var result []protocol.TextEdit
|
||||
err := c.Call(ctx, "textDocument/onTypeFormatting", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Rename sends a textDocument/rename request to the LSP server.
|
||||
// A request to rename a symbol.
|
||||
func (c *Client) Rename(ctx context.Context, params protocol.RenameParams) (protocol.WorkspaceEdit, error) {
|
||||
var result protocol.WorkspaceEdit
|
||||
err := c.Call(ctx, "textDocument/rename", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// PrepareRename sends a textDocument/prepareRename request to the LSP server.
|
||||
// A request to test and perform the setup necessary for a rename. Since 3.16 - support for default behavior
|
||||
func (c *Client) PrepareRename(ctx context.Context, params protocol.PrepareRenameParams) (protocol.PrepareRenameResult, error) {
|
||||
var result protocol.PrepareRenameResult
|
||||
err := c.Call(ctx, "textDocument/prepareRename", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// ExecuteCommand sends a workspace/executeCommand request to the LSP server.
|
||||
// A request send from the client to the server to execute a command. The request might return a workspace edit which the client will apply to the workspace.
|
||||
func (c *Client) ExecuteCommand(ctx context.Context, params protocol.ExecuteCommandParams) (any, error) {
|
||||
var result any
|
||||
err := c.Call(ctx, "workspace/executeCommand", params, &result)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// DidChangeWorkspaceFolders sends a workspace/didChangeWorkspaceFolders notification to the LSP server.
|
||||
// The workspace/didChangeWorkspaceFolders notification is sent from the client to the server when the workspace folder configuration changes.
|
||||
func (c *Client) DidChangeWorkspaceFolders(ctx context.Context, params protocol.DidChangeWorkspaceFoldersParams) error {
|
||||
return c.Notify(ctx, "workspace/didChangeWorkspaceFolders", params)
|
||||
}
|
||||
|
||||
// WorkDoneProgressCancel sends a window/workDoneProgress/cancel notification to the LSP server.
|
||||
// The window/workDoneProgress/cancel notification is sent from the client to the server to cancel a progress initiated on the server side.
|
||||
func (c *Client) WorkDoneProgressCancel(ctx context.Context, params protocol.WorkDoneProgressCancelParams) error {
|
||||
return c.Notify(ctx, "window/workDoneProgress/cancel", params)
|
||||
}
|
||||
|
||||
// DidCreateFiles sends a workspace/didCreateFiles notification to the LSP server.
|
||||
// The did create files notification is sent from the client to the server when files were created from within the client. Since 3.16.0
|
||||
func (c *Client) DidCreateFiles(ctx context.Context, params protocol.CreateFilesParams) error {
|
||||
return c.Notify(ctx, "workspace/didCreateFiles", params)
|
||||
}
|
||||
|
||||
// DidRenameFiles sends a workspace/didRenameFiles notification to the LSP server.
|
||||
// The did rename files notification is sent from the client to the server when files were renamed from within the client. Since 3.16.0
|
||||
func (c *Client) DidRenameFiles(ctx context.Context, params protocol.RenameFilesParams) error {
|
||||
return c.Notify(ctx, "workspace/didRenameFiles", params)
|
||||
}
|
||||
|
||||
// DidDeleteFiles sends a workspace/didDeleteFiles notification to the LSP server.
|
||||
// The will delete files request is sent from the client to the server before files are actually deleted as long as the deletion is triggered from within the client. Since 3.16.0
|
||||
func (c *Client) DidDeleteFiles(ctx context.Context, params protocol.DeleteFilesParams) error {
|
||||
return c.Notify(ctx, "workspace/didDeleteFiles", params)
|
||||
}
|
||||
|
||||
// DidOpenNotebookDocument sends a notebookDocument/didOpen notification to the LSP server.
|
||||
// A notification sent when a notebook opens. Since 3.17.0
|
||||
func (c *Client) DidOpenNotebookDocument(ctx context.Context, params protocol.DidOpenNotebookDocumentParams) error {
|
||||
return c.Notify(ctx, "notebookDocument/didOpen", params)
|
||||
}
|
||||
|
||||
// DidChangeNotebookDocument sends a notebookDocument/didChange notification to the LSP server.
|
||||
func (c *Client) DidChangeNotebookDocument(ctx context.Context, params protocol.DidChangeNotebookDocumentParams) error {
|
||||
return c.Notify(ctx, "notebookDocument/didChange", params)
|
||||
}
|
||||
|
||||
// DidSaveNotebookDocument sends a notebookDocument/didSave notification to the LSP server.
|
||||
// A notification sent when a notebook document is saved. Since 3.17.0
|
||||
func (c *Client) DidSaveNotebookDocument(ctx context.Context, params protocol.DidSaveNotebookDocumentParams) error {
|
||||
return c.Notify(ctx, "notebookDocument/didSave", params)
|
||||
}
|
||||
|
||||
// DidCloseNotebookDocument sends a notebookDocument/didClose notification to the LSP server.
|
||||
// A notification sent when a notebook closes. Since 3.17.0
|
||||
func (c *Client) DidCloseNotebookDocument(ctx context.Context, params protocol.DidCloseNotebookDocumentParams) error {
|
||||
return c.Notify(ctx, "notebookDocument/didClose", params)
|
||||
}
|
||||
|
||||
// Initialized sends a initialized notification to the LSP server.
|
||||
// The initialized notification is sent from the client to the server after the client is fully initialized and the server is allowed to send requests from the server to the client.
|
||||
func (c *Client) Initialized(ctx context.Context, params protocol.InitializedParams) error {
|
||||
return c.Notify(ctx, "initialized", params)
|
||||
}
|
||||
|
||||
// Exit sends a exit notification to the LSP server.
|
||||
// The exit event is sent from the client to the server to ask the server to exit its process.
|
||||
func (c *Client) Exit(ctx context.Context) error {
|
||||
return c.Notify(ctx, "exit", nil)
|
||||
}
|
||||
|
||||
// DidChangeConfiguration sends a workspace/didChangeConfiguration notification to the LSP server.
|
||||
// The configuration change notification is sent from the client to the server when the client's configuration has changed. The notification contains the changed configuration as defined by the language client.
|
||||
func (c *Client) DidChangeConfiguration(ctx context.Context, params protocol.DidChangeConfigurationParams) error {
|
||||
return c.Notify(ctx, "workspace/didChangeConfiguration", params)
|
||||
}
|
||||
|
||||
// DidOpen sends a textDocument/didOpen notification to the LSP server.
|
||||
// The document open notification is sent from the client to the server to signal newly opened text documents. The document's truth is now managed by the client and the server must not try to read the document's truth using the document's uri. Open in this sense means it is managed by the client. It doesn't necessarily mean that its content is presented in an editor. An open notification must not be sent more than once without a corresponding close notification send before. This means open and close notification must be balanced and the max open count is one.
|
||||
func (c *Client) DidOpen(ctx context.Context, params protocol.DidOpenTextDocumentParams) error {
|
||||
return c.Notify(ctx, "textDocument/didOpen", params)
|
||||
}
|
||||
|
||||
// DidChange sends a textDocument/didChange notification to the LSP server.
|
||||
// The document change notification is sent from the client to the server to signal changes to a text document.
|
||||
func (c *Client) DidChange(ctx context.Context, params protocol.DidChangeTextDocumentParams) error {
|
||||
return c.Notify(ctx, "textDocument/didChange", params)
|
||||
}
|
||||
|
||||
// DidClose sends a textDocument/didClose notification to the LSP server.
|
||||
// The document close notification is sent from the client to the server when the document got closed in the client. The document's truth now exists where the document's uri points to (e.g. if the document's uri is a file uri the truth now exists on disk). As with the open notification the close notification is about managing the document's content. Receiving a close notification doesn't mean that the document was open in an editor before. A close notification requires a previous open notification to be sent.
|
||||
func (c *Client) DidClose(ctx context.Context, params protocol.DidCloseTextDocumentParams) error {
|
||||
return c.Notify(ctx, "textDocument/didClose", params)
|
||||
}
|
||||
|
||||
// DidSave sends a textDocument/didSave notification to the LSP server.
|
||||
// The document save notification is sent from the client to the server when the document got saved in the client.
|
||||
func (c *Client) DidSave(ctx context.Context, params protocol.DidSaveTextDocumentParams) error {
|
||||
return c.Notify(ctx, "textDocument/didSave", params)
|
||||
}
|
||||
|
||||
// WillSave sends a textDocument/willSave notification to the LSP server.
|
||||
// A document will save notification is sent from the client to the server before the document is actually saved.
|
||||
func (c *Client) WillSave(ctx context.Context, params protocol.WillSaveTextDocumentParams) error {
|
||||
return c.Notify(ctx, "textDocument/willSave", params)
|
||||
}
|
||||
|
||||
// DidChangeWatchedFiles sends a workspace/didChangeWatchedFiles notification to the LSP server.
|
||||
// The watched files notification is sent from the client to the server when the client detects changes to file watched by the language client.
|
||||
func (c *Client) DidChangeWatchedFiles(ctx context.Context, params protocol.DidChangeWatchedFilesParams) error {
|
||||
return c.Notify(ctx, "workspace/didChangeWatchedFiles", params)
|
||||
}
|
||||
|
||||
// SetTrace sends a $/setTrace notification to the LSP server.
|
||||
func (c *Client) SetTrace(ctx context.Context, params protocol.SetTraceParams) error {
|
||||
return c.Notify(ctx, "$/setTrace", params)
|
||||
}
|
||||
|
||||
// Progress sends a $/progress notification to the LSP server.
|
||||
func (c *Client) Progress(ctx context.Context, params protocol.ProgressParams) error {
|
||||
return c.Notify(ctx, "$/progress", params)
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
package lsp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// Message represents a JSON-RPC 2.0 message
|
||||
type Message struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID int32 `json:"id,omitempty"`
|
||||
Method string `json:"method,omitempty"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *ResponseError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ResponseError represents a JSON-RPC 2.0 error
|
||||
type ResponseError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func NewRequest(id int32, method string, params any) (*Message, error) {
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Message{
|
||||
JSONRPC: "2.0",
|
||||
ID: id,
|
||||
Method: method,
|
||||
Params: paramsJSON,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewNotification(method string, params any) (*Message, error) {
|
||||
paramsJSON, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Message{
|
||||
JSONRPC: "2.0",
|
||||
Method: method,
|
||||
Params: paramsJSON,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
Copyright 2009 The Go Authors.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google LLC nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
@@ -1,117 +0,0 @@
|
||||
package protocol
|
||||
|
||||
import "fmt"
|
||||
|
||||
// TextEditResult is an interface for types that represent workspace symbols
|
||||
type WorkspaceSymbolResult interface {
|
||||
GetName() string
|
||||
GetLocation() Location
|
||||
isWorkspaceSymbol() // marker method
|
||||
}
|
||||
|
||||
func (ws *WorkspaceSymbol) GetName() string { return ws.Name }
|
||||
func (ws *WorkspaceSymbol) GetLocation() Location {
|
||||
switch v := ws.Location.Value.(type) {
|
||||
case Location:
|
||||
return v
|
||||
case LocationUriOnly:
|
||||
return Location{URI: v.URI}
|
||||
}
|
||||
return Location{}
|
||||
}
|
||||
func (ws *WorkspaceSymbol) isWorkspaceSymbol() {}
|
||||
|
||||
func (si *SymbolInformation) GetName() string { return si.Name }
|
||||
func (si *SymbolInformation) GetLocation() Location { return si.Location }
|
||||
func (si *SymbolInformation) isWorkspaceSymbol() {}
|
||||
|
||||
// Results converts the Value to a slice of WorkspaceSymbolResult
|
||||
func (r Or_Result_workspace_symbol) Results() ([]WorkspaceSymbolResult, error) {
|
||||
if r.Value == nil {
|
||||
return make([]WorkspaceSymbolResult, 0), nil
|
||||
}
|
||||
switch v := r.Value.(type) {
|
||||
case []WorkspaceSymbol:
|
||||
results := make([]WorkspaceSymbolResult, len(v))
|
||||
for i := range v {
|
||||
results[i] = &v[i]
|
||||
}
|
||||
return results, nil
|
||||
case []SymbolInformation:
|
||||
results := make([]WorkspaceSymbolResult, len(v))
|
||||
for i := range v {
|
||||
results[i] = &v[i]
|
||||
}
|
||||
return results, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown symbol type: %T", r.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// TextEditResult is an interface for types that represent document symbols
|
||||
type DocumentSymbolResult interface {
|
||||
GetRange() Range
|
||||
GetName() string
|
||||
isDocumentSymbol() // marker method
|
||||
}
|
||||
|
||||
func (ds *DocumentSymbol) GetRange() Range { return ds.Range }
|
||||
func (ds *DocumentSymbol) GetName() string { return ds.Name }
|
||||
func (ds *DocumentSymbol) isDocumentSymbol() {}
|
||||
|
||||
func (si *SymbolInformation) GetRange() Range { return si.Location.Range }
|
||||
|
||||
// Note: SymbolInformation already has GetName() implemented above
|
||||
func (si *SymbolInformation) isDocumentSymbol() {}
|
||||
|
||||
// Results converts the Value to a slice of DocumentSymbolResult
|
||||
func (r Or_Result_textDocument_documentSymbol) Results() ([]DocumentSymbolResult, error) {
|
||||
if r.Value == nil {
|
||||
return make([]DocumentSymbolResult, 0), nil
|
||||
}
|
||||
switch v := r.Value.(type) {
|
||||
case []DocumentSymbol:
|
||||
results := make([]DocumentSymbolResult, len(v))
|
||||
for i := range v {
|
||||
results[i] = &v[i]
|
||||
}
|
||||
return results, nil
|
||||
case []SymbolInformation:
|
||||
results := make([]DocumentSymbolResult, len(v))
|
||||
for i := range v {
|
||||
results[i] = &v[i]
|
||||
}
|
||||
return results, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown document symbol type: %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
// TextEditResult is an interface for types that can be used as text edits
|
||||
type TextEditResult interface {
|
||||
GetRange() Range
|
||||
GetNewText() string
|
||||
isTextEdit() // marker method
|
||||
}
|
||||
|
||||
func (te *TextEdit) GetRange() Range { return te.Range }
|
||||
func (te *TextEdit) GetNewText() string { return te.NewText }
|
||||
func (te *TextEdit) isTextEdit() {}
|
||||
|
||||
// Convert Or_TextDocumentEdit_edits_Elem to TextEdit
|
||||
func (e Or_TextDocumentEdit_edits_Elem) AsTextEdit() (TextEdit, error) {
|
||||
if e.Value == nil {
|
||||
return TextEdit{}, fmt.Errorf("nil text edit")
|
||||
}
|
||||
switch v := e.Value.(type) {
|
||||
case TextEdit:
|
||||
return v, nil
|
||||
case AnnotatedTextEdit:
|
||||
return TextEdit{
|
||||
Range: v.Range,
|
||||
NewText: v.NewText,
|
||||
}, nil
|
||||
default:
|
||||
return TextEdit{}, fmt.Errorf("unknown text edit type: %T", e.Value)
|
||||
}
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PatternInfo is an interface for types that represent glob patterns
|
||||
type PatternInfo interface {
|
||||
GetPattern() string
|
||||
GetBasePath() string
|
||||
isPattern() // marker method
|
||||
}
|
||||
|
||||
// StringPattern implements PatternInfo for string patterns
|
||||
type StringPattern struct {
|
||||
Pattern string
|
||||
}
|
||||
|
||||
func (p StringPattern) GetPattern() string { return p.Pattern }
|
||||
func (p StringPattern) GetBasePath() string { return "" }
|
||||
func (p StringPattern) isPattern() {}
|
||||
|
||||
// RelativePatternInfo implements PatternInfo for RelativePattern
|
||||
type RelativePatternInfo struct {
|
||||
RP RelativePattern
|
||||
BasePath string
|
||||
}
|
||||
|
||||
func (p RelativePatternInfo) GetPattern() string { return string(p.RP.Pattern) }
|
||||
func (p RelativePatternInfo) GetBasePath() string { return p.BasePath }
|
||||
func (p RelativePatternInfo) isPattern() {}
|
||||
|
||||
// AsPattern converts GlobPattern to a PatternInfo object
|
||||
func (g *GlobPattern) AsPattern() (PatternInfo, error) {
|
||||
if g.Value == nil {
|
||||
return nil, fmt.Errorf("nil pattern")
|
||||
}
|
||||
|
||||
switch v := g.Value.(type) {
|
||||
case string:
|
||||
return StringPattern{Pattern: v}, nil
|
||||
case RelativePattern:
|
||||
// Handle BaseURI which could be string or DocumentUri
|
||||
basePath := ""
|
||||
switch baseURI := v.BaseURI.Value.(type) {
|
||||
case string:
|
||||
basePath = strings.TrimPrefix(baseURI, "file://")
|
||||
case DocumentUri:
|
||||
basePath = strings.TrimPrefix(string(baseURI), "file://")
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown BaseURI type: %T", v.BaseURI.Value)
|
||||
}
|
||||
return RelativePatternInfo{RP: v, BasePath: basePath}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown pattern type: %T", g.Value)
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
package protocol
|
||||
|
||||
var TableKindMap = map[SymbolKind]string{
|
||||
File: "File",
|
||||
Module: "Module",
|
||||
Namespace: "Namespace",
|
||||
Package: "Package",
|
||||
Class: "Class",
|
||||
Method: "Method",
|
||||
Property: "Property",
|
||||
Field: "Field",
|
||||
Constructor: "Constructor",
|
||||
Enum: "Enum",
|
||||
Interface: "Interface",
|
||||
Function: "Function",
|
||||
Variable: "Variable",
|
||||
Constant: "Constant",
|
||||
String: "String",
|
||||
Number: "Number",
|
||||
Boolean: "Boolean",
|
||||
Array: "Array",
|
||||
Object: "Object",
|
||||
Key: "Key",
|
||||
Null: "Null",
|
||||
EnumMember: "EnumMember",
|
||||
Struct: "Struct",
|
||||
Event: "Event",
|
||||
Operator: "Operator",
|
||||
TypeParameter: "TypeParameter",
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
// Copyright 2022 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// DocumentChange is a union of various file edit operations.
|
||||
//
|
||||
// Exactly one field of this struct is non-nil; see [DocumentChange.Valid].
|
||||
//
|
||||
// See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#resourceChanges
|
||||
type DocumentChange struct {
|
||||
TextDocumentEdit *TextDocumentEdit
|
||||
CreateFile *CreateFile
|
||||
RenameFile *RenameFile
|
||||
DeleteFile *DeleteFile
|
||||
}
|
||||
|
||||
// Valid reports whether the DocumentChange sum-type value is valid,
|
||||
// that is, exactly one of create, delete, edit, or rename.
|
||||
func (ch DocumentChange) Valid() bool {
|
||||
n := 0
|
||||
if ch.TextDocumentEdit != nil {
|
||||
n++
|
||||
}
|
||||
if ch.CreateFile != nil {
|
||||
n++
|
||||
}
|
||||
if ch.RenameFile != nil {
|
||||
n++
|
||||
}
|
||||
if ch.DeleteFile != nil {
|
||||
n++
|
||||
}
|
||||
return n == 1
|
||||
}
|
||||
|
||||
func (d *DocumentChange) UnmarshalJSON(data []byte) error {
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, ok := m["textDocument"]; ok {
|
||||
d.TextDocumentEdit = new(TextDocumentEdit)
|
||||
return json.Unmarshal(data, d.TextDocumentEdit)
|
||||
}
|
||||
|
||||
// The {Create,Rename,Delete}File types all share a 'kind' field.
|
||||
kind := m["kind"]
|
||||
switch kind {
|
||||
case "create":
|
||||
d.CreateFile = new(CreateFile)
|
||||
return json.Unmarshal(data, d.CreateFile)
|
||||
case "rename":
|
||||
d.RenameFile = new(RenameFile)
|
||||
return json.Unmarshal(data, d.RenameFile)
|
||||
case "delete":
|
||||
d.DeleteFile = new(DeleteFile)
|
||||
return json.Unmarshal(data, d.DeleteFile)
|
||||
}
|
||||
return fmt.Errorf("DocumentChanges: unexpected kind: %q", kind)
|
||||
}
|
||||
|
||||
func (d *DocumentChange) MarshalJSON() ([]byte, error) {
|
||||
if d.TextDocumentEdit != nil {
|
||||
return json.Marshal(d.TextDocumentEdit)
|
||||
} else if d.CreateFile != nil {
|
||||
return json.Marshal(d.CreateFile)
|
||||
} else if d.RenameFile != nil {
|
||||
return json.Marshal(d.RenameFile)
|
||||
} else if d.DeleteFile != nil {
|
||||
return json.Marshal(d.DeleteFile)
|
||||
}
|
||||
return nil, fmt.Errorf("empty DocumentChanges union value")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,218 +0,0 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package protocol
|
||||
|
||||
// This file declares URI, DocumentUri, and its methods.
|
||||
//
|
||||
// For the LSP definition of these types, see
|
||||
// https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#uri
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// A DocumentUri is the URI of a client editor document.
|
||||
//
|
||||
// According to the LSP specification:
|
||||
//
|
||||
// Care should be taken to handle encoding in URIs. For
|
||||
// example, some clients (such as VS Code) may encode colons
|
||||
// in drive letters while others do not. The URIs below are
|
||||
// both valid, but clients and servers should be consistent
|
||||
// with the form they use themselves to ensure the other party
|
||||
// doesn’t interpret them as distinct URIs. Clients and
|
||||
// servers should not assume that each other are encoding the
|
||||
// same way (for example a client encoding colons in drive
|
||||
// letters cannot assume server responses will have encoded
|
||||
// colons). The same applies to casing of drive letters - one
|
||||
// party should not assume the other party will return paths
|
||||
// with drive letters cased the same as it.
|
||||
//
|
||||
// file:///c:/project/readme.md
|
||||
// file:///C%3A/project/readme.md
|
||||
//
|
||||
// This is done during JSON unmarshalling;
|
||||
// see [DocumentUri.UnmarshalText] for details.
|
||||
type DocumentUri string
|
||||
|
||||
// A URI is an arbitrary URL (e.g. https), not necessarily a file.
|
||||
type URI = string
|
||||
|
||||
// UnmarshalText implements decoding of DocumentUri values.
|
||||
//
|
||||
// In particular, it implements a systematic correction of various odd
|
||||
// features of the definition of DocumentUri in the LSP spec that
|
||||
// appear to be workarounds for bugs in VS Code. For example, it may
|
||||
// URI-encode the URI itself, so that colon becomes %3A, and it may
|
||||
// send file://foo.go URIs that have two slashes (not three) and no
|
||||
// hostname.
|
||||
//
|
||||
// We use UnmarshalText, not UnmarshalJSON, because it is called even
|
||||
// for non-addressable values such as keys and values of map[K]V,
|
||||
// where there is no pointer of type *K or *V on which to call
|
||||
// UnmarshalJSON. (See Go issue #28189 for more detail.)
|
||||
//
|
||||
// Non-empty DocumentUris are valid "file"-scheme URIs.
|
||||
// The empty DocumentUri is valid.
|
||||
func (uri *DocumentUri) UnmarshalText(data []byte) (err error) {
|
||||
*uri, err = ParseDocumentUri(string(data))
|
||||
return
|
||||
}
|
||||
|
||||
// Path returns the file path for the given URI.
|
||||
//
|
||||
// DocumentUri("").Path() returns the empty string.
|
||||
//
|
||||
// Path panics if called on a URI that is not a valid filename.
|
||||
func (uri DocumentUri) Path() string {
|
||||
filename, err := filename(uri)
|
||||
if err != nil {
|
||||
// e.g. ParseRequestURI failed.
|
||||
//
|
||||
// This can only affect DocumentUris created by
|
||||
// direct string manipulation; all DocumentUris
|
||||
// received from the client pass through
|
||||
// ParseRequestURI, which ensures validity.
|
||||
panic(err)
|
||||
}
|
||||
return filepath.FromSlash(filename)
|
||||
}
|
||||
|
||||
// Dir returns the URI for the directory containing the receiver.
|
||||
func (uri DocumentUri) Dir() DocumentUri {
|
||||
// This function could be more efficiently implemented by avoiding any call
|
||||
// to Path(), but at least consolidates URI manipulation.
|
||||
return URIFromPath(uri.DirPath())
|
||||
}
|
||||
|
||||
// DirPath returns the file path to the directory containing this URI, which
|
||||
// must be a file URI.
|
||||
func (uri DocumentUri) DirPath() string {
|
||||
return filepath.Dir(uri.Path())
|
||||
}
|
||||
|
||||
func filename(uri DocumentUri) (string, error) {
|
||||
if uri == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// This conservative check for the common case
|
||||
// of a simple non-empty absolute POSIX filename
|
||||
// avoids the allocation of a net.URL.
|
||||
if strings.HasPrefix(string(uri), "file:///") {
|
||||
rest := string(uri)[len("file://"):] // leave one slash
|
||||
for i := range len(rest) {
|
||||
b := rest[i]
|
||||
// Reject these cases:
|
||||
if b < ' ' || b == 0x7f || // control character
|
||||
b == '%' || b == '+' || // URI escape
|
||||
b == ':' || // Windows drive letter
|
||||
b == '@' || b == '&' || b == '?' { // authority or query
|
||||
goto slow
|
||||
}
|
||||
}
|
||||
return rest, nil
|
||||
}
|
||||
slow:
|
||||
|
||||
u, err := url.ParseRequestURI(string(uri))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if u.Scheme != fileScheme {
|
||||
return "", fmt.Errorf("only file URIs are supported, got %q from %q", u.Scheme, uri)
|
||||
}
|
||||
// If the URI is a Windows URI, we trim the leading "/" and uppercase
|
||||
// the drive letter, which will never be case sensitive.
|
||||
if isWindowsDriveURIPath(u.Path) {
|
||||
u.Path = strings.ToUpper(string(u.Path[1])) + u.Path[2:]
|
||||
}
|
||||
|
||||
return u.Path, nil
|
||||
}
|
||||
|
||||
// ParseDocumentUri interprets a string as a DocumentUri, applying VS
|
||||
// Code workarounds; see [DocumentUri.UnmarshalText] for details.
|
||||
func ParseDocumentUri(s string) (DocumentUri, error) {
|
||||
if s == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(s, "file://") {
|
||||
return "", fmt.Errorf("DocumentUri scheme is not 'file': %s", s)
|
||||
}
|
||||
|
||||
// VS Code sends URLs with only two slashes,
|
||||
// which are invalid. golang/go#39789.
|
||||
if !strings.HasPrefix(s, "file:///") {
|
||||
s = "file:///" + s[len("file://"):]
|
||||
}
|
||||
|
||||
// Even though the input is a URI, it may not be in canonical form. VS Code
|
||||
// in particular over-escapes :, @, etc. Unescape and re-encode to canonicalize.
|
||||
path, err := url.PathUnescape(s[len("file://"):])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// File URIs from Windows may have lowercase drive letters.
|
||||
// Since drive letters are guaranteed to be case insensitive,
|
||||
// we change them to uppercase to remain consistent.
|
||||
// For example, file:///c:/x/y/z becomes file:///C:/x/y/z.
|
||||
if isWindowsDriveURIPath(path) {
|
||||
path = path[:1] + strings.ToUpper(string(path[1])) + path[2:]
|
||||
}
|
||||
u := url.URL{Scheme: fileScheme, Path: path}
|
||||
return DocumentUri(u.String()), nil
|
||||
}
|
||||
|
||||
// URIFromPath returns DocumentUri for the supplied file path.
|
||||
// Given "", it returns "".
|
||||
func URIFromPath(path string) DocumentUri {
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
if !isWindowsDrivePath(path) {
|
||||
if abs, err := filepath.Abs(path); err == nil {
|
||||
path = abs
|
||||
}
|
||||
}
|
||||
// Check the file path again, in case it became absolute.
|
||||
if isWindowsDrivePath(path) {
|
||||
path = "/" + strings.ToUpper(string(path[0])) + path[1:]
|
||||
}
|
||||
path = filepath.ToSlash(path)
|
||||
u := url.URL{
|
||||
Scheme: fileScheme,
|
||||
Path: path,
|
||||
}
|
||||
return DocumentUri(u.String())
|
||||
}
|
||||
|
||||
const fileScheme = "file"
|
||||
|
||||
// isWindowsDrivePath returns true if the file path is of the form used by
|
||||
// Windows. We check if the path begins with a drive letter, followed by a ":".
|
||||
// For example: C:/x/y/z.
|
||||
func isWindowsDrivePath(path string) bool {
|
||||
if len(path) < 3 {
|
||||
return false
|
||||
}
|
||||
return unicode.IsLetter(rune(path[0])) && path[1] == ':'
|
||||
}
|
||||
|
||||
// isWindowsDriveURIPath returns true if the file URI is of the format used by
|
||||
// Windows URIs. The url.Parse package does not specially handle Windows paths
|
||||
// (see golang/go#6027), so we check if the URI path has a drive prefix (e.g. "/C:").
|
||||
func isWindowsDriveURIPath(uri string) bool {
|
||||
if len(uri) < 4 {
|
||||
return false
|
||||
}
|
||||
return uri[0] == '/' && unicode.IsLetter(rune(uri[1])) && uri[2] == ':'
|
||||
}
|
||||
@@ -1,272 +0,0 @@
|
||||
package lsp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// Write writes an LSP message to the given writer
|
||||
func WriteMessage(w io.Writer, msg *Message) error {
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
cnf := config.Get()
|
||||
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("Sending message to server", "method", msg.Method, "id", msg.ID)
|
||||
}
|
||||
|
||||
_, err = fmt.Fprintf(w, "Content-Length: %d\r\n\r\n", len(data))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write header: %w", err)
|
||||
}
|
||||
|
||||
_, err = w.Write(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadMessage reads a single LSP message from the given reader
|
||||
func ReadMessage(r *bufio.Reader) (*Message, error) {
|
||||
cnf := config.Get()
|
||||
// Read headers
|
||||
var contentLength int
|
||||
for {
|
||||
line, err := r.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read header: %w", err)
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("Received header", "line", line)
|
||||
}
|
||||
|
||||
if line == "" {
|
||||
break // End of headers
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "Content-Length: ") {
|
||||
_, err := fmt.Sscanf(line, "Content-Length: %d", &contentLength)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid Content-Length: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("Content-Length", "length", contentLength)
|
||||
}
|
||||
|
||||
// Read content
|
||||
content := make([]byte, contentLength)
|
||||
_, err := io.ReadFull(r, content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read content: %w", err)
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("Received content", "content", string(content))
|
||||
}
|
||||
|
||||
// Parse message
|
||||
var msg Message
|
||||
if err := json.Unmarshal(content, &msg); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal message: %w", err)
|
||||
}
|
||||
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
// handleMessages reads and dispatches messages in a loop
|
||||
func (c *Client) handleMessages() {
|
||||
cnf := config.Get()
|
||||
for {
|
||||
msg, err := ReadMessage(c.stdout)
|
||||
if err != nil {
|
||||
if cnf.DebugLSP {
|
||||
slog.Error("Error reading message", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Handle server->client request (has both Method and ID)
|
||||
if msg.Method != "" && msg.ID != 0 {
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("Received request from server", "method", msg.Method, "id", msg.ID)
|
||||
}
|
||||
|
||||
response := &Message{
|
||||
JSONRPC: "2.0",
|
||||
ID: msg.ID,
|
||||
}
|
||||
|
||||
// Look up handler for this method
|
||||
c.serverHandlersMu.RLock()
|
||||
handler, ok := c.serverRequestHandlers[msg.Method]
|
||||
c.serverHandlersMu.RUnlock()
|
||||
|
||||
if ok {
|
||||
result, err := handler(msg.Params)
|
||||
if err != nil {
|
||||
response.Error = &ResponseError{
|
||||
Code: -32603,
|
||||
Message: err.Error(),
|
||||
}
|
||||
} else {
|
||||
rawJSON, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
response.Error = &ResponseError{
|
||||
Code: -32603,
|
||||
Message: fmt.Sprintf("failed to marshal response: %v", err),
|
||||
}
|
||||
} else {
|
||||
response.Result = rawJSON
|
||||
}
|
||||
}
|
||||
} else {
|
||||
response.Error = &ResponseError{
|
||||
Code: -32601,
|
||||
Message: fmt.Sprintf("method not found: %s", msg.Method),
|
||||
}
|
||||
}
|
||||
|
||||
// Send response back to server
|
||||
if err := WriteMessage(c.stdin, response); err != nil {
|
||||
slog.Error("Error sending response to server", "error", err)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle notification (has Method but no ID)
|
||||
if msg.Method != "" && msg.ID == 0 {
|
||||
c.notificationMu.RLock()
|
||||
handler, ok := c.notificationHandlers[msg.Method]
|
||||
c.notificationMu.RUnlock()
|
||||
|
||||
if ok {
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("Handling notification", "method", msg.Method)
|
||||
}
|
||||
go handler(msg.Params)
|
||||
} else if cnf.DebugLSP {
|
||||
slog.Debug("No handler for notification", "method", msg.Method)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle response to our request (has ID but no Method)
|
||||
if msg.ID != 0 && msg.Method == "" {
|
||||
c.handlersMu.RLock()
|
||||
ch, ok := c.handlers[msg.ID]
|
||||
c.handlersMu.RUnlock()
|
||||
|
||||
if ok {
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("Received response for request", "id", msg.ID)
|
||||
}
|
||||
ch <- msg
|
||||
close(ch)
|
||||
} else if cnf.DebugLSP {
|
||||
slog.Debug("No handler for response", "id", msg.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Call makes a request and waits for the response
|
||||
func (c *Client) Call(ctx context.Context, method string, params any, result any) error {
|
||||
cnf := config.Get()
|
||||
id := c.nextID.Add(1)
|
||||
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("Making call", "method", method, "id", id)
|
||||
}
|
||||
|
||||
msg, err := NewRequest(id, method, params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Create response channel
|
||||
ch := make(chan *Message, 1)
|
||||
c.handlersMu.Lock()
|
||||
c.handlers[id] = ch
|
||||
c.handlersMu.Unlock()
|
||||
|
||||
defer func() {
|
||||
c.handlersMu.Lock()
|
||||
delete(c.handlers, id)
|
||||
c.handlersMu.Unlock()
|
||||
}()
|
||||
|
||||
// Send request
|
||||
if err := WriteMessage(c.stdin, msg); err != nil {
|
||||
return fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("Request sent", "method", method, "id", id)
|
||||
}
|
||||
|
||||
// Wait for response
|
||||
resp := <-ch
|
||||
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("Received response", "id", id)
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
return fmt.Errorf("request failed: %s (code: %d)", resp.Error.Message, resp.Error.Code)
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
// If result is a json.RawMessage, just copy the raw bytes
|
||||
if rawMsg, ok := result.(*json.RawMessage); ok {
|
||||
*rawMsg = resp.Result
|
||||
return nil
|
||||
}
|
||||
// Otherwise unmarshal into the provided type
|
||||
if err := json.Unmarshal(resp.Result, result); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal result: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Notify sends a notification (a request without an ID that doesn't expect a response)
|
||||
func (c *Client) Notify(ctx context.Context, method string, params any) error {
|
||||
cnf := config.Get()
|
||||
if cnf.DebugLSP {
|
||||
slog.Debug("Sending notification", "method", method)
|
||||
}
|
||||
|
||||
msg, err := NewNotification(method, params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create notification: %w", err)
|
||||
}
|
||||
|
||||
if err := WriteMessage(c.stdin, msg); err != nil {
|
||||
return fmt.Errorf("failed to send notification: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type (
|
||||
NotificationHandler func(params json.RawMessage)
|
||||
ServerRequestHandler func(params json.RawMessage) (any, error)
|
||||
)
|
||||
@@ -1,239 +0,0 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
)
|
||||
|
||||
func applyTextEdits(uri protocol.DocumentUri, edits []protocol.TextEdit) error {
|
||||
path := strings.TrimPrefix(string(uri), "file://")
|
||||
|
||||
// Read the file content
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
// Detect line ending style
|
||||
var lineEnding string
|
||||
if bytes.Contains(content, []byte("\r\n")) {
|
||||
lineEnding = "\r\n"
|
||||
} else {
|
||||
lineEnding = "\n"
|
||||
}
|
||||
|
||||
// Track if file ends with a newline
|
||||
endsWithNewline := len(content) > 0 && bytes.HasSuffix(content, []byte(lineEnding))
|
||||
|
||||
// Split into lines without the endings
|
||||
lines := strings.Split(string(content), lineEnding)
|
||||
|
||||
// Check for overlapping edits
|
||||
for i, edit1 := range edits {
|
||||
for j := i + 1; j < len(edits); j++ {
|
||||
if rangesOverlap(edit1.Range, edits[j].Range) {
|
||||
return fmt.Errorf("overlapping edits detected between edit %d and %d", i, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort edits in reverse order
|
||||
sortedEdits := make([]protocol.TextEdit, len(edits))
|
||||
copy(sortedEdits, edits)
|
||||
sort.Slice(sortedEdits, func(i, j int) bool {
|
||||
if sortedEdits[i].Range.Start.Line != sortedEdits[j].Range.Start.Line {
|
||||
return sortedEdits[i].Range.Start.Line > sortedEdits[j].Range.Start.Line
|
||||
}
|
||||
return sortedEdits[i].Range.Start.Character > sortedEdits[j].Range.Start.Character
|
||||
})
|
||||
|
||||
// Apply each edit
|
||||
for _, edit := range sortedEdits {
|
||||
newLines, err := applyTextEdit(lines, edit)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to apply edit: %w", err)
|
||||
}
|
||||
lines = newLines
|
||||
}
|
||||
|
||||
// Join lines with proper line endings
|
||||
var newContent strings.Builder
|
||||
for i, line := range lines {
|
||||
if i > 0 {
|
||||
newContent.WriteString(lineEnding)
|
||||
}
|
||||
newContent.WriteString(line)
|
||||
}
|
||||
|
||||
// Only add a newline if the original file had one and we haven't already added it
|
||||
if endsWithNewline && !strings.HasSuffix(newContent.String(), lineEnding) {
|
||||
newContent.WriteString(lineEnding)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, []byte(newContent.String()), 0o644); err != nil {
|
||||
return fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyTextEdit(lines []string, edit protocol.TextEdit) ([]string, error) {
|
||||
startLine := int(edit.Range.Start.Line)
|
||||
endLine := int(edit.Range.End.Line)
|
||||
startChar := int(edit.Range.Start.Character)
|
||||
endChar := int(edit.Range.End.Character)
|
||||
|
||||
// Validate positions
|
||||
if startLine < 0 || startLine >= len(lines) {
|
||||
return nil, fmt.Errorf("invalid start line: %d", startLine)
|
||||
}
|
||||
if endLine < 0 || endLine >= len(lines) {
|
||||
endLine = len(lines) - 1
|
||||
}
|
||||
|
||||
// Create result slice with initial capacity
|
||||
result := make([]string, 0, len(lines))
|
||||
|
||||
// Copy lines before edit
|
||||
result = append(result, lines[:startLine]...)
|
||||
|
||||
// Get the prefix of the start line
|
||||
startLineContent := lines[startLine]
|
||||
if startChar < 0 || startChar > len(startLineContent) {
|
||||
startChar = len(startLineContent)
|
||||
}
|
||||
prefix := startLineContent[:startChar]
|
||||
|
||||
// Get the suffix of the end line
|
||||
endLineContent := lines[endLine]
|
||||
if endChar < 0 || endChar > len(endLineContent) {
|
||||
endChar = len(endLineContent)
|
||||
}
|
||||
suffix := endLineContent[endChar:]
|
||||
|
||||
// Handle the edit
|
||||
if edit.NewText == "" {
|
||||
if prefix+suffix != "" {
|
||||
result = append(result, prefix+suffix)
|
||||
}
|
||||
} else {
|
||||
// Split new text into lines, being careful not to add extra newlines
|
||||
// newLines := strings.Split(strings.TrimRight(edit.NewText, "\n"), "\n")
|
||||
newLines := strings.Split(edit.NewText, "\n")
|
||||
|
||||
if len(newLines) == 1 {
|
||||
// Single line change
|
||||
result = append(result, prefix+newLines[0]+suffix)
|
||||
} else {
|
||||
// Multi-line change
|
||||
result = append(result, prefix+newLines[0])
|
||||
result = append(result, newLines[1:len(newLines)-1]...)
|
||||
result = append(result, newLines[len(newLines)-1]+suffix)
|
||||
}
|
||||
}
|
||||
|
||||
// Add remaining lines
|
||||
if endLine+1 < len(lines) {
|
||||
result = append(result, lines[endLine+1:]...)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// applyDocumentChange applies a DocumentChange (create/rename/delete operations)
|
||||
func applyDocumentChange(change protocol.DocumentChange) error {
|
||||
if change.CreateFile != nil {
|
||||
path := strings.TrimPrefix(string(change.CreateFile.URI), "file://")
|
||||
if change.CreateFile.Options != nil {
|
||||
if change.CreateFile.Options.Overwrite {
|
||||
// Proceed with overwrite
|
||||
} else if change.CreateFile.Options.IgnoreIfExists {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return nil // File exists and we're ignoring it
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := os.WriteFile(path, []byte(""), 0o644); err != nil {
|
||||
return fmt.Errorf("failed to create file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if change.DeleteFile != nil {
|
||||
path := strings.TrimPrefix(string(change.DeleteFile.URI), "file://")
|
||||
if change.DeleteFile.Options != nil && change.DeleteFile.Options.Recursive {
|
||||
if err := os.RemoveAll(path); err != nil {
|
||||
return fmt.Errorf("failed to delete directory recursively: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := os.Remove(path); err != nil {
|
||||
return fmt.Errorf("failed to delete file: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if change.RenameFile != nil {
|
||||
oldPath := strings.TrimPrefix(string(change.RenameFile.OldURI), "file://")
|
||||
newPath := strings.TrimPrefix(string(change.RenameFile.NewURI), "file://")
|
||||
if change.RenameFile.Options != nil {
|
||||
if !change.RenameFile.Options.Overwrite {
|
||||
if _, err := os.Stat(newPath); err == nil {
|
||||
return fmt.Errorf("target file already exists and overwrite is not allowed: %s", newPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := os.Rename(oldPath, newPath); err != nil {
|
||||
return fmt.Errorf("failed to rename file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if change.TextDocumentEdit != nil {
|
||||
textEdits := make([]protocol.TextEdit, len(change.TextDocumentEdit.Edits))
|
||||
for i, edit := range change.TextDocumentEdit.Edits {
|
||||
var err error
|
||||
textEdits[i], err = edit.AsTextEdit()
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid edit type: %w", err)
|
||||
}
|
||||
}
|
||||
return applyTextEdits(change.TextDocumentEdit.TextDocument.URI, textEdits)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyWorkspaceEdit applies the given WorkspaceEdit to the filesystem
|
||||
func ApplyWorkspaceEdit(edit protocol.WorkspaceEdit) error {
|
||||
// Handle Changes field
|
||||
for uri, textEdits := range edit.Changes {
|
||||
if err := applyTextEdits(uri, textEdits); err != nil {
|
||||
return fmt.Errorf("failed to apply text edits: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle DocumentChanges field
|
||||
for _, change := range edit.DocumentChanges {
|
||||
if err := applyDocumentChange(change); err != nil {
|
||||
return fmt.Errorf("failed to apply document change: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func rangesOverlap(r1, r2 protocol.Range) bool {
|
||||
if r1.Start.Line > r2.End.Line || r2.Start.Line > r1.End.Line {
|
||||
return false
|
||||
}
|
||||
if r1.Start.Line == r2.End.Line && r1.Start.Character > r2.End.Character {
|
||||
return false
|
||||
}
|
||||
if r2.Start.Line == r1.End.Line && r2.Start.Character > r1.End.Character {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,8 +0,0 @@
|
||||
package message
|
||||
|
||||
type Attachment struct {
|
||||
FilePath string
|
||||
FileName string
|
||||
MimeType string
|
||||
Content []byte
|
||||
}
|
||||
@@ -1,325 +0,0 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
)
|
||||
|
||||
type MessageRole string
|
||||
|
||||
const (
|
||||
Assistant MessageRole = "assistant"
|
||||
User MessageRole = "user"
|
||||
System MessageRole = "system"
|
||||
Tool MessageRole = "tool"
|
||||
)
|
||||
|
||||
type FinishReason string
|
||||
|
||||
const (
|
||||
FinishReasonEndTurn FinishReason = "end_turn"
|
||||
FinishReasonMaxTokens FinishReason = "max_tokens"
|
||||
FinishReasonToolUse FinishReason = "tool_use"
|
||||
FinishReasonCanceled FinishReason = "canceled"
|
||||
FinishReasonError FinishReason = "error"
|
||||
FinishReasonPermissionDenied FinishReason = "permission_denied"
|
||||
|
||||
// Should never happen
|
||||
FinishReasonUnknown FinishReason = "unknown"
|
||||
)
|
||||
|
||||
type ContentPart interface {
|
||||
isPart()
|
||||
}
|
||||
|
||||
type ReasoningContent struct {
|
||||
Thinking string `json:"thinking"`
|
||||
}
|
||||
|
||||
func (tc ReasoningContent) String() string {
|
||||
return tc.Thinking
|
||||
}
|
||||
func (ReasoningContent) isPart() {}
|
||||
|
||||
type TextContent struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
func (tc *TextContent) String() string {
|
||||
if tc == nil {
|
||||
return ""
|
||||
}
|
||||
return tc.Text
|
||||
}
|
||||
|
||||
func (TextContent) isPart() {}
|
||||
|
||||
type ImageURLContent struct {
|
||||
URL string `json:"url"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
}
|
||||
|
||||
func (iuc ImageURLContent) String() string {
|
||||
return iuc.URL
|
||||
}
|
||||
|
||||
func (ImageURLContent) isPart() {}
|
||||
|
||||
type BinaryContent struct {
|
||||
Path string
|
||||
MIMEType string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (bc BinaryContent) String(provider models.ModelProvider) string {
|
||||
base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
|
||||
if provider == models.ProviderOpenAI {
|
||||
return "data:" + bc.MIMEType + ";base64," + base64Encoded
|
||||
}
|
||||
return base64Encoded
|
||||
}
|
||||
|
||||
func (BinaryContent) isPart() {}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Input string `json:"input"`
|
||||
Type string `json:"type"`
|
||||
Finished bool `json:"finished"`
|
||||
}
|
||||
|
||||
func (ToolCall) isPart() {}
|
||||
|
||||
type ToolResult struct {
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
Name string `json:"name"`
|
||||
Content string `json:"content"`
|
||||
Metadata string `json:"metadata"`
|
||||
IsError bool `json:"is_error"`
|
||||
}
|
||||
|
||||
func (ToolResult) isPart() {}
|
||||
|
||||
type Finish struct {
|
||||
Reason FinishReason `json:"reason"`
|
||||
Time time.Time `json:"time"`
|
||||
}
|
||||
|
||||
type DBFinish struct {
|
||||
Reason FinishReason `json:"reason"`
|
||||
Time int64 `json:"time"`
|
||||
}
|
||||
|
||||
func (Finish) isPart() {}
|
||||
|
||||
func (m *Message) Content() *TextContent {
|
||||
for _, part := range m.Parts {
|
||||
if c, ok := part.(TextContent); ok {
|
||||
return &c
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Message) ReasoningContent() ReasoningContent {
|
||||
for _, part := range m.Parts {
|
||||
if c, ok := part.(ReasoningContent); ok {
|
||||
return c
|
||||
}
|
||||
}
|
||||
return ReasoningContent{}
|
||||
}
|
||||
|
||||
func (m *Message) ImageURLContent() []ImageURLContent {
|
||||
imageURLContents := make([]ImageURLContent, 0)
|
||||
for _, part := range m.Parts {
|
||||
if c, ok := part.(ImageURLContent); ok {
|
||||
imageURLContents = append(imageURLContents, c)
|
||||
}
|
||||
}
|
||||
return imageURLContents
|
||||
}
|
||||
|
||||
func (m *Message) BinaryContent() []BinaryContent {
|
||||
binaryContents := make([]BinaryContent, 0)
|
||||
for _, part := range m.Parts {
|
||||
if c, ok := part.(BinaryContent); ok {
|
||||
binaryContents = append(binaryContents, c)
|
||||
}
|
||||
}
|
||||
return binaryContents
|
||||
}
|
||||
|
||||
func (m *Message) ToolCalls() []ToolCall {
|
||||
toolCalls := make([]ToolCall, 0)
|
||||
for _, part := range m.Parts {
|
||||
if c, ok := part.(ToolCall); ok {
|
||||
toolCalls = append(toolCalls, c)
|
||||
}
|
||||
}
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func (m *Message) ToolResults() []ToolResult {
|
||||
toolResults := make([]ToolResult, 0)
|
||||
for _, part := range m.Parts {
|
||||
if c, ok := part.(ToolResult); ok {
|
||||
toolResults = append(toolResults, c)
|
||||
}
|
||||
}
|
||||
return toolResults
|
||||
}
|
||||
|
||||
func (m *Message) IsFinished() bool {
|
||||
for _, part := range m.Parts {
|
||||
if _, ok := part.(Finish); ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Message) FinishPart() *Finish {
|
||||
for _, part := range m.Parts {
|
||||
if c, ok := part.(Finish); ok {
|
||||
return &c
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Message) FinishReason() FinishReason {
|
||||
for _, part := range m.Parts {
|
||||
if c, ok := part.(Finish); ok {
|
||||
return c.Reason
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Message) IsThinking() bool {
|
||||
if m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Message) AppendContent(delta string) {
|
||||
found := false
|
||||
for i, part := range m.Parts {
|
||||
if c, ok := part.(TextContent); ok {
|
||||
m.Parts[i] = TextContent{Text: c.Text + delta}
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
m.Parts = append(m.Parts, TextContent{Text: delta})
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Message) AppendReasoningContent(delta string) {
|
||||
found := false
|
||||
for i, part := range m.Parts {
|
||||
if c, ok := part.(ReasoningContent); ok {
|
||||
m.Parts[i] = ReasoningContent{Thinking: c.Thinking + delta}
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
m.Parts = append(m.Parts, ReasoningContent{Thinking: delta})
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Message) FinishToolCall(toolCallID string) {
|
||||
for i, part := range m.Parts {
|
||||
if c, ok := part.(ToolCall); ok {
|
||||
if c.ID == toolCallID {
|
||||
m.Parts[i] = ToolCall{
|
||||
ID: c.ID,
|
||||
Name: c.Name,
|
||||
Input: c.Input,
|
||||
Type: c.Type,
|
||||
Finished: true,
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
|
||||
for i, part := range m.Parts {
|
||||
if c, ok := part.(ToolCall); ok {
|
||||
if c.ID == toolCallID {
|
||||
m.Parts[i] = ToolCall{
|
||||
ID: c.ID,
|
||||
Name: c.Name,
|
||||
Input: c.Input + inputDelta,
|
||||
Type: c.Type,
|
||||
Finished: c.Finished,
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Message) AddToolCall(tc ToolCall) {
|
||||
for i, part := range m.Parts {
|
||||
if c, ok := part.(ToolCall); ok {
|
||||
if c.ID == tc.ID {
|
||||
m.Parts[i] = tc
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
m.Parts = append(m.Parts, tc)
|
||||
}
|
||||
|
||||
func (m *Message) SetToolCalls(tc []ToolCall) {
|
||||
// remove any existing tool call part it could have multiple
|
||||
parts := make([]ContentPart, 0)
|
||||
for _, part := range m.Parts {
|
||||
if _, ok := part.(ToolCall); ok {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
m.Parts = parts
|
||||
for _, toolCall := range tc {
|
||||
m.Parts = append(m.Parts, toolCall)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Message) AddToolResult(tr ToolResult) {
|
||||
m.Parts = append(m.Parts, tr)
|
||||
}
|
||||
|
||||
func (m *Message) SetToolResults(tr []ToolResult) {
|
||||
for _, toolResult := range tr {
|
||||
m.Parts = append(m.Parts, toolResult)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Message) AddFinish(reason FinishReason) {
|
||||
// remove any existing finish part
|
||||
for i, part := range m.Parts {
|
||||
if _, ok := part.(Finish); ok {
|
||||
m.Parts = slices.Delete(m.Parts, i, i+1)
|
||||
break
|
||||
}
|
||||
}
|
||||
m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now()})
|
||||
}
|
||||
|
||||
func (m *Message) AddImageURL(url, detail string) {
|
||||
m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
|
||||
}
|
||||
|
||||
func (m *Message) AddBinary(mimeType string, data []byte) {
|
||||
m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
|
||||
}
|
||||
@@ -1,503 +0,0 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sst/opencode/internal/db"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
"github.com/sst/opencode/internal/pubsub"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
ID string
|
||||
Role MessageRole
|
||||
SessionID string
|
||||
Parts []ContentPart
|
||||
Model models.ModelID
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
EventMessageCreated pubsub.EventType = "message_created"
|
||||
EventMessageUpdated pubsub.EventType = "message_updated"
|
||||
EventMessageDeleted pubsub.EventType = "message_deleted"
|
||||
)
|
||||
|
||||
type CreateMessageParams struct {
|
||||
Role MessageRole
|
||||
Parts []ContentPart
|
||||
Model models.ModelID
|
||||
}
|
||||
|
||||
type Service interface {
|
||||
pubsub.Subscriber[Message]
|
||||
|
||||
Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
|
||||
Update(ctx context.Context, message Message) (Message, error)
|
||||
Get(ctx context.Context, id string) (Message, error)
|
||||
List(ctx context.Context, sessionID string) ([]Message, error)
|
||||
ListAfter(ctx context.Context, sessionID string, timestamp time.Time) ([]Message, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
DeleteSessionMessages(ctx context.Context, sessionID string) error
|
||||
}
|
||||
|
||||
type service struct {
|
||||
db *db.Queries
|
||||
broker *pubsub.Broker[Message]
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var globalMessageService *service
|
||||
|
||||
func InitService(dbConn *sql.DB) error {
|
||||
if globalMessageService != nil {
|
||||
return fmt.Errorf("message service already initialized")
|
||||
}
|
||||
queries := db.New(dbConn)
|
||||
broker := pubsub.NewBroker[Message]()
|
||||
|
||||
globalMessageService = &service{
|
||||
db: queries,
|
||||
broker: broker,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetService() Service {
|
||||
if globalMessageService == nil {
|
||||
panic("message service not initialized. Call message.InitService() first.")
|
||||
}
|
||||
return globalMessageService
|
||||
}
|
||||
|
||||
func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
isFinished := false
|
||||
for _, p := range params.Parts {
|
||||
if _, ok := p.(Finish); ok {
|
||||
isFinished = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if params.Role == User && !isFinished {
|
||||
params.Parts = append(params.Parts, Finish{Reason: FinishReasonEndTurn, Time: time.Now()})
|
||||
}
|
||||
|
||||
partsJSON, err := marshallParts(params.Parts)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("failed to marshal message parts: %w", err)
|
||||
}
|
||||
|
||||
dbMsgParams := db.CreateMessageParams{
|
||||
ID: uuid.New().String(),
|
||||
SessionID: sessionID,
|
||||
Role: string(params.Role),
|
||||
Parts: string(partsJSON),
|
||||
Model: sql.NullString{String: string(params.Model), Valid: params.Model != ""},
|
||||
}
|
||||
|
||||
dbMessage, err := s.db.CreateMessage(ctx, dbMsgParams)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("db.CreateMessage: %w", err)
|
||||
}
|
||||
|
||||
message, err := s.fromDBItem(dbMessage)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("failed to convert DB message: %w", err)
|
||||
}
|
||||
|
||||
s.broker.Publish(EventMessageCreated, message)
|
||||
return message, nil
|
||||
}
|
||||
|
||||
func (s *service) Update(ctx context.Context, message Message) (Message, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if message.ID == "" {
|
||||
return Message{}, fmt.Errorf("cannot update message with empty ID")
|
||||
}
|
||||
|
||||
partsJSON, err := marshallParts(message.Parts)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("failed to marshal message parts for update: %w", err)
|
||||
}
|
||||
|
||||
var dbFinishedAt sql.NullString
|
||||
finishPart := message.FinishPart()
|
||||
if finishPart != nil && !finishPart.Time.IsZero() {
|
||||
dbFinishedAt = sql.NullString{
|
||||
String: finishPart.Time.UTC().Format(time.RFC3339Nano),
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdatedAt is handled by the DB trigger (strftime('%s', 'now'))
|
||||
err = s.db.UpdateMessage(ctx, db.UpdateMessageParams{
|
||||
ID: message.ID,
|
||||
Parts: string(partsJSON),
|
||||
FinishedAt: dbFinishedAt,
|
||||
})
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("db.UpdateMessage: %w", err)
|
||||
}
|
||||
|
||||
dbUpdatedMessage, err := s.db.GetMessage(ctx, message.ID)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("failed to fetch message after update: %w", err)
|
||||
}
|
||||
updatedMessage, err := s.fromDBItem(dbUpdatedMessage)
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("failed to convert updated DB message: %w", err)
|
||||
}
|
||||
|
||||
s.broker.Publish(EventMessageUpdated, updatedMessage)
|
||||
return updatedMessage, nil
|
||||
}
|
||||
|
||||
func (s *service) Get(ctx context.Context, id string) (Message, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
dbMessage, err := s.db.GetMessage(ctx, id)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return Message{}, fmt.Errorf("message with ID '%s' not found", id)
|
||||
}
|
||||
return Message{}, fmt.Errorf("db.GetMessage: %w", err)
|
||||
}
|
||||
return s.fromDBItem(dbMessage)
|
||||
}
|
||||
|
||||
func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
dbMessages, err := s.db.ListMessagesBySession(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListMessagesBySession: %w", err)
|
||||
}
|
||||
messages := make([]Message, len(dbMessages))
|
||||
for i, dbMsg := range dbMessages {
|
||||
msg, convErr := s.fromDBItem(dbMsg)
|
||||
if convErr != nil {
|
||||
return nil, fmt.Errorf("failed to convert DB message at index %d: %w", i, convErr)
|
||||
}
|
||||
messages[i] = msg
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (s *service) ListAfter(ctx context.Context, sessionID string, timestamp time.Time) ([]Message, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
dbMessages, err := s.db.ListMessagesBySessionAfter(ctx, db.ListMessagesBySessionAfterParams{
|
||||
SessionID: sessionID,
|
||||
CreatedAt: timestamp.Format(time.RFC3339Nano),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListMessagesBySessionAfter: %w", err)
|
||||
}
|
||||
messages := make([]Message, len(dbMessages))
|
||||
for i, dbMsg := range dbMessages {
|
||||
msg, convErr := s.fromDBItem(dbMsg)
|
||||
if convErr != nil {
|
||||
return nil, fmt.Errorf("failed to convert DB message at index %d (ListAfter): %w", i, convErr)
|
||||
}
|
||||
messages[i] = msg
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (s *service) Delete(ctx context.Context, id string) error {
|
||||
s.mu.Lock()
|
||||
messageToPublish, err := s.getServiceForPublish(ctx, id)
|
||||
s.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
// If error was due to not found, it's not a critical failure for deletion intent
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return nil // Or return the error if strictness is required
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
err = s.db.DeleteMessage(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.DeleteMessage: %w", err)
|
||||
}
|
||||
|
||||
if messageToPublish != nil {
|
||||
s.broker.Publish(EventMessageDeleted, *messageToPublish)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) getServiceForPublish(ctx context.Context, id string) (*Message, error) {
|
||||
dbMsg, err := s.db.GetMessage(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg, convErr := s.fromDBItem(dbMsg)
|
||||
if convErr != nil {
|
||||
return nil, fmt.Errorf("failed to convert DB message for publishing: %w", convErr)
|
||||
}
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
messagesToDelete, err := s.db.ListMessagesBySession(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list messages for deletion: %w", err)
|
||||
}
|
||||
|
||||
err = s.db.DeleteSessionMessages(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.DeleteSessionMessages: %w", err)
|
||||
}
|
||||
|
||||
for _, dbMsg := range messagesToDelete {
|
||||
msg, convErr := s.fromDBItem(dbMsg)
|
||||
if convErr == nil {
|
||||
s.broker.Publish(EventMessageDeleted, msg)
|
||||
} else {
|
||||
slog.Error("Failed to convert DB message for delete event publishing", "id", dbMsg.ID, "error", convErr)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[Message] {
|
||||
return s.broker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func (s *service) fromDBItem(item db.Message) (Message, error) {
|
||||
parts, err := unmarshallParts([]byte(item.Parts))
|
||||
if err != nil {
|
||||
return Message{}, fmt.Errorf("unmarshallParts for message ID %s: %w. Raw parts: %s", item.ID, err, item.Parts)
|
||||
}
|
||||
|
||||
// Parse timestamps from ISO strings
|
||||
createdAt, err := time.Parse(time.RFC3339Nano, item.CreatedAt)
|
||||
if err != nil {
|
||||
slog.Error("Failed to parse created_at", "value", item.CreatedAt, "error", err)
|
||||
createdAt = time.Now() // Fallback
|
||||
}
|
||||
|
||||
updatedAt, err := time.Parse(time.RFC3339Nano, item.UpdatedAt)
|
||||
if err != nil {
|
||||
slog.Error("Failed to parse created_at", "value", item.CreatedAt, "error", err)
|
||||
updatedAt = time.Now() // Fallback
|
||||
}
|
||||
|
||||
msg := Message{
|
||||
ID: item.ID,
|
||||
SessionID: item.SessionID,
|
||||
Role: MessageRole(item.Role),
|
||||
Parts: parts,
|
||||
Model: models.ModelID(item.Model.String),
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
|
||||
return GetService().Create(ctx, sessionID, params)
|
||||
}
|
||||
|
||||
func Update(ctx context.Context, message Message) (Message, error) {
|
||||
return GetService().Update(ctx, message)
|
||||
}
|
||||
|
||||
func Get(ctx context.Context, id string) (Message, error) {
|
||||
return GetService().Get(ctx, id)
|
||||
}
|
||||
|
||||
func List(ctx context.Context, sessionID string) ([]Message, error) {
|
||||
return GetService().List(ctx, sessionID)
|
||||
}
|
||||
|
||||
func ListAfter(ctx context.Context, sessionID string, timestamp time.Time) ([]Message, error) {
|
||||
return GetService().ListAfter(ctx, sessionID, timestamp)
|
||||
}
|
||||
|
||||
func Delete(ctx context.Context, id string) error {
|
||||
return GetService().Delete(ctx, id)
|
||||
}
|
||||
|
||||
func DeleteSessionMessages(ctx context.Context, sessionID string) error {
|
||||
return GetService().DeleteSessionMessages(ctx, sessionID)
|
||||
}
|
||||
|
||||
func Subscribe(ctx context.Context) <-chan pubsub.Event[Message] {
|
||||
return GetService().Subscribe(ctx)
|
||||
}
|
||||
|
||||
type partType string
|
||||
|
||||
const (
|
||||
reasoningType partType = "reasoning"
|
||||
textType partType = "text"
|
||||
imageURLType partType = "image_url"
|
||||
binaryType partType = "binary"
|
||||
toolCallType partType = "tool_call"
|
||||
toolResultType partType = "tool_result"
|
||||
finishType partType = "finish"
|
||||
)
|
||||
|
||||
type partWrapper struct {
|
||||
Type partType `json:"type"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
func marshallParts(parts []ContentPart) ([]byte, error) {
|
||||
wrappedParts := make([]json.RawMessage, len(parts))
|
||||
for i, part := range parts {
|
||||
var typ partType
|
||||
var dataBytes []byte
|
||||
var err error
|
||||
|
||||
switch p := part.(type) {
|
||||
case ReasoningContent:
|
||||
typ = reasoningType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case TextContent:
|
||||
typ = textType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case *TextContent:
|
||||
typ = textType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case ImageURLContent:
|
||||
typ = imageURLType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case BinaryContent:
|
||||
typ = binaryType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case ToolCall:
|
||||
typ = toolCallType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case ToolResult:
|
||||
typ = toolResultType
|
||||
dataBytes, err = json.Marshal(p)
|
||||
case Finish:
|
||||
typ = finishType
|
||||
var dbFinish DBFinish
|
||||
dbFinish.Reason = p.Reason
|
||||
dbFinish.Time = p.Time.UnixMilli()
|
||||
dataBytes, err = json.Marshal(dbFinish)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown part type for marshalling: %T", part)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal part data for type %s: %w", typ, err)
|
||||
}
|
||||
wrapper := struct {
|
||||
Type partType `json:"type"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}{Type: typ, Data: dataBytes}
|
||||
wrappedBytes, err := json.Marshal(wrapper)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal part wrapper for type %s: %w", typ, err)
|
||||
}
|
||||
wrappedParts[i] = wrappedBytes
|
||||
}
|
||||
return json.Marshal(wrappedParts)
|
||||
}
|
||||
|
||||
func unmarshallParts(data []byte) ([]ContentPart, error) {
|
||||
var rawMessages []json.RawMessage
|
||||
if err := json.Unmarshal(data, &rawMessages); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal parts data as array: %w. Data: %s", err, string(data))
|
||||
}
|
||||
|
||||
parts := make([]ContentPart, 0, len(rawMessages))
|
||||
for _, rawPart := range rawMessages {
|
||||
var wrapper partWrapper
|
||||
if err := json.Unmarshal(rawPart, &wrapper); err != nil {
|
||||
// Fallback for old format where parts might be just TextContent string
|
||||
var text string
|
||||
if errText := json.Unmarshal(rawPart, &text); errText == nil {
|
||||
parts = append(parts, TextContent{Text: text})
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("failed to unmarshal part wrapper: %w. Raw part: %s", err, string(rawPart))
|
||||
}
|
||||
|
||||
switch wrapper.Type {
|
||||
case reasoningType:
|
||||
var p ReasoningContent
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal ReasoningContent: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, p)
|
||||
case textType:
|
||||
var p TextContent
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal TextContent: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, p)
|
||||
case imageURLType:
|
||||
var p ImageURLContent
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal ImageURLContent: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, p)
|
||||
case binaryType:
|
||||
var p BinaryContent
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal BinaryContent: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, p)
|
||||
case toolCallType:
|
||||
var p ToolCall
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal ToolCall: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, p)
|
||||
case toolResultType:
|
||||
var p ToolResult
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal ToolResult: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, p)
|
||||
case finishType:
|
||||
var p DBFinish
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal Finish: %w. Data: %s", err, string(wrapper.Data))
|
||||
}
|
||||
parts = append(parts, Finish{Reason: FinishReason(p.Reason), Time: time.UnixMilli(p.Time)})
|
||||
default:
|
||||
slog.Warn("Unknown part type during unmarshalling, attempting to parse as TextContent", "type", wrapper.Type, "data", string(wrapper.Data))
|
||||
// Fallback: if type is unknown or empty, try to parse data as TextContent directly
|
||||
var p TextContent
|
||||
if err := json.Unmarshal(wrapper.Data, &p); err == nil {
|
||||
parts = append(parts, p)
|
||||
} else {
|
||||
// If that also fails, log it but continue if possible, or return error
|
||||
slog.Error("Failed to unmarshal unknown part type and fallback to TextContent failed", "type", wrapper.Type, "data", string(wrapper.Data), "error", err)
|
||||
// Depending on strictness, you might return an error here:
|
||||
// return nil, fmt.Errorf("unknown part type '%s' and failed fallback: %w", wrapper.Type, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return parts, nil
|
||||
}
|
||||
@@ -1,246 +0,0 @@
|
||||
package permission
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/pubsub"
|
||||
)
|
||||
|
||||
var ErrorPermissionDenied = errors.New("permission denied")
|
||||
|
||||
type CreatePermissionRequest struct {
|
||||
SessionID string `json:"session_id"`
|
||||
ToolName string `json:"tool_name"`
|
||||
Description string `json:"description"`
|
||||
Action string `json:"action"`
|
||||
Params any `json:"params"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
type PermissionRequest struct {
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"session_id"`
|
||||
ToolName string `json:"tool_name"`
|
||||
Description string `json:"description"`
|
||||
Action string `json:"action"`
|
||||
Params any `json:"params"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
type PermissionResponse struct {
|
||||
Request PermissionRequest
|
||||
Granted bool
|
||||
}
|
||||
|
||||
const (
|
||||
EventPermissionRequested pubsub.EventType = "permission_requested"
|
||||
EventPermissionGranted pubsub.EventType = "permission_granted"
|
||||
EventPermissionDenied pubsub.EventType = "permission_denied"
|
||||
EventPermissionPersisted pubsub.EventType = "permission_persisted"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
pubsub.Subscriber[PermissionRequest]
|
||||
SubscribeToResponseEvents(ctx context.Context) <-chan pubsub.Event[PermissionResponse]
|
||||
|
||||
GrantPersistant(ctx context.Context, permission PermissionRequest)
|
||||
Grant(ctx context.Context, permission PermissionRequest)
|
||||
Deny(ctx context.Context, permission PermissionRequest)
|
||||
Request(ctx context.Context, opts CreatePermissionRequest) bool
|
||||
AutoApproveSession(ctx context.Context, sessionID string)
|
||||
IsAutoApproved(ctx context.Context, sessionID string) bool
|
||||
}
|
||||
|
||||
type permissionService struct {
|
||||
broker *pubsub.Broker[PermissionRequest]
|
||||
responseBroker *pubsub.Broker[PermissionResponse]
|
||||
|
||||
sessionPermissions map[string][]PermissionRequest
|
||||
pendingRequests sync.Map
|
||||
autoApproveSessions map[string]bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var globalPermissionService *permissionService
|
||||
|
||||
func InitService() error {
|
||||
if globalPermissionService != nil {
|
||||
return fmt.Errorf("permission service already initialized")
|
||||
}
|
||||
globalPermissionService = &permissionService{
|
||||
broker: pubsub.NewBroker[PermissionRequest](),
|
||||
responseBroker: pubsub.NewBroker[PermissionResponse](),
|
||||
sessionPermissions: make(map[string][]PermissionRequest),
|
||||
autoApproveSessions: make(map[string]bool),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetService() *permissionService {
|
||||
if globalPermissionService == nil {
|
||||
panic("permission service not initialized. Call permission.InitService() first.")
|
||||
}
|
||||
return globalPermissionService
|
||||
}
|
||||
|
||||
func (s *permissionService) GrantPersistant(ctx context.Context, permission PermissionRequest) {
|
||||
s.mu.Lock()
|
||||
s.sessionPermissions[permission.SessionID] = append(s.sessionPermissions[permission.SessionID], permission)
|
||||
s.mu.Unlock()
|
||||
|
||||
respCh, ok := s.pendingRequests.Load(permission.ID)
|
||||
if ok {
|
||||
select {
|
||||
case respCh.(chan bool) <- true:
|
||||
case <-ctx.Done():
|
||||
slog.Warn("Context cancelled while sending grant persistent response", "request_id", permission.ID)
|
||||
}
|
||||
}
|
||||
s.responseBroker.Publish(EventPermissionPersisted, PermissionResponse{Request: permission, Granted: true})
|
||||
}
|
||||
|
||||
func (s *permissionService) Grant(ctx context.Context, permission PermissionRequest) {
|
||||
respCh, ok := s.pendingRequests.Load(permission.ID)
|
||||
if ok {
|
||||
select {
|
||||
case respCh.(chan bool) <- true:
|
||||
case <-ctx.Done():
|
||||
slog.Warn("Context cancelled while sending grant response", "request_id", permission.ID)
|
||||
}
|
||||
}
|
||||
s.responseBroker.Publish(EventPermissionGranted, PermissionResponse{Request: permission, Granted: true})
|
||||
}
|
||||
|
||||
func (s *permissionService) Deny(ctx context.Context, permission PermissionRequest) {
|
||||
respCh, ok := s.pendingRequests.Load(permission.ID)
|
||||
if ok {
|
||||
select {
|
||||
case respCh.(chan bool) <- false:
|
||||
case <-ctx.Done():
|
||||
slog.Warn("Context cancelled while sending deny response", "request_id", permission.ID)
|
||||
}
|
||||
}
|
||||
s.responseBroker.Publish(EventPermissionDenied, PermissionResponse{Request: permission, Granted: false})
|
||||
}
|
||||
|
||||
func (s *permissionService) Request(ctx context.Context, opts CreatePermissionRequest) bool {
|
||||
s.mu.RLock()
|
||||
if s.autoApproveSessions[opts.SessionID] {
|
||||
s.mu.RUnlock()
|
||||
return true
|
||||
}
|
||||
|
||||
requestPath := opts.Path
|
||||
if !filepath.IsAbs(requestPath) {
|
||||
requestPath = filepath.Join(config.WorkingDirectory(), requestPath)
|
||||
}
|
||||
requestPath = filepath.Clean(requestPath)
|
||||
|
||||
if permissions, ok := s.sessionPermissions[opts.SessionID]; ok {
|
||||
for _, p := range permissions {
|
||||
storedPath := p.Path
|
||||
if !filepath.IsAbs(storedPath) {
|
||||
storedPath = filepath.Join(config.WorkingDirectory(), storedPath)
|
||||
}
|
||||
storedPath = filepath.Clean(storedPath)
|
||||
|
||||
if p.ToolName == opts.ToolName && p.Action == opts.Action &&
|
||||
(requestPath == storedPath || strings.HasPrefix(requestPath, storedPath+string(filepath.Separator))) {
|
||||
s.mu.RUnlock()
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
normalizedPath := opts.Path
|
||||
if !filepath.IsAbs(normalizedPath) {
|
||||
normalizedPath = filepath.Join(config.WorkingDirectory(), normalizedPath)
|
||||
}
|
||||
normalizedPath = filepath.Clean(normalizedPath)
|
||||
|
||||
permissionReq := PermissionRequest{
|
||||
ID: uuid.New().String(),
|
||||
Path: normalizedPath,
|
||||
SessionID: opts.SessionID,
|
||||
ToolName: opts.ToolName,
|
||||
Description: opts.Description,
|
||||
Action: opts.Action,
|
||||
Params: opts.Params,
|
||||
}
|
||||
|
||||
respCh := make(chan bool, 1)
|
||||
s.pendingRequests.Store(permissionReq.ID, respCh)
|
||||
defer s.pendingRequests.Delete(permissionReq.ID)
|
||||
|
||||
s.broker.Publish(EventPermissionRequested, permissionReq)
|
||||
|
||||
select {
|
||||
case resp := <-respCh:
|
||||
return resp
|
||||
case <-ctx.Done():
|
||||
slog.Warn("Permission request timed out or context cancelled", "request_id", permissionReq.ID, "tool", opts.ToolName)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *permissionService) AutoApproveSession(ctx context.Context, sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.autoApproveSessions[sessionID] = true
|
||||
}
|
||||
|
||||
func (s *permissionService) IsAutoApproved(ctx context.Context, sessionID string) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.autoApproveSessions[sessionID]
|
||||
}
|
||||
|
||||
func (s *permissionService) Subscribe(ctx context.Context) <-chan pubsub.Event[PermissionRequest] {
|
||||
return s.broker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func (s *permissionService) SubscribeToResponseEvents(ctx context.Context) <-chan pubsub.Event[PermissionResponse] {
|
||||
return s.responseBroker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func GrantPersistant(ctx context.Context, permission PermissionRequest) {
|
||||
GetService().GrantPersistant(ctx, permission)
|
||||
}
|
||||
|
||||
func Grant(ctx context.Context, permission PermissionRequest) {
|
||||
GetService().Grant(ctx, permission)
|
||||
}
|
||||
|
||||
func Deny(ctx context.Context, permission PermissionRequest) {
|
||||
GetService().Deny(ctx, permission)
|
||||
}
|
||||
|
||||
func Request(ctx context.Context, opts CreatePermissionRequest) bool {
|
||||
return GetService().Request(ctx, opts)
|
||||
}
|
||||
|
||||
func AutoApproveSession(ctx context.Context, sessionID string) {
|
||||
GetService().AutoApproveSession(ctx, sessionID)
|
||||
}
|
||||
|
||||
func IsAutoApproved(ctx context.Context, sessionID string) bool {
|
||||
return GetService().IsAutoApproved(ctx, sessionID)
|
||||
}
|
||||
|
||||
func SubscribeToRequests(ctx context.Context) <-chan pubsub.Event[PermissionRequest] {
|
||||
return GetService().Subscribe(ctx)
|
||||
}
|
||||
|
||||
func SubscribeToResponses(ctx context.Context) <-chan pubsub.Event[PermissionResponse] {
|
||||
return GetService().SubscribeToResponseEvents(ctx)
|
||||
}
|
||||
@@ -1,255 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sst/opencode/internal/db"
|
||||
"github.com/sst/opencode/internal/pubsub"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
ID string
|
||||
ParentSessionID string
|
||||
Title string
|
||||
MessageCount int64
|
||||
PromptTokens int64
|
||||
CompletionTokens int64
|
||||
Cost float64
|
||||
Summary string
|
||||
SummarizedAt time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
EventSessionCreated pubsub.EventType = "session_created"
|
||||
EventSessionUpdated pubsub.EventType = "session_updated"
|
||||
EventSessionDeleted pubsub.EventType = "session_deleted"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
pubsub.Subscriber[Session]
|
||||
|
||||
Create(ctx context.Context, title string) (Session, error)
|
||||
CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
|
||||
Get(ctx context.Context, id string) (Session, error)
|
||||
List(ctx context.Context) ([]Session, error)
|
||||
Update(ctx context.Context, session Session) (Session, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
type service struct {
|
||||
db *db.Queries
|
||||
broker *pubsub.Broker[Session]
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var globalSessionService *service
|
||||
|
||||
func InitService(dbConn *sql.DB) error {
|
||||
if globalSessionService != nil {
|
||||
return fmt.Errorf("session service already initialized")
|
||||
}
|
||||
queries := db.New(dbConn)
|
||||
broker := pubsub.NewBroker[Session]()
|
||||
|
||||
globalSessionService = &service{
|
||||
db: queries,
|
||||
broker: broker,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetService() Service {
|
||||
if globalSessionService == nil {
|
||||
panic("session service not initialized. Call session.InitService() first.")
|
||||
}
|
||||
return globalSessionService
|
||||
}
|
||||
|
||||
func (s *service) Create(ctx context.Context, title string) (Session, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if title == "" {
|
||||
title = "New Session - " + time.Now().Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
dbSessParams := db.CreateSessionParams{
|
||||
ID: uuid.New().String(),
|
||||
Title: title,
|
||||
}
|
||||
dbSession, err := s.db.CreateSession(ctx, dbSessParams)
|
||||
if err != nil {
|
||||
return Session{}, fmt.Errorf("db.CreateSession: %w", err)
|
||||
}
|
||||
|
||||
session := s.fromDBItem(dbSession)
|
||||
s.broker.Publish(EventSessionCreated, session)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if title == "" {
|
||||
title = "Task Session - " + time.Now().Format("2006-01-02 15:04:05")
|
||||
}
|
||||
if toolCallID == "" {
|
||||
toolCallID = uuid.New().String()
|
||||
}
|
||||
|
||||
dbSessParams := db.CreateSessionParams{
|
||||
ID: toolCallID,
|
||||
ParentSessionID: sql.NullString{String: parentSessionID, Valid: parentSessionID != ""},
|
||||
Title: title,
|
||||
}
|
||||
dbSession, err := s.db.CreateSession(ctx, dbSessParams)
|
||||
if err != nil {
|
||||
return Session{}, fmt.Errorf("db.CreateTaskSession: %w", err)
|
||||
}
|
||||
session := s.fromDBItem(dbSession)
|
||||
s.broker.Publish(EventSessionCreated, session)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *service) Get(ctx context.Context, id string) (Session, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
dbSession, err := s.db.GetSessionByID(ctx, id)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return Session{}, fmt.Errorf("session ID '%s' not found", id)
|
||||
}
|
||||
return Session{}, fmt.Errorf("db.GetSessionByID: %w", err)
|
||||
}
|
||||
return s.fromDBItem(dbSession), nil
|
||||
}
|
||||
|
||||
func (s *service) List(ctx context.Context) ([]Session, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
dbSessions, err := s.db.ListSessions(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.ListSessions: %w", err)
|
||||
}
|
||||
sessions := make([]Session, len(dbSessions))
|
||||
for i, dbSess := range dbSessions {
|
||||
sessions[i] = s.fromDBItem(dbSess)
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (s *service) Update(ctx context.Context, session Session) (Session, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if session.ID == "" {
|
||||
return Session{}, fmt.Errorf("cannot update session with empty ID")
|
||||
}
|
||||
|
||||
params := db.UpdateSessionParams{
|
||||
ID: session.ID,
|
||||
Title: session.Title,
|
||||
PromptTokens: session.PromptTokens,
|
||||
CompletionTokens: session.CompletionTokens,
|
||||
Cost: session.Cost,
|
||||
Summary: sql.NullString{String: session.Summary, Valid: session.Summary != ""},
|
||||
SummarizedAt: sql.NullString{String: session.SummarizedAt.UTC().Format(time.RFC3339Nano), Valid: !session.SummarizedAt.IsZero()},
|
||||
}
|
||||
dbSession, err := s.db.UpdateSession(ctx, params)
|
||||
if err != nil {
|
||||
return Session{}, fmt.Errorf("db.UpdateSession: %w", err)
|
||||
}
|
||||
updatedSession := s.fromDBItem(dbSession)
|
||||
s.broker.Publish(EventSessionUpdated, updatedSession)
|
||||
return updatedSession, nil
|
||||
}
|
||||
|
||||
func (s *service) Delete(ctx context.Context, id string) error {
|
||||
s.mu.Lock()
|
||||
dbSess, err := s.db.GetSessionByID(ctx, id)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("session ID '%s' not found for deletion", id)
|
||||
}
|
||||
return fmt.Errorf("db.GetSessionByID before delete: %w", err)
|
||||
}
|
||||
sessionToPublish := s.fromDBItem(dbSess)
|
||||
s.mu.Unlock()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
err = s.db.DeleteSession(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.DeleteSession: %w", err)
|
||||
}
|
||||
s.broker.Publish(EventSessionDeleted, sessionToPublish)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[Session] {
|
||||
return s.broker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func (s *service) fromDBItem(item db.Session) Session {
|
||||
var summarizedAt time.Time
|
||||
if item.SummarizedAt.Valid {
|
||||
parsedTime, err := time.Parse(time.RFC3339Nano, item.SummarizedAt.String)
|
||||
if err == nil {
|
||||
summarizedAt = parsedTime
|
||||
}
|
||||
}
|
||||
|
||||
createdAt, _ := time.Parse(time.RFC3339Nano, item.CreatedAt)
|
||||
updatedAt, _ := time.Parse(time.RFC3339Nano, item.UpdatedAt)
|
||||
|
||||
return Session{
|
||||
ID: item.ID,
|
||||
ParentSessionID: item.ParentSessionID.String,
|
||||
Title: item.Title,
|
||||
MessageCount: item.MessageCount,
|
||||
PromptTokens: item.PromptTokens,
|
||||
CompletionTokens: item.CompletionTokens,
|
||||
Cost: item.Cost,
|
||||
Summary: item.Summary.String,
|
||||
SummarizedAt: summarizedAt,
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func Create(ctx context.Context, title string) (Session, error) {
|
||||
return GetService().Create(ctx, title)
|
||||
}
|
||||
|
||||
func CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
|
||||
return GetService().CreateTaskSession(ctx, toolCallID, parentSessionID, title)
|
||||
}
|
||||
|
||||
func Get(ctx context.Context, id string) (Session, error) {
|
||||
return GetService().Get(ctx, id)
|
||||
}
|
||||
|
||||
func List(ctx context.Context) ([]Session, error) {
|
||||
return GetService().List(ctx)
|
||||
}
|
||||
|
||||
func Update(ctx context.Context, session Session) (Session, error) {
|
||||
return GetService().Update(ctx, session)
|
||||
}
|
||||
|
||||
func Delete(ctx context.Context, id string) error {
|
||||
return GetService().Delete(ctx, id)
|
||||
}
|
||||
|
||||
func Subscribe(ctx context.Context) <-chan pubsub.Event[Session] {
|
||||
return GetService().Subscribe(ctx)
|
||||
}
|
||||
@@ -1,117 +0,0 @@
|
||||
package status
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sst/opencode/internal/pubsub"
|
||||
)
|
||||
|
||||
type Level string
|
||||
|
||||
const (
|
||||
LevelInfo Level = "info"
|
||||
LevelWarn Level = "warn"
|
||||
LevelError Level = "error"
|
||||
LevelDebug Level = "debug"
|
||||
)
|
||||
|
||||
type StatusMessage struct {
|
||||
Level Level `json:"level"`
|
||||
Message string `json:"message"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
const (
|
||||
EventStatusPublished pubsub.EventType = "status_published"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
pubsub.Subscriber[StatusMessage]
|
||||
|
||||
Info(message string)
|
||||
Warn(message string)
|
||||
Error(message string)
|
||||
Debug(message string)
|
||||
}
|
||||
|
||||
type service struct {
|
||||
broker *pubsub.Broker[StatusMessage]
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var globalStatusService *service
|
||||
|
||||
func InitService() error {
|
||||
if globalStatusService != nil {
|
||||
return fmt.Errorf("status service already initialized")
|
||||
}
|
||||
broker := pubsub.NewBroker[StatusMessage]()
|
||||
globalStatusService = &service{
|
||||
broker: broker,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetService() Service {
|
||||
if globalStatusService == nil {
|
||||
panic("status service not initialized. Call status.InitService() at application startup.")
|
||||
}
|
||||
return globalStatusService
|
||||
}
|
||||
|
||||
func (s *service) Info(message string) {
|
||||
s.publish(LevelInfo, message)
|
||||
slog.Info(message)
|
||||
}
|
||||
|
||||
func (s *service) Warn(message string) {
|
||||
s.publish(LevelWarn, message)
|
||||
slog.Warn(message)
|
||||
}
|
||||
|
||||
func (s *service) Error(message string) {
|
||||
s.publish(LevelError, message)
|
||||
slog.Error(message)
|
||||
}
|
||||
|
||||
func (s *service) Debug(message string) {
|
||||
s.publish(LevelDebug, message)
|
||||
slog.Debug(message)
|
||||
}
|
||||
|
||||
func (s *service) publish(level Level, messageText string) {
|
||||
statusMsg := StatusMessage{
|
||||
Level: level,
|
||||
Message: messageText,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
s.broker.Publish(EventStatusPublished, statusMsg)
|
||||
}
|
||||
|
||||
func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[StatusMessage] {
|
||||
return s.broker.Subscribe(ctx)
|
||||
}
|
||||
|
||||
func Info(message string) {
|
||||
GetService().Info(message)
|
||||
}
|
||||
|
||||
func Warn(message string) {
|
||||
GetService().Warn(message)
|
||||
}
|
||||
|
||||
func Error(message string) {
|
||||
GetService().Error(message)
|
||||
}
|
||||
|
||||
func Debug(message string) {
|
||||
GetService().Debug(message)
|
||||
}
|
||||
|
||||
func Subscribe(ctx context.Context) <-chan pubsub.Event[StatusMessage] {
|
||||
return GetService().Subscribe(ctx)
|
||||
}
|
||||
@@ -1,664 +0,0 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/charmbracelet/x/ansi"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/diff"
|
||||
"github.com/sst/opencode/internal/llm/agent"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/message"
|
||||
"github.com/sst/opencode/internal/tui/styles"
|
||||
"github.com/sst/opencode/internal/tui/theme"
|
||||
)
|
||||
|
||||
type uiMessageType int
|
||||
|
||||
const (
|
||||
userMessageType uiMessageType = iota
|
||||
assistantMessageType
|
||||
toolMessageType
|
||||
|
||||
maxResultHeight = 10
|
||||
)
|
||||
|
||||
type uiMessage struct {
|
||||
ID string
|
||||
messageType uiMessageType
|
||||
position int
|
||||
height int
|
||||
content string
|
||||
}
|
||||
|
||||
func toMarkdown(content string, focused bool, width int) string {
|
||||
r := styles.GetMarkdownRenderer(width)
|
||||
rendered, _ := r.Render(content)
|
||||
return rendered
|
||||
}
|
||||
|
||||
func renderMessage(msg string, isUser bool, isFocused bool, width int, info ...string) string {
|
||||
t := theme.CurrentTheme()
|
||||
|
||||
style := styles.BaseStyle().
|
||||
Width(width - 1).
|
||||
BorderLeft(true).
|
||||
Foreground(t.TextMuted()).
|
||||
BorderForeground(t.Primary()).
|
||||
BorderStyle(lipgloss.ThickBorder())
|
||||
|
||||
if isUser {
|
||||
style = style.BorderForeground(t.Secondary())
|
||||
}
|
||||
|
||||
// Apply markdown formatting and handle background color
|
||||
parts := []string{
|
||||
styles.ForceReplaceBackgroundWithLipgloss(toMarkdown(msg, isFocused, width), t.Background()),
|
||||
}
|
||||
|
||||
// Remove newline at the end
|
||||
parts[0] = strings.TrimSuffix(parts[0], "\n")
|
||||
if len(info) > 0 {
|
||||
parts = append(parts, info...)
|
||||
}
|
||||
|
||||
rendered := style.Render(
|
||||
lipgloss.JoinVertical(
|
||||
lipgloss.Left,
|
||||
parts...,
|
||||
),
|
||||
)
|
||||
|
||||
return rendered
|
||||
}
|
||||
|
||||
func renderUserMessage(msg message.Message, isFocused bool, width int, position int) uiMessage {
|
||||
var styledAttachments []string
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
attachmentStyles := baseStyle.
|
||||
MarginLeft(1).
|
||||
Background(t.TextMuted()).
|
||||
Foreground(t.Text())
|
||||
for _, attachment := range msg.BinaryContent() {
|
||||
file := filepath.Base(attachment.Path)
|
||||
var filename string
|
||||
if len(file) > 10 {
|
||||
filename = fmt.Sprintf(" %s %s...", styles.DocumentIcon, file[0:7])
|
||||
} else {
|
||||
filename = fmt.Sprintf(" %s %s", styles.DocumentIcon, file)
|
||||
}
|
||||
styledAttachments = append(styledAttachments, attachmentStyles.Render(filename))
|
||||
}
|
||||
|
||||
// Add timestamp info
|
||||
info := []string{}
|
||||
timestamp := msg.CreatedAt.Local().Format("02 Jan 2006 03:04 PM")
|
||||
username, _ := config.GetUsername()
|
||||
info = append(info, baseStyle.
|
||||
Width(width-1).
|
||||
Foreground(t.TextMuted()).
|
||||
Render(fmt.Sprintf(" %s (%s)", username, timestamp)),
|
||||
)
|
||||
|
||||
content := ""
|
||||
if len(styledAttachments) > 0 {
|
||||
attachmentContent := baseStyle.Width(width).Render(lipgloss.JoinHorizontal(lipgloss.Left, styledAttachments...))
|
||||
content = renderMessage(msg.Content().String(), true, isFocused, width, append(info, attachmentContent)...)
|
||||
} else {
|
||||
content = renderMessage(msg.Content().String(), true, isFocused, width, info...)
|
||||
}
|
||||
userMsg := uiMessage{
|
||||
ID: msg.ID,
|
||||
messageType: userMessageType,
|
||||
position: position,
|
||||
height: lipgloss.Height(content),
|
||||
content: content,
|
||||
}
|
||||
return userMsg
|
||||
}
|
||||
|
||||
// Returns multiple uiMessages because of the tool calls
|
||||
func renderAssistantMessage(
|
||||
msg message.Message,
|
||||
msgIndex int,
|
||||
allMessages []message.Message, // we need this to get tool results and the user message
|
||||
messagesService message.Service, // We need this to get the task tool messages
|
||||
focusedUIMessageId string,
|
||||
width int,
|
||||
position int,
|
||||
showToolMessages bool,
|
||||
) []uiMessage {
|
||||
messages := []uiMessage{}
|
||||
content := strings.TrimSpace(msg.Content().String())
|
||||
thinking := msg.IsThinking()
|
||||
thinkingContent := msg.ReasoningContent().Thinking
|
||||
finished := msg.IsFinished()
|
||||
finishData := msg.FinishPart()
|
||||
info := []string{}
|
||||
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
// Always add timestamp info
|
||||
timestamp := msg.CreatedAt.Local().Format("02 Jan 2006 03:04 PM")
|
||||
modelName := "Assistant"
|
||||
if msg.Model != "" {
|
||||
modelName = models.SupportedModels[msg.Model].Name
|
||||
}
|
||||
|
||||
info = append(info, baseStyle.
|
||||
Width(width-1).
|
||||
Foreground(t.TextMuted()).
|
||||
Render(fmt.Sprintf(" %s (%s)", modelName, timestamp)),
|
||||
)
|
||||
|
||||
if finished {
|
||||
// Add finish info if available
|
||||
switch finishData.Reason {
|
||||
case message.FinishReasonCanceled:
|
||||
info = append(info, baseStyle.
|
||||
Width(width-1).
|
||||
Foreground(t.Warning()).
|
||||
Render("(canceled)"),
|
||||
)
|
||||
case message.FinishReasonError:
|
||||
info = append(info, baseStyle.
|
||||
Width(width-1).
|
||||
Foreground(t.Error()).
|
||||
Render("(error)"),
|
||||
)
|
||||
case message.FinishReasonPermissionDenied:
|
||||
info = append(info, baseStyle.
|
||||
Width(width-1).
|
||||
Foreground(t.Info()).
|
||||
Render("(permission denied)"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if content != "" || (finished && finishData.Reason == message.FinishReasonEndTurn) {
|
||||
if content == "" {
|
||||
content = "*Finished without output*"
|
||||
}
|
||||
|
||||
content = renderMessage(content, false, true, width, info...)
|
||||
messages = append(messages, uiMessage{
|
||||
ID: msg.ID,
|
||||
messageType: assistantMessageType,
|
||||
position: position,
|
||||
height: lipgloss.Height(content),
|
||||
content: content,
|
||||
})
|
||||
position += messages[0].height
|
||||
position++ // for the space
|
||||
} else if thinking && thinkingContent != "" {
|
||||
// Render the thinking content with timestamp
|
||||
content = renderMessage(thinkingContent, false, msg.ID == focusedUIMessageId, width, info...)
|
||||
messages = append(messages, uiMessage{
|
||||
ID: msg.ID,
|
||||
messageType: assistantMessageType,
|
||||
position: position,
|
||||
height: lipgloss.Height(content),
|
||||
content: content,
|
||||
})
|
||||
position += lipgloss.Height(content)
|
||||
position++ // for the space
|
||||
}
|
||||
|
||||
// Only render tool messages if they should be shown
|
||||
if showToolMessages {
|
||||
for i, toolCall := range msg.ToolCalls() {
|
||||
toolCallContent := renderToolMessage(
|
||||
toolCall,
|
||||
allMessages,
|
||||
messagesService,
|
||||
focusedUIMessageId,
|
||||
false,
|
||||
width,
|
||||
i+1,
|
||||
)
|
||||
messages = append(messages, toolCallContent)
|
||||
position += toolCallContent.height
|
||||
position++ // for the space
|
||||
}
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
func findToolResponse(toolCallID string, futureMessages []message.Message) *message.ToolResult {
|
||||
for _, msg := range futureMessages {
|
||||
for _, result := range msg.ToolResults() {
|
||||
if result.ToolCallID == toolCallID {
|
||||
return &result
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func toolName(name string) string {
|
||||
switch name {
|
||||
case agent.AgentToolName:
|
||||
return "Task"
|
||||
case tools.BashToolName:
|
||||
return "Bash"
|
||||
case tools.EditToolName:
|
||||
return "Edit"
|
||||
case tools.FetchToolName:
|
||||
return "Fetch"
|
||||
case tools.GlobToolName:
|
||||
return "Glob"
|
||||
case tools.GrepToolName:
|
||||
return "Grep"
|
||||
case tools.LSToolName:
|
||||
return "List"
|
||||
case tools.ViewToolName:
|
||||
return "View"
|
||||
case tools.WriteToolName:
|
||||
return "Write"
|
||||
case tools.PatchToolName:
|
||||
return "Patch"
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func getToolAction(name string) string {
|
||||
switch name {
|
||||
case agent.AgentToolName:
|
||||
return "Preparing prompt..."
|
||||
case tools.BashToolName:
|
||||
return "Building command..."
|
||||
case tools.EditToolName:
|
||||
return "Preparing edit..."
|
||||
case tools.FetchToolName:
|
||||
return "Writing fetch..."
|
||||
case tools.GlobToolName:
|
||||
return "Finding files..."
|
||||
case tools.GrepToolName:
|
||||
return "Searching content..."
|
||||
case tools.LSToolName:
|
||||
return "Listing directory..."
|
||||
case tools.ViewToolName:
|
||||
return "Reading file..."
|
||||
case tools.WriteToolName:
|
||||
return "Preparing write..."
|
||||
case tools.PatchToolName:
|
||||
return "Preparing patch..."
|
||||
}
|
||||
return "Working..."
|
||||
}
|
||||
|
||||
// renders params, params[0] (params[1]=params[2] ....)
|
||||
func renderParams(paramsWidth int, params ...string) string {
|
||||
if len(params) == 0 {
|
||||
return ""
|
||||
}
|
||||
mainParam := params[0]
|
||||
if len(mainParam) > paramsWidth {
|
||||
mainParam = mainParam[:paramsWidth-3] + "..."
|
||||
}
|
||||
|
||||
if len(params) == 1 {
|
||||
return mainParam
|
||||
}
|
||||
otherParams := params[1:]
|
||||
// create pairs of key/value
|
||||
// if odd number of params, the last one is a key without value
|
||||
if len(otherParams)%2 != 0 {
|
||||
otherParams = append(otherParams, "")
|
||||
}
|
||||
parts := make([]string, 0, len(otherParams)/2)
|
||||
for i := 0; i < len(otherParams); i += 2 {
|
||||
key := otherParams[i]
|
||||
value := otherParams[i+1]
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("%s=%s", key, value))
|
||||
}
|
||||
|
||||
partsRendered := strings.Join(parts, ", ")
|
||||
remainingWidth := paramsWidth - lipgloss.Width(partsRendered) - 5 // for the space
|
||||
if remainingWidth < 30 {
|
||||
// No space for the params, just show the main
|
||||
return mainParam
|
||||
}
|
||||
|
||||
if len(parts) > 0 {
|
||||
mainParam = fmt.Sprintf("%s (%s)", mainParam, strings.Join(parts, ", "))
|
||||
}
|
||||
|
||||
return ansi.Truncate(mainParam, paramsWidth, "...")
|
||||
}
|
||||
|
||||
func removeWorkingDirPrefix(path string) string {
|
||||
wd := config.WorkingDirectory()
|
||||
if strings.HasPrefix(path, wd) {
|
||||
path = strings.TrimPrefix(path, wd)
|
||||
}
|
||||
if strings.HasPrefix(path, "/") {
|
||||
path = strings.TrimPrefix(path, "/")
|
||||
}
|
||||
if strings.HasPrefix(path, "./") {
|
||||
path = strings.TrimPrefix(path, "./")
|
||||
}
|
||||
if strings.HasPrefix(path, "../") {
|
||||
path = strings.TrimPrefix(path, "../")
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func renderToolParams(paramWidth int, toolCall message.ToolCall) string {
|
||||
params := ""
|
||||
switch toolCall.Name {
|
||||
case agent.AgentToolName:
|
||||
var params agent.AgentParams
|
||||
json.Unmarshal([]byte(toolCall.Input), ¶ms)
|
||||
prompt := strings.ReplaceAll(params.Prompt, "\n", " ")
|
||||
return renderParams(paramWidth, prompt)
|
||||
case tools.BashToolName:
|
||||
var params tools.BashParams
|
||||
json.Unmarshal([]byte(toolCall.Input), ¶ms)
|
||||
command := strings.ReplaceAll(params.Command, "\n", " ")
|
||||
return renderParams(paramWidth, command)
|
||||
case tools.EditToolName:
|
||||
var params tools.EditParams
|
||||
json.Unmarshal([]byte(toolCall.Input), ¶ms)
|
||||
filePath := removeWorkingDirPrefix(params.FilePath)
|
||||
return renderParams(paramWidth, filePath)
|
||||
case tools.FetchToolName:
|
||||
var params tools.FetchParams
|
||||
json.Unmarshal([]byte(toolCall.Input), ¶ms)
|
||||
url := params.URL
|
||||
toolParams := []string{
|
||||
url,
|
||||
}
|
||||
if params.Format != "" {
|
||||
toolParams = append(toolParams, "format", params.Format)
|
||||
}
|
||||
if params.Timeout != 0 {
|
||||
toolParams = append(toolParams, "timeout", (time.Duration(params.Timeout) * time.Second).String())
|
||||
}
|
||||
return renderParams(paramWidth, toolParams...)
|
||||
case tools.GlobToolName:
|
||||
var params tools.GlobParams
|
||||
json.Unmarshal([]byte(toolCall.Input), ¶ms)
|
||||
pattern := params.Pattern
|
||||
toolParams := []string{
|
||||
pattern,
|
||||
}
|
||||
if params.Path != "" {
|
||||
toolParams = append(toolParams, "path", params.Path)
|
||||
}
|
||||
return renderParams(paramWidth, toolParams...)
|
||||
case tools.GrepToolName:
|
||||
var params tools.GrepParams
|
||||
json.Unmarshal([]byte(toolCall.Input), ¶ms)
|
||||
pattern := params.Pattern
|
||||
toolParams := []string{
|
||||
pattern,
|
||||
}
|
||||
if params.Path != "" {
|
||||
toolParams = append(toolParams, "path", params.Path)
|
||||
}
|
||||
if params.Include != "" {
|
||||
toolParams = append(toolParams, "include", params.Include)
|
||||
}
|
||||
if params.LiteralText {
|
||||
toolParams = append(toolParams, "literal", "true")
|
||||
}
|
||||
return renderParams(paramWidth, toolParams...)
|
||||
case tools.LSToolName:
|
||||
var params tools.LSParams
|
||||
json.Unmarshal([]byte(toolCall.Input), ¶ms)
|
||||
path := params.Path
|
||||
if path == "" {
|
||||
path = "."
|
||||
}
|
||||
return renderParams(paramWidth, path)
|
||||
case tools.ViewToolName:
|
||||
var params tools.ViewParams
|
||||
json.Unmarshal([]byte(toolCall.Input), ¶ms)
|
||||
filePath := removeWorkingDirPrefix(params.FilePath)
|
||||
toolParams := []string{
|
||||
filePath,
|
||||
}
|
||||
if params.Limit != 0 {
|
||||
toolParams = append(toolParams, "limit", fmt.Sprintf("%d", params.Limit))
|
||||
}
|
||||
if params.Offset != 0 {
|
||||
toolParams = append(toolParams, "offset", fmt.Sprintf("%d", params.Offset))
|
||||
}
|
||||
return renderParams(paramWidth, toolParams...)
|
||||
case tools.WriteToolName:
|
||||
var params tools.WriteParams
|
||||
json.Unmarshal([]byte(toolCall.Input), ¶ms)
|
||||
filePath := removeWorkingDirPrefix(params.FilePath)
|
||||
return renderParams(paramWidth, filePath)
|
||||
default:
|
||||
input := strings.ReplaceAll(toolCall.Input, "\n", " ")
|
||||
params = renderParams(paramWidth, input)
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func truncateHeight(content string, height int) string {
|
||||
lines := strings.Split(content, "\n")
|
||||
if len(lines) > height {
|
||||
return strings.Join(lines[:height], "\n")
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func renderToolResponse(toolCall message.ToolCall, response message.ToolResult, width int) string {
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
if response.IsError {
|
||||
errContent := fmt.Sprintf("Error: %s", strings.ReplaceAll(response.Content, "\n", " "))
|
||||
errContent = ansi.Truncate(errContent, width-1, "...")
|
||||
return baseStyle.
|
||||
Width(width).
|
||||
Foreground(t.Error()).
|
||||
Render(errContent)
|
||||
}
|
||||
|
||||
resultContent := truncateHeight(response.Content, maxResultHeight)
|
||||
switch toolCall.Name {
|
||||
case agent.AgentToolName:
|
||||
return styles.ForceReplaceBackgroundWithLipgloss(
|
||||
toMarkdown(resultContent, false, width),
|
||||
t.Background(),
|
||||
)
|
||||
case tools.BashToolName:
|
||||
resultContent = fmt.Sprintf("```bash\n%s\n```", resultContent)
|
||||
return styles.ForceReplaceBackgroundWithLipgloss(
|
||||
toMarkdown(resultContent, true, width),
|
||||
t.Background(),
|
||||
)
|
||||
case tools.EditToolName:
|
||||
metadata := tools.EditResponseMetadata{}
|
||||
json.Unmarshal([]byte(response.Metadata), &metadata)
|
||||
formattedDiff, _ := diff.FormatDiff(metadata.Diff, diff.WithTotalWidth(width))
|
||||
return formattedDiff
|
||||
case tools.FetchToolName:
|
||||
var params tools.FetchParams
|
||||
json.Unmarshal([]byte(toolCall.Input), ¶ms)
|
||||
mdFormat := "markdown"
|
||||
switch params.Format {
|
||||
case "text":
|
||||
mdFormat = "text"
|
||||
case "html":
|
||||
mdFormat = "html"
|
||||
}
|
||||
resultContent = fmt.Sprintf("```%s\n%s\n```", mdFormat, resultContent)
|
||||
return styles.ForceReplaceBackgroundWithLipgloss(
|
||||
toMarkdown(resultContent, true, width),
|
||||
t.Background(),
|
||||
)
|
||||
case tools.GlobToolName:
|
||||
return baseStyle.Width(width).Foreground(t.TextMuted()).Render(resultContent)
|
||||
case tools.GrepToolName:
|
||||
return baseStyle.Width(width).Foreground(t.TextMuted()).Render(resultContent)
|
||||
case tools.LSToolName:
|
||||
return baseStyle.Width(width).Foreground(t.TextMuted()).Render(resultContent)
|
||||
case tools.ViewToolName:
|
||||
metadata := tools.ViewResponseMetadata{}
|
||||
json.Unmarshal([]byte(response.Metadata), &metadata)
|
||||
ext := filepath.Ext(metadata.FilePath)
|
||||
if ext == "" {
|
||||
ext = ""
|
||||
} else {
|
||||
ext = strings.ToLower(ext[1:])
|
||||
}
|
||||
resultContent = fmt.Sprintf("```%s\n%s\n```", ext, truncateHeight(metadata.Content, maxResultHeight))
|
||||
return styles.ForceReplaceBackgroundWithLipgloss(
|
||||
toMarkdown(resultContent, true, width),
|
||||
t.Background(),
|
||||
)
|
||||
case tools.WriteToolName:
|
||||
params := tools.WriteParams{}
|
||||
json.Unmarshal([]byte(toolCall.Input), ¶ms)
|
||||
metadata := tools.WriteResponseMetadata{}
|
||||
json.Unmarshal([]byte(response.Metadata), &metadata)
|
||||
ext := filepath.Ext(params.FilePath)
|
||||
if ext == "" {
|
||||
ext = ""
|
||||
} else {
|
||||
ext = strings.ToLower(ext[1:])
|
||||
}
|
||||
resultContent = fmt.Sprintf("```%s\n%s\n```", ext, truncateHeight(params.Content, maxResultHeight))
|
||||
return styles.ForceReplaceBackgroundWithLipgloss(
|
||||
toMarkdown(resultContent, true, width),
|
||||
t.Background(),
|
||||
)
|
||||
default:
|
||||
resultContent = fmt.Sprintf("```text\n%s\n```", resultContent)
|
||||
return styles.ForceReplaceBackgroundWithLipgloss(
|
||||
toMarkdown(resultContent, true, width),
|
||||
t.Background(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func renderToolMessage(
|
||||
toolCall message.ToolCall,
|
||||
allMessages []message.Message,
|
||||
messagesService message.Service,
|
||||
focusedUIMessageId string,
|
||||
nested bool,
|
||||
width int,
|
||||
position int,
|
||||
) uiMessage {
|
||||
if nested {
|
||||
width = width - 3
|
||||
}
|
||||
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
style := baseStyle.
|
||||
Width(width - 1).
|
||||
BorderLeft(true).
|
||||
BorderStyle(lipgloss.ThickBorder()).
|
||||
PaddingLeft(1).
|
||||
BorderForeground(t.TextMuted())
|
||||
|
||||
response := findToolResponse(toolCall.ID, allMessages)
|
||||
toolNameText := baseStyle.Foreground(t.TextMuted()).
|
||||
Render(fmt.Sprintf("%s: ", toolName(toolCall.Name)))
|
||||
|
||||
if !toolCall.Finished {
|
||||
// Get a brief description of what the tool is doing
|
||||
toolAction := getToolAction(toolCall.Name)
|
||||
|
||||
progressText := baseStyle.
|
||||
Width(width - 2 - lipgloss.Width(toolNameText)).
|
||||
Foreground(t.TextMuted()).
|
||||
Render(fmt.Sprintf("%s", toolAction))
|
||||
|
||||
content := style.Render(lipgloss.JoinHorizontal(lipgloss.Left, toolNameText, progressText))
|
||||
toolMsg := uiMessage{
|
||||
messageType: toolMessageType,
|
||||
position: position,
|
||||
height: lipgloss.Height(content),
|
||||
content: content,
|
||||
}
|
||||
return toolMsg
|
||||
}
|
||||
|
||||
params := renderToolParams(width-1-lipgloss.Width(toolNameText), toolCall)
|
||||
responseContent := ""
|
||||
if response != nil {
|
||||
responseContent = renderToolResponse(toolCall, *response, width-2)
|
||||
responseContent = strings.TrimSuffix(responseContent, "\n")
|
||||
} else {
|
||||
responseContent = baseStyle.
|
||||
Italic(true).
|
||||
Width(width - 2).
|
||||
Foreground(t.TextMuted()).
|
||||
Render("Waiting for response...")
|
||||
}
|
||||
|
||||
parts := []string{}
|
||||
if !nested {
|
||||
formattedParams := baseStyle.
|
||||
Width(width - 2 - lipgloss.Width(toolNameText)).
|
||||
Foreground(t.TextMuted()).
|
||||
Render(params)
|
||||
|
||||
parts = append(parts, lipgloss.JoinHorizontal(lipgloss.Left, toolNameText, formattedParams))
|
||||
} else {
|
||||
prefix := baseStyle.
|
||||
Foreground(t.TextMuted()).
|
||||
Render(" └ ")
|
||||
formattedParams := baseStyle.
|
||||
Width(width - 2 - lipgloss.Width(toolNameText)).
|
||||
Foreground(t.TextMuted()).
|
||||
Render(params)
|
||||
parts = append(parts, lipgloss.JoinHorizontal(lipgloss.Left, prefix, toolNameText, formattedParams))
|
||||
}
|
||||
|
||||
if toolCall.Name == agent.AgentToolName {
|
||||
taskMessages, _ := messagesService.List(context.Background(), toolCall.ID)
|
||||
toolCalls := []message.ToolCall{}
|
||||
for _, v := range taskMessages {
|
||||
toolCalls = append(toolCalls, v.ToolCalls()...)
|
||||
}
|
||||
for _, call := range toolCalls {
|
||||
rendered := renderToolMessage(call, []message.Message{}, messagesService, focusedUIMessageId, true, width, 0)
|
||||
parts = append(parts, rendered.content)
|
||||
}
|
||||
}
|
||||
if responseContent != "" && !nested {
|
||||
parts = append(parts, responseContent)
|
||||
}
|
||||
|
||||
content := style.Render(
|
||||
lipgloss.JoinVertical(
|
||||
lipgloss.Left,
|
||||
parts...,
|
||||
),
|
||||
)
|
||||
if nested {
|
||||
content = lipgloss.JoinVertical(
|
||||
lipgloss.Left,
|
||||
parts...,
|
||||
)
|
||||
}
|
||||
toolMsg := uiMessage{
|
||||
messageType: toolMessageType,
|
||||
position: position,
|
||||
height: lipgloss.Height(content),
|
||||
content: content,
|
||||
}
|
||||
return toolMsg
|
||||
}
|
||||
@@ -1,361 +0,0 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/sst/opencode/internal/app"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/diff"
|
||||
"github.com/sst/opencode/internal/history"
|
||||
"github.com/sst/opencode/internal/pubsub"
|
||||
"github.com/sst/opencode/internal/tui/state"
|
||||
"github.com/sst/opencode/internal/tui/styles"
|
||||
"github.com/sst/opencode/internal/tui/theme"
|
||||
)
|
||||
|
||||
type sidebarCmp struct {
|
||||
app *app.App
|
||||
width, height int
|
||||
modFiles map[string]struct {
|
||||
additions int
|
||||
removals int
|
||||
}
|
||||
}
|
||||
|
||||
func (m *sidebarCmp) Init() tea.Cmd {
|
||||
if m.app.History != nil {
|
||||
ctx := context.Background()
|
||||
// Subscribe to file events
|
||||
filesCh := m.app.History.Subscribe(ctx)
|
||||
|
||||
// Initialize the modified files map
|
||||
m.modFiles = make(map[string]struct {
|
||||
additions int
|
||||
removals int
|
||||
})
|
||||
|
||||
// Load initial files and calculate diffs
|
||||
m.loadModifiedFiles(ctx)
|
||||
|
||||
// Return a command that will send file events to the Update method
|
||||
return func() tea.Msg {
|
||||
return <-filesCh
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *sidebarCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case state.SessionSelectedMsg:
|
||||
ctx := context.Background()
|
||||
m.loadModifiedFiles(ctx)
|
||||
case pubsub.Event[history.File]:
|
||||
if msg.Payload.SessionID == m.app.CurrentSession.ID {
|
||||
// Process the individual file change instead of reloading all files
|
||||
ctx := context.Background()
|
||||
m.processFileChanges(ctx, msg.Payload)
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *sidebarCmp) View() string {
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
return baseStyle.
|
||||
Width(m.width).
|
||||
PaddingLeft(4).
|
||||
PaddingRight(2).
|
||||
Height(m.height - 1).
|
||||
Render(
|
||||
lipgloss.JoinVertical(
|
||||
lipgloss.Top,
|
||||
header(m.width),
|
||||
" ",
|
||||
m.sessionSection(),
|
||||
" ",
|
||||
lspsConfigured(m.width),
|
||||
" ",
|
||||
m.modifiedFiles(),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (m *sidebarCmp) sessionSection() string {
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
sessionKey := baseStyle.
|
||||
Foreground(t.Primary()).
|
||||
Bold(true).
|
||||
Render("Session")
|
||||
|
||||
sessionValue := baseStyle.
|
||||
Foreground(t.Text()).
|
||||
Width(m.width - lipgloss.Width(sessionKey)).
|
||||
Render(fmt.Sprintf(": %s", m.app.CurrentSession.Title))
|
||||
|
||||
return lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
sessionKey,
|
||||
sessionValue,
|
||||
)
|
||||
}
|
||||
|
||||
func (m *sidebarCmp) modifiedFile(filePath string, additions, removals int) string {
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
stats := ""
|
||||
if additions > 0 && removals > 0 {
|
||||
additionsStr := baseStyle.
|
||||
Foreground(t.Success()).
|
||||
PaddingLeft(1).
|
||||
Render(fmt.Sprintf("+%d", additions))
|
||||
|
||||
removalsStr := baseStyle.
|
||||
Foreground(t.Error()).
|
||||
PaddingLeft(1).
|
||||
Render(fmt.Sprintf("-%d", removals))
|
||||
|
||||
content := lipgloss.JoinHorizontal(lipgloss.Left, additionsStr, removalsStr)
|
||||
stats = baseStyle.Width(lipgloss.Width(content)).Render(content)
|
||||
} else if additions > 0 {
|
||||
additionsStr := fmt.Sprintf(" %s", baseStyle.
|
||||
PaddingLeft(1).
|
||||
Foreground(t.Success()).
|
||||
Render(fmt.Sprintf("+%d", additions)))
|
||||
stats = baseStyle.Width(lipgloss.Width(additionsStr)).Render(additionsStr)
|
||||
} else if removals > 0 {
|
||||
removalsStr := fmt.Sprintf(" %s", baseStyle.
|
||||
PaddingLeft(1).
|
||||
Foreground(t.Error()).
|
||||
Render(fmt.Sprintf("-%d", removals)))
|
||||
stats = baseStyle.Width(lipgloss.Width(removalsStr)).Render(removalsStr)
|
||||
}
|
||||
|
||||
filePathStr := baseStyle.Render(filePath)
|
||||
|
||||
return baseStyle.
|
||||
Width(m.width).
|
||||
Render(
|
||||
lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
filePathStr,
|
||||
stats,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (m *sidebarCmp) modifiedFiles() string {
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
modifiedFiles := baseStyle.
|
||||
Width(m.width).
|
||||
Foreground(t.Primary()).
|
||||
Bold(true).
|
||||
Render("Modified Files:")
|
||||
|
||||
// If no modified files, show a placeholder message
|
||||
if m.modFiles == nil || len(m.modFiles) == 0 {
|
||||
message := "No modified files"
|
||||
remainingWidth := m.width - lipgloss.Width(message)
|
||||
if remainingWidth > 0 {
|
||||
message += strings.Repeat(" ", remainingWidth)
|
||||
}
|
||||
return baseStyle.
|
||||
Width(m.width).
|
||||
Render(
|
||||
lipgloss.JoinVertical(
|
||||
lipgloss.Top,
|
||||
modifiedFiles,
|
||||
baseStyle.Foreground(t.TextMuted()).Render(message),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// Sort file paths alphabetically for consistent ordering
|
||||
var paths []string
|
||||
for path := range m.modFiles {
|
||||
paths = append(paths, path)
|
||||
}
|
||||
sort.Strings(paths)
|
||||
|
||||
// Create views for each file in sorted order
|
||||
var fileViews []string
|
||||
for _, path := range paths {
|
||||
stats := m.modFiles[path]
|
||||
fileViews = append(fileViews, m.modifiedFile(path, stats.additions, stats.removals))
|
||||
}
|
||||
|
||||
return baseStyle.
|
||||
Width(m.width).
|
||||
Render(
|
||||
lipgloss.JoinVertical(
|
||||
lipgloss.Top,
|
||||
modifiedFiles,
|
||||
lipgloss.JoinVertical(
|
||||
lipgloss.Left,
|
||||
fileViews...,
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (m *sidebarCmp) SetSize(width, height int) tea.Cmd {
|
||||
m.width = width
|
||||
m.height = height
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *sidebarCmp) GetSize() (int, int) {
|
||||
return m.width, m.height
|
||||
}
|
||||
|
||||
func NewSidebarCmp(app *app.App) tea.Model {
|
||||
return &sidebarCmp{
|
||||
app: app,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *sidebarCmp) loadModifiedFiles(ctx context.Context) {
|
||||
if m.app.CurrentSession.ID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Get all latest files for this session
|
||||
latestFiles, err := m.app.History.ListLatestSessionFiles(ctx, m.app.CurrentSession.ID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get all files for this session (to find initial versions)
|
||||
allFiles, err := m.app.History.ListBySession(ctx, m.app.CurrentSession.ID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Clear the existing map to rebuild it
|
||||
m.modFiles = make(map[string]struct {
|
||||
additions int
|
||||
removals int
|
||||
})
|
||||
|
||||
// Process each latest file
|
||||
for _, file := range latestFiles {
|
||||
// Skip if this is the initial version (no changes to show)
|
||||
if file.Version == history.InitialVersion {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find the initial version for this specific file
|
||||
var initialVersion history.File
|
||||
for _, v := range allFiles {
|
||||
if v.Path == file.Path && v.Version == history.InitialVersion {
|
||||
initialVersion = v
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Skip if we can't find the initial version
|
||||
if initialVersion.ID == "" {
|
||||
continue
|
||||
}
|
||||
if initialVersion.Content == file.Content {
|
||||
continue
|
||||
}
|
||||
|
||||
// Calculate diff between initial and latest version
|
||||
_, additions, removals := diff.GenerateDiff(initialVersion.Content, file.Content, file.Path)
|
||||
|
||||
// Only add to modified files if there are changes
|
||||
if additions > 0 || removals > 0 {
|
||||
// Remove working directory prefix from file path
|
||||
displayPath := file.Path
|
||||
workingDir := config.WorkingDirectory()
|
||||
displayPath = strings.TrimPrefix(displayPath, workingDir)
|
||||
displayPath = strings.TrimPrefix(displayPath, "/")
|
||||
|
||||
m.modFiles[displayPath] = struct {
|
||||
additions int
|
||||
removals int
|
||||
}{
|
||||
additions: additions,
|
||||
removals: removals,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *sidebarCmp) processFileChanges(ctx context.Context, file history.File) {
|
||||
// Skip if this is the initial version (no changes to show)
|
||||
if file.Version == history.InitialVersion {
|
||||
return
|
||||
}
|
||||
|
||||
// Find the initial version for this file
|
||||
initialVersion, err := m.findInitialVersion(ctx, file.Path)
|
||||
if err != nil || initialVersion.ID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip if content hasn't changed
|
||||
if initialVersion.Content == file.Content {
|
||||
// If this file was previously modified but now matches the initial version,
|
||||
// remove it from the modified files list
|
||||
displayPath := getDisplayPath(file.Path)
|
||||
delete(m.modFiles, displayPath)
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate diff between initial and latest version
|
||||
_, additions, removals := diff.GenerateDiff(initialVersion.Content, file.Content, file.Path)
|
||||
|
||||
// Only add to modified files if there are changes
|
||||
if additions > 0 || removals > 0 {
|
||||
displayPath := getDisplayPath(file.Path)
|
||||
m.modFiles[displayPath] = struct {
|
||||
additions int
|
||||
removals int
|
||||
}{
|
||||
additions: additions,
|
||||
removals: removals,
|
||||
}
|
||||
} else {
|
||||
// If no changes, remove from modified files
|
||||
displayPath := getDisplayPath(file.Path)
|
||||
delete(m.modFiles, displayPath)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to find the initial version of a file
|
||||
func (m *sidebarCmp) findInitialVersion(ctx context.Context, path string) (history.File, error) {
|
||||
// Get all versions of this file for the session
|
||||
fileVersions, err := m.app.History.ListBySession(ctx, m.app.CurrentSession.ID)
|
||||
if err != nil {
|
||||
return history.File{}, err
|
||||
}
|
||||
|
||||
// Find the initial version
|
||||
for _, v := range fileVersions {
|
||||
if v.Path == path && v.Version == history.InitialVersion {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
|
||||
return history.File{}, fmt.Errorf("initial version not found")
|
||||
}
|
||||
|
||||
// Helper function to get the display path for a file
|
||||
func getDisplayPath(path string) string {
|
||||
workingDir := config.WorkingDirectory()
|
||||
displayPath := strings.TrimPrefix(path, workingDir)
|
||||
return strings.TrimPrefix(displayPath, "/")
|
||||
}
|
||||
@@ -1,302 +0,0 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/sst/opencode/internal/app"
|
||||
"github.com/sst/opencode/internal/config"
|
||||
"github.com/sst/opencode/internal/llm/models"
|
||||
"github.com/sst/opencode/internal/lsp"
|
||||
"github.com/sst/opencode/internal/lsp/protocol"
|
||||
"github.com/sst/opencode/internal/pubsub"
|
||||
"github.com/sst/opencode/internal/status"
|
||||
"github.com/sst/opencode/internal/tui/styles"
|
||||
"github.com/sst/opencode/internal/tui/theme"
|
||||
)
|
||||
|
||||
type StatusCmp interface {
|
||||
tea.Model
|
||||
SetHelpWidgetMsg(string)
|
||||
}
|
||||
|
||||
type statusCmp struct {
|
||||
app *app.App
|
||||
statusMessages []statusMessage
|
||||
width int
|
||||
messageTTL time.Duration
|
||||
}
|
||||
|
||||
type statusMessage struct {
|
||||
Level status.Level
|
||||
Message string
|
||||
Timestamp time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// clearMessageCmd is a command that clears status messages after a timeout
|
||||
func (m statusCmp) clearMessageCmd() tea.Cmd {
|
||||
return tea.Tick(time.Second, func(t time.Time) tea.Msg {
|
||||
return statusCleanupMsg{time: t}
|
||||
})
|
||||
}
|
||||
|
||||
// statusCleanupMsg is a message that triggers cleanup of expired status messages
|
||||
type statusCleanupMsg struct {
|
||||
time time.Time
|
||||
}
|
||||
|
||||
func (m statusCmp) Init() tea.Cmd {
|
||||
return m.clearMessageCmd()
|
||||
}
|
||||
|
||||
func (m statusCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
return m, nil
|
||||
case pubsub.Event[status.StatusMessage]:
|
||||
if msg.Type == status.EventStatusPublished {
|
||||
statusMsg := statusMessage{
|
||||
Level: msg.Payload.Level,
|
||||
Message: msg.Payload.Message,
|
||||
Timestamp: msg.Payload.Timestamp,
|
||||
ExpiresAt: msg.Payload.Timestamp.Add(m.messageTTL),
|
||||
}
|
||||
m.statusMessages = append(m.statusMessages, statusMsg)
|
||||
}
|
||||
case statusCleanupMsg:
|
||||
// Remove expired messages
|
||||
var activeMessages []statusMessage
|
||||
for _, sm := range m.statusMessages {
|
||||
if sm.ExpiresAt.After(msg.time) {
|
||||
activeMessages = append(activeMessages, sm)
|
||||
}
|
||||
}
|
||||
m.statusMessages = activeMessages
|
||||
return m, m.clearMessageCmd()
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
var helpWidget = ""
|
||||
|
||||
// getHelpWidget returns the help widget with current theme colors
|
||||
func getHelpWidget(helpText string) string {
|
||||
t := theme.CurrentTheme()
|
||||
if helpText == "" {
|
||||
helpText = "ctrl+? help"
|
||||
}
|
||||
|
||||
return styles.Padded().
|
||||
Background(t.TextMuted()).
|
||||
Foreground(t.BackgroundDarker()).
|
||||
Bold(true).
|
||||
Render(helpText)
|
||||
}
|
||||
|
||||
func formatTokensAndCost(tokens int64, contextWindow int64, cost float64) string {
|
||||
// Format tokens in human-readable format (e.g., 110K, 1.2M)
|
||||
var formattedTokens string
|
||||
switch {
|
||||
case tokens >= 1_000_000:
|
||||
formattedTokens = fmt.Sprintf("%.1fM", float64(tokens)/1_000_000)
|
||||
case tokens >= 1_000:
|
||||
formattedTokens = fmt.Sprintf("%.1fK", float64(tokens)/1_000)
|
||||
default:
|
||||
formattedTokens = fmt.Sprintf("%d", tokens)
|
||||
}
|
||||
|
||||
// Remove .0 suffix if present
|
||||
if strings.HasSuffix(formattedTokens, ".0K") {
|
||||
formattedTokens = strings.Replace(formattedTokens, ".0K", "K", 1)
|
||||
}
|
||||
if strings.HasSuffix(formattedTokens, ".0M") {
|
||||
formattedTokens = strings.Replace(formattedTokens, ".0M", "M", 1)
|
||||
}
|
||||
|
||||
// Format cost with $ symbol and 2 decimal places
|
||||
formattedCost := fmt.Sprintf("$%.2f", cost)
|
||||
|
||||
percentage := (float64(tokens) / float64(contextWindow)) * 100
|
||||
|
||||
return fmt.Sprintf("Tokens: %s (%d%%), Cost: %s", formattedTokens, int(percentage), formattedCost)
|
||||
}
|
||||
|
||||
func (m statusCmp) View() string {
|
||||
t := theme.CurrentTheme()
|
||||
modelID := config.Get().Agents[config.AgentPrimary].Model
|
||||
model := models.SupportedModels[modelID]
|
||||
|
||||
// Initialize the help widget
|
||||
status := getHelpWidget("")
|
||||
|
||||
if m.app.CurrentSession.ID != "" {
|
||||
tokens := formatTokensAndCost(m.app.CurrentSession.PromptTokens+m.app.CurrentSession.CompletionTokens, model.ContextWindow, m.app.CurrentSession.Cost)
|
||||
tokensStyle := styles.Padded().
|
||||
Background(t.Text()).
|
||||
Foreground(t.BackgroundSecondary()).
|
||||
Render(tokens)
|
||||
status += tokensStyle
|
||||
}
|
||||
|
||||
diagnostics := styles.Padded().Background(t.BackgroundDarker()).Render(m.projectDiagnostics())
|
||||
|
||||
modelName := m.model()
|
||||
|
||||
statusWidth := max(
|
||||
0,
|
||||
m.width-
|
||||
lipgloss.Width(status)-
|
||||
lipgloss.Width(modelName)-
|
||||
lipgloss.Width(diagnostics),
|
||||
)
|
||||
|
||||
// Display the first status message if available
|
||||
if len(m.statusMessages) > 0 {
|
||||
sm := m.statusMessages[0]
|
||||
infoStyle := styles.Padded().
|
||||
Foreground(t.Background()).
|
||||
Width(statusWidth)
|
||||
|
||||
switch sm.Level {
|
||||
case "info":
|
||||
infoStyle = infoStyle.Background(t.Info())
|
||||
case "warn":
|
||||
infoStyle = infoStyle.Background(t.Warning())
|
||||
case "error":
|
||||
infoStyle = infoStyle.Background(t.Error())
|
||||
case "debug":
|
||||
infoStyle = infoStyle.Background(t.TextMuted())
|
||||
}
|
||||
|
||||
// Truncate message if it's longer than available width
|
||||
msg := sm.Message
|
||||
availWidth := statusWidth - 10
|
||||
if len(msg) > availWidth && availWidth > 0 {
|
||||
msg = msg[:availWidth] + "..."
|
||||
}
|
||||
|
||||
status += infoStyle.Render(msg)
|
||||
} else {
|
||||
status += styles.Padded().
|
||||
Foreground(t.Text()).
|
||||
Background(t.BackgroundSecondary()).
|
||||
Width(statusWidth).
|
||||
Render("")
|
||||
}
|
||||
|
||||
status += diagnostics
|
||||
status += modelName
|
||||
return status
|
||||
}
|
||||
|
||||
func (m *statusCmp) projectDiagnostics() string {
|
||||
t := theme.CurrentTheme()
|
||||
|
||||
// Check if any LSP server is still initializing
|
||||
initializing := false
|
||||
for _, client := range m.app.LSPClients {
|
||||
if client.GetServerState() == lsp.StateStarting {
|
||||
initializing = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If any server is initializing, show that status
|
||||
if initializing {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(t.Warning()).
|
||||
Render(fmt.Sprintf("%s Initializing LSP...", styles.SpinnerIcon))
|
||||
}
|
||||
|
||||
errorDiagnostics := []protocol.Diagnostic{}
|
||||
warnDiagnostics := []protocol.Diagnostic{}
|
||||
hintDiagnostics := []protocol.Diagnostic{}
|
||||
infoDiagnostics := []protocol.Diagnostic{}
|
||||
for _, client := range m.app.LSPClients {
|
||||
for _, d := range client.GetDiagnostics() {
|
||||
for _, diag := range d {
|
||||
switch diag.Severity {
|
||||
case protocol.SeverityError:
|
||||
errorDiagnostics = append(errorDiagnostics, diag)
|
||||
case protocol.SeverityWarning:
|
||||
warnDiagnostics = append(warnDiagnostics, diag)
|
||||
case protocol.SeverityHint:
|
||||
hintDiagnostics = append(hintDiagnostics, diag)
|
||||
case protocol.SeverityInformation:
|
||||
infoDiagnostics = append(infoDiagnostics, diag)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
diagnostics := []string{}
|
||||
|
||||
errStr := lipgloss.NewStyle().
|
||||
Background(t.BackgroundDarker()).
|
||||
Foreground(t.Error()).
|
||||
Render(fmt.Sprintf("%s %d", styles.ErrorIcon, len(errorDiagnostics)))
|
||||
diagnostics = append(diagnostics, errStr)
|
||||
|
||||
warnStr := lipgloss.NewStyle().
|
||||
Background(t.BackgroundDarker()).
|
||||
Foreground(t.Warning()).
|
||||
Render(fmt.Sprintf("%s %d", styles.WarningIcon, len(warnDiagnostics)))
|
||||
diagnostics = append(diagnostics, warnStr)
|
||||
|
||||
infoStr := lipgloss.NewStyle().
|
||||
Background(t.BackgroundDarker()).
|
||||
Foreground(t.Info()).
|
||||
Render(fmt.Sprintf("%s %d", styles.InfoIcon, len(infoDiagnostics)))
|
||||
diagnostics = append(diagnostics, infoStr)
|
||||
|
||||
hintStr := lipgloss.NewStyle().
|
||||
Background(t.BackgroundDarker()).
|
||||
Foreground(t.Text()).
|
||||
Render(fmt.Sprintf("%s %d", styles.HintIcon, len(hintDiagnostics)))
|
||||
diagnostics = append(diagnostics, hintStr)
|
||||
|
||||
return styles.ForceReplaceBackgroundWithLipgloss(
|
||||
styles.Padded().Render(strings.Join(diagnostics, " ")),
|
||||
t.BackgroundDarker(),
|
||||
)
|
||||
}
|
||||
|
||||
func (m statusCmp) model() string {
|
||||
t := theme.CurrentTheme()
|
||||
|
||||
cfg := config.Get()
|
||||
|
||||
coder, ok := cfg.Agents[config.AgentPrimary]
|
||||
if !ok {
|
||||
return "Unknown"
|
||||
}
|
||||
model := models.SupportedModels[coder.Model]
|
||||
|
||||
return styles.Padded().
|
||||
Background(t.Secondary()).
|
||||
Foreground(t.Background()).
|
||||
Render(model.Name)
|
||||
}
|
||||
|
||||
func (m statusCmp) SetHelpWidgetMsg(s string) {
|
||||
// Update the help widget text using the getHelpWidget function
|
||||
helpWidget = getHelpWidget(s)
|
||||
}
|
||||
|
||||
func NewStatusCmp(app *app.App) StatusCmp {
|
||||
// Initialize the help widget with default text
|
||||
helpWidget = getHelpWidget("")
|
||||
|
||||
statusComponent := &statusCmp{
|
||||
app: app,
|
||||
statusMessages: []statusMessage{},
|
||||
messageTTL: 4 * time.Second,
|
||||
}
|
||||
|
||||
return statusComponent
|
||||
}
|
||||
@@ -1,172 +0,0 @@
|
||||
package dialog
|
||||
|
||||
import (
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
|
||||
"github.com/sst/opencode/internal/tui/styles"
|
||||
"github.com/sst/opencode/internal/tui/theme"
|
||||
"github.com/sst/opencode/internal/tui/util"
|
||||
)
|
||||
|
||||
// ArgumentsDialogCmp is a component that asks the user for command arguments.
|
||||
type ArgumentsDialogCmp struct {
|
||||
width, height int
|
||||
textInput textinput.Model
|
||||
keys argumentsDialogKeyMap
|
||||
commandID string
|
||||
content string
|
||||
}
|
||||
|
||||
// NewArgumentsDialogCmp creates a new ArgumentsDialogCmp.
|
||||
func NewArgumentsDialogCmp(commandID, content string) ArgumentsDialogCmp {
|
||||
t := theme.CurrentTheme()
|
||||
ti := textinput.New()
|
||||
ti.Placeholder = "Enter arguments..."
|
||||
ti.Focus()
|
||||
ti.Width = 40
|
||||
ti.Prompt = ""
|
||||
ti.PlaceholderStyle = ti.PlaceholderStyle.Background(t.Background())
|
||||
ti.PromptStyle = ti.PromptStyle.Background(t.Background())
|
||||
ti.TextStyle = ti.TextStyle.Background(t.Background())
|
||||
|
||||
return ArgumentsDialogCmp{
|
||||
textInput: ti,
|
||||
keys: argumentsDialogKeyMap{},
|
||||
commandID: commandID,
|
||||
content: content,
|
||||
}
|
||||
}
|
||||
|
||||
type argumentsDialogKeyMap struct {
|
||||
Enter key.Binding
|
||||
Escape key.Binding
|
||||
}
|
||||
|
||||
// ShortHelp implements key.Map.
|
||||
func (k argumentsDialogKeyMap) ShortHelp() []key.Binding {
|
||||
return []key.Binding{
|
||||
key.NewBinding(
|
||||
key.WithKeys("enter"),
|
||||
key.WithHelp("enter", "confirm"),
|
||||
),
|
||||
key.NewBinding(
|
||||
key.WithKeys("esc"),
|
||||
key.WithHelp("esc", "cancel"),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// FullHelp implements key.Map.
|
||||
func (k argumentsDialogKeyMap) FullHelp() [][]key.Binding {
|
||||
return [][]key.Binding{k.ShortHelp()}
|
||||
}
|
||||
|
||||
// Init implements tea.Model.
|
||||
func (m ArgumentsDialogCmp) Init() tea.Cmd {
|
||||
return tea.Batch(
|
||||
textinput.Blink,
|
||||
m.textInput.Focus(),
|
||||
)
|
||||
}
|
||||
|
||||
// Update implements tea.Model.
|
||||
func (m ArgumentsDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
var cmds []tea.Cmd
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch {
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("esc"))):
|
||||
return m, util.CmdHandler(CloseArgumentsDialogMsg{})
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("enter"))):
|
||||
return m, util.CmdHandler(CloseArgumentsDialogMsg{
|
||||
Submit: true,
|
||||
CommandID: m.commandID,
|
||||
Content: m.content,
|
||||
Arguments: m.textInput.Value(),
|
||||
})
|
||||
}
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
m.height = msg.Height
|
||||
}
|
||||
|
||||
m.textInput, cmd = m.textInput.Update(msg)
|
||||
cmds = append(cmds, cmd)
|
||||
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
// View implements tea.Model.
|
||||
func (m ArgumentsDialogCmp) View() string {
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
// Calculate width needed for content
|
||||
maxWidth := 60 // Width for explanation text
|
||||
|
||||
title := baseStyle.
|
||||
Foreground(t.Primary()).
|
||||
Bold(true).
|
||||
Width(maxWidth).
|
||||
Padding(0, 1).
|
||||
Render("Command Arguments")
|
||||
|
||||
explanation := baseStyle.
|
||||
Foreground(t.Text()).
|
||||
Width(maxWidth).
|
||||
Padding(0, 1).
|
||||
Render("This command requires arguments. Please enter the text to replace $ARGUMENTS with:")
|
||||
|
||||
inputField := baseStyle.
|
||||
Foreground(t.Text()).
|
||||
Width(maxWidth).
|
||||
Padding(1, 1).
|
||||
Render(m.textInput.View())
|
||||
|
||||
maxWidth = min(maxWidth, m.width-10)
|
||||
|
||||
content := lipgloss.JoinVertical(
|
||||
lipgloss.Left,
|
||||
title,
|
||||
explanation,
|
||||
inputField,
|
||||
)
|
||||
|
||||
return baseStyle.Padding(1, 2).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderBackground(t.Background()).
|
||||
BorderForeground(t.TextMuted()).
|
||||
Background(t.Background()).
|
||||
Width(lipgloss.Width(content) + 4).
|
||||
Render(content)
|
||||
}
|
||||
|
||||
// SetSize sets the size of the component.
|
||||
func (m *ArgumentsDialogCmp) SetSize(width, height int) {
|
||||
m.width = width
|
||||
m.height = height
|
||||
}
|
||||
|
||||
// Bindings implements layout.Bindings.
|
||||
func (m ArgumentsDialogCmp) Bindings() []key.Binding {
|
||||
return m.keys.ShortHelp()
|
||||
}
|
||||
|
||||
// CloseArgumentsDialogMsg is a message that is sent when the arguments dialog is closed.
|
||||
type CloseArgumentsDialogMsg struct {
|
||||
Submit bool
|
||||
CommandID string
|
||||
Content string
|
||||
Arguments string
|
||||
}
|
||||
|
||||
// ShowArgumentsDialogMsg is a message that is sent to show the arguments dialog.
|
||||
type ShowArgumentsDialogMsg struct {
|
||||
CommandID string
|
||||
Content string
|
||||
}
|
||||
@@ -1,249 +0,0 @@
|
||||
package dialog
|
||||
|
||||
import (
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/sst/opencode/internal/tui/layout"
|
||||
"github.com/sst/opencode/internal/tui/styles"
|
||||
"github.com/sst/opencode/internal/tui/theme"
|
||||
"github.com/sst/opencode/internal/tui/util"
|
||||
)
|
||||
|
||||
// Command represents a command that can be executed
|
||||
type Command struct {
|
||||
ID string
|
||||
Title string
|
||||
Description string
|
||||
Handler func(cmd Command) tea.Cmd
|
||||
}
|
||||
|
||||
// CommandSelectedMsg is sent when a command is selected
|
||||
type CommandSelectedMsg struct {
|
||||
Command Command
|
||||
}
|
||||
|
||||
// CloseCommandDialogMsg is sent when the command dialog is closed
|
||||
type CloseCommandDialogMsg struct{}
|
||||
|
||||
// CommandDialog interface for the command selection dialog
|
||||
type CommandDialog interface {
|
||||
tea.Model
|
||||
layout.Bindings
|
||||
SetCommands(commands []Command)
|
||||
SetSelectedCommand(commandID string)
|
||||
}
|
||||
|
||||
type commandDialogCmp struct {
|
||||
commands []Command
|
||||
selectedIdx int
|
||||
width int
|
||||
height int
|
||||
selectedCommandID string
|
||||
}
|
||||
|
||||
type commandKeyMap struct {
|
||||
Up key.Binding
|
||||
Down key.Binding
|
||||
Enter key.Binding
|
||||
Escape key.Binding
|
||||
J key.Binding
|
||||
K key.Binding
|
||||
}
|
||||
|
||||
var commandKeys = commandKeyMap{
|
||||
Up: key.NewBinding(
|
||||
key.WithKeys("up"),
|
||||
key.WithHelp("↑", "previous command"),
|
||||
),
|
||||
Down: key.NewBinding(
|
||||
key.WithKeys("down"),
|
||||
key.WithHelp("↓", "next command"),
|
||||
),
|
||||
Enter: key.NewBinding(
|
||||
key.WithKeys("enter"),
|
||||
key.WithHelp("enter", "select command"),
|
||||
),
|
||||
Escape: key.NewBinding(
|
||||
key.WithKeys("esc"),
|
||||
key.WithHelp("esc", "close"),
|
||||
),
|
||||
J: key.NewBinding(
|
||||
key.WithKeys("j"),
|
||||
key.WithHelp("j", "next command"),
|
||||
),
|
||||
K: key.NewBinding(
|
||||
key.WithKeys("k"),
|
||||
key.WithHelp("k", "previous command"),
|
||||
),
|
||||
}
|
||||
|
||||
func (c *commandDialogCmp) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *commandDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch {
|
||||
case key.Matches(msg, commandKeys.Up) || key.Matches(msg, commandKeys.K):
|
||||
if c.selectedIdx > 0 {
|
||||
c.selectedIdx--
|
||||
}
|
||||
return c, nil
|
||||
case key.Matches(msg, commandKeys.Down) || key.Matches(msg, commandKeys.J):
|
||||
if c.selectedIdx < len(c.commands)-1 {
|
||||
c.selectedIdx++
|
||||
}
|
||||
return c, nil
|
||||
case key.Matches(msg, commandKeys.Enter):
|
||||
if len(c.commands) > 0 {
|
||||
return c, util.CmdHandler(CommandSelectedMsg{
|
||||
Command: c.commands[c.selectedIdx],
|
||||
})
|
||||
}
|
||||
case key.Matches(msg, commandKeys.Escape):
|
||||
return c, util.CmdHandler(CloseCommandDialogMsg{})
|
||||
}
|
||||
case tea.WindowSizeMsg:
|
||||
c.width = msg.Width
|
||||
c.height = msg.Height
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *commandDialogCmp) View() string {
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
if len(c.commands) == 0 {
|
||||
return baseStyle.Padding(1, 2).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderBackground(t.Background()).
|
||||
BorderForeground(t.TextMuted()).
|
||||
Width(40).
|
||||
Render("No commands available")
|
||||
}
|
||||
|
||||
// Calculate max width needed for command titles
|
||||
maxWidth := 40 // Minimum width
|
||||
for _, cmd := range c.commands {
|
||||
if len(cmd.Title) > maxWidth-4 { // Account for padding
|
||||
maxWidth = len(cmd.Title) + 4
|
||||
}
|
||||
if len(cmd.Description) > maxWidth-4 {
|
||||
maxWidth = len(cmd.Description) + 4
|
||||
}
|
||||
}
|
||||
|
||||
// Limit height to avoid taking up too much screen space
|
||||
maxVisibleCommands := min(10, len(c.commands))
|
||||
|
||||
// Build the command list
|
||||
commandItems := make([]string, 0, maxVisibleCommands)
|
||||
startIdx := 0
|
||||
|
||||
// If we have more commands than can be displayed, adjust the start index
|
||||
if len(c.commands) > maxVisibleCommands {
|
||||
// Center the selected item when possible
|
||||
halfVisible := maxVisibleCommands / 2
|
||||
if c.selectedIdx >= halfVisible && c.selectedIdx < len(c.commands)-halfVisible {
|
||||
startIdx = c.selectedIdx - halfVisible
|
||||
} else if c.selectedIdx >= len(c.commands)-halfVisible {
|
||||
startIdx = len(c.commands) - maxVisibleCommands
|
||||
}
|
||||
}
|
||||
|
||||
endIdx := min(startIdx+maxVisibleCommands, len(c.commands))
|
||||
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
cmd := c.commands[i]
|
||||
itemStyle := baseStyle.Width(maxWidth)
|
||||
descStyle := baseStyle.Width(maxWidth).Foreground(t.TextMuted())
|
||||
|
||||
if i == c.selectedIdx {
|
||||
itemStyle = itemStyle.
|
||||
Background(t.Primary()).
|
||||
Foreground(t.Background()).
|
||||
Bold(true)
|
||||
descStyle = descStyle.
|
||||
Background(t.Primary()).
|
||||
Foreground(t.Background())
|
||||
}
|
||||
|
||||
title := itemStyle.Padding(0, 1).Render(cmd.Title)
|
||||
description := ""
|
||||
if cmd.Description != "" {
|
||||
description = descStyle.Padding(0, 1).Render(cmd.Description)
|
||||
commandItems = append(commandItems, lipgloss.JoinVertical(lipgloss.Left, title, description))
|
||||
} else {
|
||||
commandItems = append(commandItems, title)
|
||||
}
|
||||
}
|
||||
|
||||
title := baseStyle.
|
||||
Foreground(t.Primary()).
|
||||
Bold(true).
|
||||
Width(maxWidth).
|
||||
Padding(0, 1).
|
||||
Render("Commands")
|
||||
|
||||
content := lipgloss.JoinVertical(
|
||||
lipgloss.Left,
|
||||
title,
|
||||
baseStyle.Width(maxWidth).Render(""),
|
||||
baseStyle.Width(maxWidth).Render(lipgloss.JoinVertical(lipgloss.Left, commandItems...)),
|
||||
baseStyle.Width(maxWidth).Render(""),
|
||||
)
|
||||
|
||||
return baseStyle.Padding(1, 2).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderBackground(t.Background()).
|
||||
BorderForeground(t.TextMuted()).
|
||||
Width(lipgloss.Width(content) + 4).
|
||||
Render(content)
|
||||
}
|
||||
|
||||
func (c *commandDialogCmp) BindingKeys() []key.Binding {
|
||||
return layout.KeyMapToSlice(commandKeys)
|
||||
}
|
||||
|
||||
func (c *commandDialogCmp) SetCommands(commands []Command) {
|
||||
c.commands = commands
|
||||
|
||||
// If we have a selected command ID, find its index
|
||||
if c.selectedCommandID != "" {
|
||||
for i, cmd := range commands {
|
||||
if cmd.ID == c.selectedCommandID {
|
||||
c.selectedIdx = i
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default to first command if selected not found
|
||||
c.selectedIdx = 0
|
||||
}
|
||||
|
||||
func (c *commandDialogCmp) SetSelectedCommand(commandID string) {
|
||||
c.selectedCommandID = commandID
|
||||
|
||||
// Update the selected index if commands are already loaded
|
||||
if len(c.commands) > 0 {
|
||||
for i, cmd := range c.commands {
|
||||
if cmd.ID == commandID {
|
||||
c.selectedIdx = i
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewCommandDialogCmp creates a new command selection dialog
|
||||
func NewCommandDialogCmp() CommandDialog {
|
||||
return &commandDialogCmp{
|
||||
commands: []Command{},
|
||||
selectedIdx: 0,
|
||||
selectedCommandID: "",
|
||||
}
|
||||
}
|
||||
@@ -1,494 +0,0 @@
|
||||
package dialog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/sst/opencode/internal/diff"
|
||||
"github.com/sst/opencode/internal/llm/tools"
|
||||
"github.com/sst/opencode/internal/permission"
|
||||
"github.com/sst/opencode/internal/tui/layout"
|
||||
"github.com/sst/opencode/internal/tui/styles"
|
||||
"github.com/sst/opencode/internal/tui/theme"
|
||||
"github.com/sst/opencode/internal/tui/util"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type PermissionAction string
|
||||
|
||||
// Permission responses
|
||||
const (
|
||||
PermissionAllow PermissionAction = "allow"
|
||||
PermissionAllowForSession PermissionAction = "allow_session"
|
||||
PermissionDeny PermissionAction = "deny"
|
||||
)
|
||||
|
||||
// PermissionResponseMsg represents the user's response to a permission request
|
||||
type PermissionResponseMsg struct {
|
||||
Permission permission.PermissionRequest
|
||||
Action PermissionAction
|
||||
}
|
||||
|
||||
// PermissionDialogCmp interface for permission dialog component
|
||||
type PermissionDialogCmp interface {
|
||||
tea.Model
|
||||
layout.Bindings
|
||||
SetPermissions(permission permission.PermissionRequest) tea.Cmd
|
||||
}
|
||||
|
||||
type permissionsMapping struct {
|
||||
Left key.Binding
|
||||
Right key.Binding
|
||||
EnterSpace key.Binding
|
||||
Allow key.Binding
|
||||
AllowSession key.Binding
|
||||
Deny key.Binding
|
||||
Tab key.Binding
|
||||
}
|
||||
|
||||
var permissionsKeys = permissionsMapping{
|
||||
Left: key.NewBinding(
|
||||
key.WithKeys("left"),
|
||||
key.WithHelp("←", "switch options"),
|
||||
),
|
||||
Right: key.NewBinding(
|
||||
key.WithKeys("right"),
|
||||
key.WithHelp("→", "switch options"),
|
||||
),
|
||||
EnterSpace: key.NewBinding(
|
||||
key.WithKeys("enter", " "),
|
||||
key.WithHelp("enter/space", "confirm"),
|
||||
),
|
||||
Allow: key.NewBinding(
|
||||
key.WithKeys("a"),
|
||||
key.WithHelp("a", "allow"),
|
||||
),
|
||||
AllowSession: key.NewBinding(
|
||||
key.WithKeys("s"),
|
||||
key.WithHelp("s", "allow for session"),
|
||||
),
|
||||
Deny: key.NewBinding(
|
||||
key.WithKeys("d"),
|
||||
key.WithHelp("d", "deny"),
|
||||
),
|
||||
Tab: key.NewBinding(
|
||||
key.WithKeys("tab"),
|
||||
key.WithHelp("tab", "switch options"),
|
||||
),
|
||||
}
|
||||
|
||||
// permissionDialogCmp is the implementation of PermissionDialog
|
||||
type permissionDialogCmp struct {
|
||||
width int
|
||||
height int
|
||||
permission permission.PermissionRequest
|
||||
windowSize tea.WindowSizeMsg
|
||||
contentViewPort viewport.Model
|
||||
selectedOption int // 0: Allow, 1: Allow for session, 2: Deny
|
||||
|
||||
diffCache map[string]string
|
||||
markdownCache map[string]string
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) Init() tea.Cmd {
|
||||
return p.contentViewPort.Init()
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmds []tea.Cmd
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
p.windowSize = msg
|
||||
cmd := p.SetSize()
|
||||
cmds = append(cmds, cmd)
|
||||
p.markdownCache = make(map[string]string)
|
||||
p.diffCache = make(map[string]string)
|
||||
case tea.KeyMsg:
|
||||
switch {
|
||||
case key.Matches(msg, permissionsKeys.Right) || key.Matches(msg, permissionsKeys.Tab):
|
||||
p.selectedOption = (p.selectedOption + 1) % 3
|
||||
return p, nil
|
||||
case key.Matches(msg, permissionsKeys.Left):
|
||||
p.selectedOption = (p.selectedOption + 2) % 3
|
||||
case key.Matches(msg, permissionsKeys.EnterSpace):
|
||||
return p, p.selectCurrentOption()
|
||||
case key.Matches(msg, permissionsKeys.Allow):
|
||||
return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionAllow, Permission: p.permission})
|
||||
case key.Matches(msg, permissionsKeys.AllowSession):
|
||||
return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionAllowForSession, Permission: p.permission})
|
||||
case key.Matches(msg, permissionsKeys.Deny):
|
||||
return p, util.CmdHandler(PermissionResponseMsg{Action: PermissionDeny, Permission: p.permission})
|
||||
default:
|
||||
// Pass other keys to viewport
|
||||
viewPort, cmd := p.contentViewPort.Update(msg)
|
||||
p.contentViewPort = viewPort
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
}
|
||||
|
||||
return p, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) selectCurrentOption() tea.Cmd {
|
||||
var action PermissionAction
|
||||
|
||||
switch p.selectedOption {
|
||||
case 0:
|
||||
action = PermissionAllow
|
||||
case 1:
|
||||
action = PermissionAllowForSession
|
||||
case 2:
|
||||
action = PermissionDeny
|
||||
}
|
||||
|
||||
return util.CmdHandler(PermissionResponseMsg{Action: action, Permission: p.permission})
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) renderButtons() string {
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
allowStyle := baseStyle
|
||||
allowSessionStyle := baseStyle
|
||||
denyStyle := baseStyle
|
||||
spacerStyle := baseStyle.Background(t.Background())
|
||||
|
||||
// Style the selected button
|
||||
switch p.selectedOption {
|
||||
case 0:
|
||||
allowStyle = allowStyle.Background(t.Primary()).Foreground(t.Background())
|
||||
allowSessionStyle = allowSessionStyle.Background(t.Background()).Foreground(t.Primary())
|
||||
denyStyle = denyStyle.Background(t.Background()).Foreground(t.Primary())
|
||||
case 1:
|
||||
allowStyle = allowStyle.Background(t.Background()).Foreground(t.Primary())
|
||||
allowSessionStyle = allowSessionStyle.Background(t.Primary()).Foreground(t.Background())
|
||||
denyStyle = denyStyle.Background(t.Background()).Foreground(t.Primary())
|
||||
case 2:
|
||||
allowStyle = allowStyle.Background(t.Background()).Foreground(t.Primary())
|
||||
allowSessionStyle = allowSessionStyle.Background(t.Background()).Foreground(t.Primary())
|
||||
denyStyle = denyStyle.Background(t.Primary()).Foreground(t.Background())
|
||||
}
|
||||
|
||||
allowButton := allowStyle.Padding(0, 1).Render("Allow (a)")
|
||||
allowSessionButton := allowSessionStyle.Padding(0, 1).Render("Allow for session (s)")
|
||||
denyButton := denyStyle.Padding(0, 1).Render("Deny (d)")
|
||||
|
||||
content := lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
allowButton,
|
||||
spacerStyle.Render(" "),
|
||||
allowSessionButton,
|
||||
spacerStyle.Render(" "),
|
||||
denyButton,
|
||||
spacerStyle.Render(" "),
|
||||
)
|
||||
|
||||
remainingWidth := p.width - lipgloss.Width(content)
|
||||
if remainingWidth > 0 {
|
||||
content = spacerStyle.Render(strings.Repeat(" ", remainingWidth)) + content
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) renderHeader() string {
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
toolKey := baseStyle.Foreground(t.TextMuted()).Bold(true).Render("Tool")
|
||||
toolValue := baseStyle.
|
||||
Foreground(t.Text()).
|
||||
Width(p.width - lipgloss.Width(toolKey)).
|
||||
Render(fmt.Sprintf(": %s", p.permission.ToolName))
|
||||
|
||||
pathKey := baseStyle.Foreground(t.TextMuted()).Bold(true).Render("Path")
|
||||
pathValue := baseStyle.
|
||||
Foreground(t.Text()).
|
||||
Width(p.width - lipgloss.Width(pathKey)).
|
||||
Render(fmt.Sprintf(": %s", p.permission.Path))
|
||||
|
||||
headerParts := []string{
|
||||
lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
toolKey,
|
||||
toolValue,
|
||||
),
|
||||
baseStyle.Render(strings.Repeat(" ", p.width)),
|
||||
lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
pathKey,
|
||||
pathValue,
|
||||
),
|
||||
baseStyle.Render(strings.Repeat(" ", p.width)),
|
||||
}
|
||||
|
||||
// Add tool-specific header information
|
||||
switch p.permission.ToolName {
|
||||
case tools.BashToolName:
|
||||
headerParts = append(headerParts, baseStyle.Foreground(t.TextMuted()).Width(p.width).Bold(true).Render("Command"))
|
||||
case tools.EditToolName:
|
||||
headerParts = append(headerParts, baseStyle.Foreground(t.TextMuted()).Width(p.width).Bold(true).Render("Diff"))
|
||||
case tools.WriteToolName:
|
||||
headerParts = append(headerParts, baseStyle.Foreground(t.TextMuted()).Width(p.width).Bold(true).Render("Diff"))
|
||||
case tools.FetchToolName:
|
||||
headerParts = append(headerParts, baseStyle.Foreground(t.TextMuted()).Width(p.width).Bold(true).Render("URL"))
|
||||
}
|
||||
|
||||
return lipgloss.NewStyle().Background(t.Background()).Render(lipgloss.JoinVertical(lipgloss.Left, headerParts...))
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) renderBashContent() string {
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
if pr, ok := p.permission.Params.(tools.BashPermissionsParams); ok {
|
||||
content := fmt.Sprintf("```bash\n%s\n```", pr.Command)
|
||||
|
||||
// Use the cache for markdown rendering
|
||||
renderedContent := p.GetOrSetMarkdown(p.permission.ID, func() (string, error) {
|
||||
r := styles.GetMarkdownRenderer(p.width - 10)
|
||||
s, err := r.Render(content)
|
||||
return styles.ForceReplaceBackgroundWithLipgloss(s, t.Background()), err
|
||||
})
|
||||
|
||||
finalContent := baseStyle.
|
||||
Width(p.contentViewPort.Width).
|
||||
Render(renderedContent)
|
||||
p.contentViewPort.SetContent(finalContent)
|
||||
return p.styleViewport()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) renderEditContent() string {
|
||||
if pr, ok := p.permission.Params.(tools.EditPermissionsParams); ok {
|
||||
diff := p.GetOrSetDiff(p.permission.ID, func() (string, error) {
|
||||
return diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width))
|
||||
})
|
||||
|
||||
p.contentViewPort.SetContent(diff)
|
||||
return p.styleViewport()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) renderPatchContent() string {
|
||||
if pr, ok := p.permission.Params.(tools.EditPermissionsParams); ok {
|
||||
diff := p.GetOrSetDiff(p.permission.ID, func() (string, error) {
|
||||
return diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width))
|
||||
})
|
||||
|
||||
p.contentViewPort.SetContent(diff)
|
||||
return p.styleViewport()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) renderWriteContent() string {
|
||||
if pr, ok := p.permission.Params.(tools.WritePermissionsParams); ok {
|
||||
// Use the cache for diff rendering
|
||||
diff := p.GetOrSetDiff(p.permission.ID, func() (string, error) {
|
||||
return diff.FormatDiff(pr.Diff, diff.WithTotalWidth(p.contentViewPort.Width))
|
||||
})
|
||||
|
||||
p.contentViewPort.SetContent(diff)
|
||||
return p.styleViewport()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) renderFetchContent() string {
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
if pr, ok := p.permission.Params.(tools.FetchPermissionsParams); ok {
|
||||
content := fmt.Sprintf("```bash\n%s\n```", pr.URL)
|
||||
|
||||
// Use the cache for markdown rendering
|
||||
renderedContent := p.GetOrSetMarkdown(p.permission.ID, func() (string, error) {
|
||||
r := styles.GetMarkdownRenderer(p.width - 10)
|
||||
s, err := r.Render(content)
|
||||
return styles.ForceReplaceBackgroundWithLipgloss(s, t.Background()), err
|
||||
})
|
||||
|
||||
finalContent := baseStyle.
|
||||
Width(p.contentViewPort.Width).
|
||||
Render(renderedContent)
|
||||
p.contentViewPort.SetContent(finalContent)
|
||||
return p.styleViewport()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) renderDefaultContent() string {
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
content := p.permission.Description
|
||||
|
||||
// Use the cache for markdown rendering
|
||||
renderedContent := p.GetOrSetMarkdown(p.permission.ID, func() (string, error) {
|
||||
r := styles.GetMarkdownRenderer(p.width - 10)
|
||||
s, err := r.Render(content)
|
||||
return styles.ForceReplaceBackgroundWithLipgloss(s, t.Background()), err
|
||||
})
|
||||
|
||||
finalContent := baseStyle.
|
||||
Width(p.contentViewPort.Width).
|
||||
Render(renderedContent)
|
||||
p.contentViewPort.SetContent(finalContent)
|
||||
|
||||
if renderedContent == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return p.styleViewport()
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) styleViewport() string {
|
||||
t := theme.CurrentTheme()
|
||||
contentStyle := lipgloss.NewStyle().
|
||||
Background(t.Background())
|
||||
|
||||
return contentStyle.Render(p.contentViewPort.View())
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) render() string {
|
||||
t := theme.CurrentTheme()
|
||||
baseStyle := styles.BaseStyle()
|
||||
|
||||
title := baseStyle.
|
||||
Bold(true).
|
||||
Width(p.width - 4).
|
||||
Foreground(t.Primary()).
|
||||
Render("Permission Required")
|
||||
// Render header
|
||||
headerContent := p.renderHeader()
|
||||
// Render buttons
|
||||
buttons := p.renderButtons()
|
||||
|
||||
// Calculate content height dynamically based on window size
|
||||
p.contentViewPort.Height = p.height - lipgloss.Height(headerContent) - lipgloss.Height(buttons) - 2 - lipgloss.Height(title)
|
||||
p.contentViewPort.Width = p.width - 4
|
||||
|
||||
// Render content based on tool type
|
||||
var contentFinal string
|
||||
switch p.permission.ToolName {
|
||||
case tools.BashToolName:
|
||||
contentFinal = p.renderBashContent()
|
||||
case tools.EditToolName:
|
||||
contentFinal = p.renderEditContent()
|
||||
case tools.PatchToolName:
|
||||
contentFinal = p.renderPatchContent()
|
||||
case tools.WriteToolName:
|
||||
contentFinal = p.renderWriteContent()
|
||||
case tools.FetchToolName:
|
||||
contentFinal = p.renderFetchContent()
|
||||
default:
|
||||
contentFinal = p.renderDefaultContent()
|
||||
}
|
||||
|
||||
content := lipgloss.JoinVertical(
|
||||
lipgloss.Top,
|
||||
title,
|
||||
baseStyle.Render(strings.Repeat(" ", lipgloss.Width(title))),
|
||||
headerContent,
|
||||
contentFinal,
|
||||
buttons,
|
||||
baseStyle.Render(strings.Repeat(" ", p.width-4)),
|
||||
)
|
||||
|
||||
return baseStyle.
|
||||
Padding(1, 0, 0, 1).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderBackground(t.Background()).
|
||||
BorderForeground(t.TextMuted()).
|
||||
Width(p.width).
|
||||
Height(p.height).
|
||||
Render(
|
||||
content,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) View() string {
|
||||
return p.render()
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) BindingKeys() []key.Binding {
|
||||
return layout.KeyMapToSlice(permissionsKeys)
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) SetSize() tea.Cmd {
|
||||
if p.permission.ID == "" {
|
||||
return nil
|
||||
}
|
||||
switch p.permission.ToolName {
|
||||
case tools.BashToolName:
|
||||
p.width = int(float64(p.windowSize.Width) * 0.4)
|
||||
p.height = int(float64(p.windowSize.Height) * 0.3)
|
||||
case tools.EditToolName:
|
||||
p.width = int(float64(p.windowSize.Width) * 0.8)
|
||||
p.height = int(float64(p.windowSize.Height) * 0.8)
|
||||
case tools.WriteToolName:
|
||||
p.width = int(float64(p.windowSize.Width) * 0.8)
|
||||
p.height = int(float64(p.windowSize.Height) * 0.8)
|
||||
case tools.FetchToolName:
|
||||
p.width = int(float64(p.windowSize.Width) * 0.4)
|
||||
p.height = int(float64(p.windowSize.Height) * 0.3)
|
||||
default:
|
||||
p.width = int(float64(p.windowSize.Width) * 0.7)
|
||||
p.height = int(float64(p.windowSize.Height) * 0.5)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *permissionDialogCmp) SetPermissions(permission permission.PermissionRequest) tea.Cmd {
|
||||
p.permission = permission
|
||||
return p.SetSize()
|
||||
}
|
||||
|
||||
// Helper to get or set cached diff content
|
||||
func (c *permissionDialogCmp) GetOrSetDiff(key string, generator func() (string, error)) string {
|
||||
if cached, ok := c.diffCache[key]; ok {
|
||||
return cached
|
||||
}
|
||||
|
||||
content, err := generator()
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error formatting diff: %v", err)
|
||||
}
|
||||
|
||||
c.diffCache[key] = content
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
// Helper to get or set cached markdown content
|
||||
func (c *permissionDialogCmp) GetOrSetMarkdown(key string, generator func() (string, error)) string {
|
||||
if cached, ok := c.markdownCache[key]; ok {
|
||||
return cached
|
||||
}
|
||||
|
||||
content, err := generator()
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error rendering markdown: %v", err)
|
||||
}
|
||||
|
||||
c.markdownCache[key] = content
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
func NewPermissionDialogCmp() PermissionDialogCmp {
|
||||
// Create viewport for content
|
||||
contentViewport := viewport.New(0, 0)
|
||||
|
||||
return &permissionDialogCmp{
|
||||
contentViewPort: contentViewport,
|
||||
selectedOption: 0, // Default to "Allow"
|
||||
diffCache: make(map[string]string),
|
||||
markdownCache: make(map[string]string),
|
||||
}
|
||||
}
|
||||
@@ -1,187 +0,0 @@
|
||||
package logs
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/sst/opencode/internal/logging"
|
||||
"github.com/sst/opencode/internal/tui/layout"
|
||||
"github.com/sst/opencode/internal/tui/styles"
|
||||
"github.com/sst/opencode/internal/tui/theme"
|
||||
)
|
||||
|
||||
type DetailComponent interface {
|
||||
tea.Model
|
||||
layout.Sizeable
|
||||
layout.Bindings
|
||||
}
|
||||
|
||||
type detailCmp struct {
|
||||
width, height int
|
||||
currentLog logging.Log
|
||||
viewport viewport.Model
|
||||
focused bool
|
||||
}
|
||||
|
||||
func (i *detailCmp) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *detailCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
switch msg := msg.(type) {
|
||||
case selectedLogMsg:
|
||||
if msg.ID != i.currentLog.ID {
|
||||
i.currentLog = logging.Log(msg)
|
||||
// Defer content update to avoid blocking the UI
|
||||
cmd = tea.Tick(time.Millisecond*1, func(time.Time) tea.Msg {
|
||||
i.updateContent()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
case tea.KeyMsg:
|
||||
// Only process keyboard input when focused
|
||||
if !i.focused {
|
||||
return i, nil
|
||||
}
|
||||
// Handle keyboard input for scrolling
|
||||
i.viewport, cmd = i.viewport.Update(msg)
|
||||
return i, cmd
|
||||
}
|
||||
|
||||
return i, cmd
|
||||
}
|
||||
|
||||
func (i *detailCmp) updateContent() {
|
||||
var content strings.Builder
|
||||
t := theme.CurrentTheme()
|
||||
|
||||
// Format the header with timestamp and level
|
||||
timeStyle := lipgloss.NewStyle().Foreground(t.TextMuted())
|
||||
levelStyle := getLevelStyle(i.currentLog.Level)
|
||||
|
||||
// Format timestamp
|
||||
timeStr := i.currentLog.Timestamp.Format(time.RFC3339)
|
||||
|
||||
header := lipgloss.JoinHorizontal(
|
||||
lipgloss.Center,
|
||||
timeStyle.Render(timeStr),
|
||||
" ",
|
||||
levelStyle.Render(i.currentLog.Level),
|
||||
)
|
||||
|
||||
content.WriteString(lipgloss.NewStyle().Bold(true).Render(header))
|
||||
content.WriteString("\n\n")
|
||||
|
||||
// Message with styling
|
||||
messageStyle := lipgloss.NewStyle().Bold(true).Foreground(t.Text())
|
||||
content.WriteString(messageStyle.Render("Message:"))
|
||||
content.WriteString("\n")
|
||||
content.WriteString(lipgloss.NewStyle().Padding(0, 2).Render(i.currentLog.Message))
|
||||
content.WriteString("\n\n")
|
||||
|
||||
// Attributes section
|
||||
if len(i.currentLog.Attributes) > 0 {
|
||||
attrHeaderStyle := lipgloss.NewStyle().Bold(true).Foreground(t.Text())
|
||||
content.WriteString(attrHeaderStyle.Render("Attributes:"))
|
||||
content.WriteString("\n")
|
||||
|
||||
// Create a table-like display for attributes
|
||||
keyStyle := lipgloss.NewStyle().Foreground(t.Primary()).Bold(true)
|
||||
valueStyle := lipgloss.NewStyle().Foreground(t.Text())
|
||||
|
||||
for key, value := range i.currentLog.Attributes {
|
||||
// if value is JSON, render it with indentation
|
||||
if strings.HasPrefix(value, "{") {
|
||||
var indented bytes.Buffer
|
||||
if err := json.Indent(&indented, []byte(value), "", " "); err != nil {
|
||||
indented.WriteString(value)
|
||||
}
|
||||
value = indented.String()
|
||||
}
|
||||
|
||||
attrLine := fmt.Sprintf("%s: %s",
|
||||
keyStyle.Render(key),
|
||||
valueStyle.Render(value),
|
||||
)
|
||||
|
||||
content.WriteString(lipgloss.NewStyle().Padding(0, 2).Render(attrLine))
|
||||
content.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Session ID if available
|
||||
if i.currentLog.SessionID != "" {
|
||||
sessionStyle := lipgloss.NewStyle().Bold(true).Foreground(t.Text())
|
||||
content.WriteString("\n")
|
||||
content.WriteString(sessionStyle.Render("Session:"))
|
||||
content.WriteString("\n")
|
||||
content.WriteString(lipgloss.NewStyle().Padding(0, 2).Render(i.currentLog.SessionID))
|
||||
}
|
||||
|
||||
i.viewport.SetContent(content.String())
|
||||
}
|
||||
|
||||
func getLevelStyle(level string) lipgloss.Style {
|
||||
style := lipgloss.NewStyle().Bold(true)
|
||||
t := theme.CurrentTheme()
|
||||
|
||||
switch strings.ToLower(level) {
|
||||
case "info":
|
||||
return style.Foreground(t.Info())
|
||||
case "warn", "warning":
|
||||
return style.Foreground(t.Warning())
|
||||
case "error", "err":
|
||||
return style.Foreground(t.Error())
|
||||
case "debug":
|
||||
return style.Foreground(t.Success())
|
||||
default:
|
||||
return style.Foreground(t.Text())
|
||||
}
|
||||
}
|
||||
|
||||
func (i *detailCmp) View() string {
|
||||
t := theme.CurrentTheme()
|
||||
return styles.ForceReplaceBackgroundWithLipgloss(i.viewport.View(), t.Background())
|
||||
}
|
||||
|
||||
func (i *detailCmp) GetSize() (int, int) {
|
||||
return i.width, i.height
|
||||
}
|
||||
|
||||
func (i *detailCmp) SetSize(width int, height int) tea.Cmd {
|
||||
i.width = width
|
||||
i.height = height
|
||||
i.viewport.Width = i.width
|
||||
i.viewport.Height = i.height
|
||||
i.updateContent()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *detailCmp) BindingKeys() []key.Binding {
|
||||
return layout.KeyMapToSlice(i.viewport.KeyMap)
|
||||
}
|
||||
|
||||
func NewLogsDetails() DetailComponent {
|
||||
return &detailCmp{
|
||||
viewport: viewport.New(0, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Focus implements the focusable interface
|
||||
func (i *detailCmp) Focus() {
|
||||
i.focused = true
|
||||
i.viewport.SetYOffset(i.viewport.YOffset)
|
||||
}
|
||||
|
||||
// Blur implements the blurable interface
|
||||
func (i *detailCmp) Blur() {
|
||||
i.focused = false
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user