Skip to content

feat(go): Add tool support for ollama models #2796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 132 additions & 13 deletions go/plugins/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,20 @@ const provider = "ollama"

var (
mediaSupportedModels = []string{"llava"}
roleMapping = map[ai.Role]string{
toolSupportedModels = []string{
"qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral",
"qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2",
"mistral-small", "command-r", "hermes3", "mistral-large", "command-r-plus",
"phi4-mini", "granite3.1-dense", "granite3-dense", "granite3.2", "athene-v2",
"nemotron-mini", "nemotron", "llama3-groq-tool-use", "aya-expanse", "granite3-moe",
"granite3.2-vision", "granite3.1-moe", "cogito", "command-r7b", "firefunction-v2",
"granite3.3", "command-a", "command-r7b-arabic",
}
roleMapping = map[ai.Role]string{
ai.RoleUser: "user",
ai.RoleModel: "assistant",
ai.RoleSystem: "system",
ai.RoleTool: "tool",
}
)

Expand All @@ -57,12 +67,15 @@ func (o *Ollama) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.M
if info != nil {
mi = *info
} else {
// Check if the model supports tools (must be a chat model and in the supported list)
supportsTools := model.Type == "chat" && slices.Contains(toolSupportedModels, model.Name)
mi = ai.ModelInfo{
Label: model.Name,
Supports: &ai.ModelSupports{
Multiturn: true,
SystemRole: true,
Media: slices.Contains(mediaSupportedModels, model.Name),
Tools: supportsTools,
},
Versions: []string{},
}
Expand Down Expand Up @@ -99,9 +112,10 @@ type generator struct {
}

type ollamaMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Images []string `json:"images,omitempty"`
Role string `json:"role"`
Content string `json:"content,omitempty"`
Images []string `json:"images,omitempty"`
ToolCalls []ollamaToolCall `json:"tool_calls,omitempty"`
}

// Ollama has two API endpoints, one with a chat interface and another with a generate response interface.
Expand All @@ -122,6 +136,8 @@ type ollamaChatRequest struct {
Messages []*ollamaMessage `json:"messages"`
Model string `json:"model"`
Stream bool `json:"stream"`
Format string `json:"format,omitempty"`
Tools []ollamaTool `json:"tools,omitempty"`
}

type ollamaModelRequest struct {
Expand All @@ -132,13 +148,38 @@ type ollamaModelRequest struct {
Stream bool `json:"stream"`
}

// Tool definition from Ollama API
type ollamaTool struct {
Type string `json:"type"`
Function ollamaFunction `json:"function"`
}

// Function definition for Ollama API
type ollamaFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]any `json:"parameters"`
}

// Tool Call from Ollama API
type ollamaToolCall struct {
Function ollamaFunctionCall `json:"function"`
}

// Function Call for Ollama API
type ollamaFunctionCall struct {
Name string `json:"name"`
Arguments any `json:"arguments"`
}

// TODO: Add optional parameters (images, format, options, etc.) based on your use case
type ollamaChatResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ollamaToolCall `json:"tool_calls,omitempty"`
} `json:"message"`
}

Expand Down Expand Up @@ -181,6 +222,7 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
stream := cb != nil
var payload any
isChatModel := g.model.Type == "chat"

if !isChatModel {
images, err := concatImages(input, []ai.Role{ai.RoleUser, ai.RoleModel})
if err != nil {
Expand All @@ -203,33 +245,46 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
}
messages = append(messages, message)
}
payload = ollamaChatRequest{
chatReq := ollamaChatRequest{
Messages: messages,
Model: g.model.Name,
Stream: stream,
}
if len(input.Tools) > 0 {
tools, err := convertTools(input.Tools)
if err != nil {
return nil, fmt.Errorf("failed to convert tools: %v", err)
}
chatReq.Tools = tools
}
payload = chatReq
}

client := &http.Client{Timeout: 30 * time.Second}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, err
}

// Determine the correct endpoint
endpoint := g.serverAddress + "/api/chat"
if !isChatModel {
endpoint = g.serverAddress + "/api/generate"
}

req, err := http.NewRequest("POST", endpoint, bytes.NewReader(payloadBytes))
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
req = req.WithContext(ctx)

resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %v", err)
}
defer resp.Body.Close()

if cb == nil {
// Existing behavior for non-streaming responses
var err error
Expand All @@ -240,6 +295,7 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("server returned non-200 status: %d, body: %s", resp.StatusCode, body)
}

var response *ai.ModelResponse
if isChatModel {
response, err = translateChatResponse(body)
Expand All @@ -254,8 +310,12 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
} else {
var chunks []*ai.ModelResponseChunk
scanner := bufio.NewScanner(resp.Body)
chunkCount := 0

for scanner.Scan() {
line := scanner.Text()
chunkCount++

var chunk *ai.ModelResponseChunk
if isChatModel {
chunk, err = translateChatChunk(line)
Expand All @@ -268,9 +328,11 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
chunks = append(chunks, chunk)
cb(ctx, chunk)
}

if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("reading response stream: %v", err)
}

// Create a final response with the merged chunks
finalResponse := &ai.ModelResponse{
Request: input,
Expand All @@ -288,11 +350,28 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
}
}

// convertTools converts Genkit tool definitions to Ollama tool format
func convertTools(tools []*ai.ToolDefinition) ([]ollamaTool, error) {
ollamaTools := make([]ollamaTool, 0, len(tools))
for _, tool := range tools {
ollamaTools = append(ollamaTools, ollamaTool{
Type: "function",
Function: ollamaFunction{
Name: tool.Name,
Description: tool.Description,
Parameters: tool.InputSchema,
},
})
}
return ollamaTools, nil
}

func convertParts(role ai.Role, parts []*ai.Part) (*ollamaMessage, error) {
message := &ollamaMessage{
Role: roleMapping[role],
}
var contentBuilder strings.Builder
var toolCalls []ollamaToolCall
for _, part := range parts {
if part.IsText() {
contentBuilder.WriteString(part.Text)
Expand All @@ -303,11 +382,29 @@ func convertParts(role ai.Role, parts []*ai.Part) (*ollamaMessage, error) {
}
base64Encoded := base64.StdEncoding.EncodeToString(data)
message.Images = append(message.Images, base64Encoded)
} else if part.IsToolRequest() {
toolReq := part.ToolRequest
toolCalls = append(toolCalls, ollamaToolCall{
Function: ollamaFunctionCall{
Name: toolReq.Name,
Arguments: toolReq.Input,
},
})
} else if part.ToolResponse != nil {
toolResp := part.ToolResponse
outputJSON, err := json.Marshal(toolResp.Output)
if err != nil {
return nil, fmt.Errorf("failed to marshal tool response: %v", err)
}
contentBuilder.WriteString(string(outputJSON))
} else {
return nil, errors.New("unknown content type")
}
}
message.Content = contentBuilder.String()
if len(toolCalls) > 0 {
message.ToolCalls = toolCalls
}
return message, nil
}

Expand All @@ -321,12 +418,22 @@ func translateChatResponse(responseData []byte) (*ai.ModelResponse, error) {
modelResponse := &ai.ModelResponse{
FinishReason: ai.FinishReason("stop"),
Message: &ai.Message{
Role: ai.Role(response.Message.Role),
Role: ai.RoleModel,
},
}

aiPart := ai.NewTextPart(response.Message.Content)
modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart)
if len(response.Message.ToolCalls) > 0 {
for _, toolCall := range response.Message.ToolCalls {
toolRequest := &ai.ToolRequest{
Name: toolCall.Function.Name,
Input: toolCall.Function.Arguments,
}
toolPart := ai.NewToolRequestPart(toolRequest)
modelResponse.Message.Content = append(modelResponse.Message.Content, toolPart)
}
} else if response.Message.Content != "" {
aiPart := ai.NewTextPart(response.Message.Content)
modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart)
}

return modelResponse, nil
}
Expand Down Expand Up @@ -359,8 +466,20 @@ func translateChatChunk(input string) (*ai.ModelResponseChunk, error) {
return nil, fmt.Errorf("failed to parse response JSON: %v", err)
}
chunk := &ai.ModelResponseChunk{}
aiPart := ai.NewTextPart(response.Message.Content)
chunk.Content = append(chunk.Content, aiPart)
if len(response.Message.ToolCalls) > 0 {
for _, toolCall := range response.Message.ToolCalls {
toolRequest := &ai.ToolRequest{
Name: toolCall.Function.Name,
Input: toolCall.Function.Arguments,
}
toolPart := ai.NewToolRequestPart(toolRequest)
chunk.Content = append(chunk.Content, toolPart)
}
} else if response.Message.Content != "" {
aiPart := ai.NewTextPart(response.Message.Content)
chunk.Content = append(chunk.Content, aiPart)
}

return chunk, nil
}

Expand Down
Loading
Loading