Merge pull request #110 from tommyalt/feat/ollama-batch-embed

feat(ollama): support repeatable embed inputs
This commit is contained in:
Yuhao
2026-03-20 19:56:08 +08:00
committed by GitHub
5 changed files with 24 additions and 4 deletions

View File

@@ -65,7 +65,7 @@ Ollama already provides a clean REST API. Our CLI wraps it with:
| `ollama ps` | `model ps` |
| `ollama run <model> <prompt>` | `generate text --model <name> --prompt "..."` |
| (no equivalent) | `generate chat --model <name> --message "..."` |
| (no equivalent) | `embed text --model <name> --input "..."` |
| (no equivalent) | `embed text --model <name> --input "..." [--input "..."]` |
| `ollama serve` | (external — must be running) |
## Model Parameters (options)

View File

@@ -81,6 +81,7 @@ generate chat --model <name> --message "user:Hello" [--message "assistant:Hi"]
```bash
embed text --model <name> --input "Text to embed"
embed text --model <name> --input "First text" --input "Second text"
```
### Server

View File

@@ -340,11 +340,16 @@ def embed():
@embed.command("text")
@click.option("--model", "-m", "model_name", required=True, help="Model name")
@click.option("--input", "-i", "input_text", required=True, help="Text to embed")
@click.option(
"--input", "-i", "input_texts",
multiple=True, required=True,
help="Text to embed. Repeat for batch embeddings.",
)
@handle_error
def embed_text(model_name, input_text):
def embed_text(model_name, input_texts):
"""Generate embeddings for text."""
result = embed_mod.embed(_host, model_name, input_text)
payload = list(input_texts)
result = embed_mod.embed(_host, model_name, payload[0] if len(payload) == 1 else payload)
if _json_output:
output(result)
else:

View File

@@ -155,6 +155,7 @@ cli-anything-ollama generate chat --model llama3.2 --file messages.json
```bash
cli-anything-ollama embed text --model nomic-embed-text --input "Hello world"
cli-anything-ollama embed text --model nomic-embed-text --input "Hello" --input "World"
```

View File

@@ -315,6 +315,19 @@ class TestEmbedCommands:
data = json.loads(result.output)
assert "embeddings" in data
@patch("cli_anything.ollama.core.embeddings.api_post")
def test_embed_text_multiple_inputs_json(self, mock_api, runner):
mock_api.return_value = {"embeddings": [[0.1, 0.2], [0.3, 0.4]]}
result = runner.invoke(cli, ["--json", "embed", "text",
"--model", "nomic-embed-text",
"--input", "Hello",
"--input", "World"])
assert result.exit_code == 0
data = json.loads(result.output)
assert len(data["embeddings"]) == 2
call_data = mock_api.call_args[0][2]
assert call_data["input"] == ["Hello", "World"]
@patch("cli_anything.ollama.core.embeddings.api_post")
def test_embed_text_human(self, mock_api, runner):
mock_api.return_value = {"embeddings": [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]]}