diff --git a/go/plugins/ollama/ollama.go b/go/plugins/ollama/ollama.go index 6b5806e17..e74adab40 100644 --- a/go/plugins/ollama/ollama.go +++ b/go/plugins/ollama/ollama.go @@ -40,10 +40,20 @@ const provider = "ollama" var ( mediaSupportedModels = []string{"llava"} - roleMapping = map[ai.Role]string{ + toolSupportedModels = []string{ + "qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral", + "qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2", + "mistral-small", "command-r", "hermes3", "mistral-large", "command-r-plus", + "phi4-mini", "granite3.1-dense", "granite3-dense", "granite3.2", "athene-v2", + "nemotron-mini", "nemotron", "llama3-groq-tool-use", "aya-expanse", "granite3-moe", + "granite3.2-vision", "granite3.1-moe", "cogito", "command-r7b", "firefunction-v2", + "granite3.3", "command-a", "command-r7b-arabic", + } + roleMapping = map[ai.Role]string{ ai.RoleUser: "user", ai.RoleModel: "assistant", ai.RoleSystem: "system", + ai.RoleTool: "tool", } ) @@ -57,12 +67,15 @@ func (o *Ollama) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.M if info != nil { mi = *info } else { + // Check if the model supports tools (must be a chat model and in the supported list) + supportsTools := model.Type == "chat" && slices.Contains(toolSupportedModels, model.Name) mi = ai.ModelInfo{ Label: model.Name, Supports: &ai.ModelSupports{ Multiturn: true, SystemRole: true, Media: slices.Contains(mediaSupportedModels, model.Name), + Tools: supportsTools, }, Versions: []string{}, } @@ -99,9 +112,10 @@ type generator struct { } type ollamaMessage struct { - Role string `json:"role"` - Content string `json:"content"` - Images []string `json:"images,omitempty"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + Images []string `json:"images,omitempty"` + ToolCalls []ollamaToolCall `json:"tool_calls,omitempty"` } // Ollama has two API endpoints, one with a chat interface and another with a generate response interface. @@ -122,6 +136,8 @@ type ollamaChatRequest struct { Messages []*ollamaMessage `json:"messages"` Model string `json:"model"` Stream bool `json:"stream"` + Format string `json:"format,omitempty"` + Tools []ollamaTool `json:"tools,omitempty"` } type ollamaModelRequest struct { @@ -132,13 +148,38 @@ type ollamaModelRequest struct { Stream bool `json:"stream"` } +// Tool definition from Ollama API +type ollamaTool struct { + Type string `json:"type"` + Function ollamaFunction `json:"function"` +} + +// Function definition for Ollama API +type ollamaFunction struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]any `json:"parameters"` +} + +// Tool Call from Ollama API +type ollamaToolCall struct { + Function ollamaFunctionCall `json:"function"` +} + +// Function Call for Ollama API +type ollamaFunctionCall struct { + Name string `json:"name"` + Arguments any `json:"arguments"` +} + // TODO: Add optional parameters (images, format, options, etc.) based on your use case type ollamaChatResponse struct { Model string `json:"model"` CreatedAt string `json:"created_at"` Message struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ollamaToolCall `json:"tool_calls,omitempty"` } `json:"message"` } @@ -181,6 +222,7 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun stream := cb != nil var payload any isChatModel := g.model.Type == "chat" + if !isChatModel { images, err := concatImages(input, []ai.Role{ai.RoleUser, ai.RoleModel}) if err != nil { @@ -203,33 +245,46 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun } messages = append(messages, message) } - payload = ollamaChatRequest{ + chatReq := ollamaChatRequest{ Messages: messages, Model: g.model.Name, Stream: stream, } + if len(input.Tools) > 0 { + tools, err := convertTools(input.Tools) + if err != nil { + return nil, fmt.Errorf("failed to convert tools: %v", err) + } + chatReq.Tools = tools + } + payload = chatReq } + client := &http.Client{Timeout: 30 * time.Second} payloadBytes, err := json.Marshal(payload) if err != nil { return nil, err } + // Determine the correct endpoint endpoint := g.serverAddress + "/api/chat" if !isChatModel { endpoint = g.serverAddress + "/api/generate" } + req, err := http.NewRequest("POST", endpoint, bytes.NewReader(payloadBytes)) if err != nil { return nil, fmt.Errorf("failed to create request: %v", err) } req.Header.Set("Content-Type", "application/json") req = req.WithContext(ctx) + resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %v", err) } defer resp.Body.Close() + if cb == nil { // Existing behavior for non-streaming responses var err error @@ -240,6 +295,7 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("server returned non-200 status: %d, body: %s", resp.StatusCode, body) } + var response *ai.ModelResponse if isChatModel { response, err = translateChatResponse(body) @@ -254,8 +310,12 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun } else { var chunks []*ai.ModelResponseChunk scanner := bufio.NewScanner(resp.Body) + chunkCount := 0 + for scanner.Scan() { line := scanner.Text() + chunkCount++ + var chunk *ai.ModelResponseChunk if isChatModel { chunk, err = translateChatChunk(line) @@ -268,9 +328,11 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun chunks = append(chunks, chunk) cb(ctx, chunk) } + if err := scanner.Err(); err != nil { return nil, fmt.Errorf("reading response stream: %v", err) } + // Create a final response with the merged chunks finalResponse := &ai.ModelResponse{ Request: input, @@ -288,11 +350,28 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun } } +// convertTools converts Genkit tool definitions to Ollama tool format +func convertTools(tools []*ai.ToolDefinition) ([]ollamaTool, error) { + ollamaTools := make([]ollamaTool, 0, len(tools)) + for _, tool := range tools { + ollamaTools = append(ollamaTools, ollamaTool{ + Type: "function", + Function: ollamaFunction{ + Name: tool.Name, + Description: tool.Description, + Parameters: tool.InputSchema, + }, + }) + } + return ollamaTools, nil +} + func convertParts(role ai.Role, parts []*ai.Part) (*ollamaMessage, error) { message := &ollamaMessage{ Role: roleMapping[role], } var contentBuilder strings.Builder + var toolCalls []ollamaToolCall for _, part := range parts { if part.IsText() { contentBuilder.WriteString(part.Text) @@ -303,11 +382,29 @@ func convertParts(role ai.Role, parts []*ai.Part) (*ollamaMessage, error) { } base64Encoded := base64.StdEncoding.EncodeToString(data) message.Images = append(message.Images, base64Encoded) + } else if part.IsToolRequest() { + toolReq := part.ToolRequest + toolCalls = append(toolCalls, ollamaToolCall{ + Function: ollamaFunctionCall{ + Name: toolReq.Name, + Arguments: toolReq.Input, + }, + }) + } else if part.ToolResponse != nil { + toolResp := part.ToolResponse + outputJSON, err := json.Marshal(toolResp.Output) + if err != nil { + return nil, fmt.Errorf("failed to marshal tool response: %v", err) + } + contentBuilder.WriteString(string(outputJSON)) } else { return nil, errors.New("unknown content type") } } message.Content = contentBuilder.String() + if len(toolCalls) > 0 { + message.ToolCalls = toolCalls + } return message, nil } @@ -321,12 +418,22 @@ func translateChatResponse(responseData []byte) (*ai.ModelResponse, error) { modelResponse := &ai.ModelResponse{ FinishReason: ai.FinishReason("stop"), Message: &ai.Message{ - Role: ai.Role(response.Message.Role), + Role: ai.RoleModel, }, } - - aiPart := ai.NewTextPart(response.Message.Content) - modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart) + if len(response.Message.ToolCalls) > 0 { + for _, toolCall := range response.Message.ToolCalls { + toolRequest := &ai.ToolRequest{ + Name: toolCall.Function.Name, + Input: toolCall.Function.Arguments, + } + toolPart := ai.NewToolRequestPart(toolRequest) + modelResponse.Message.Content = append(modelResponse.Message.Content, toolPart) + } + } else if response.Message.Content != "" { + aiPart := ai.NewTextPart(response.Message.Content) + modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart) + } return modelResponse, nil } @@ -359,8 +466,20 @@ func translateChatChunk(input string) (*ai.ModelResponseChunk, error) { return nil, fmt.Errorf("failed to parse response JSON: %v", err) } chunk := &ai.ModelResponseChunk{} - aiPart := ai.NewTextPart(response.Message.Content) - chunk.Content = append(chunk.Content, aiPart) + if len(response.Message.ToolCalls) > 0 { + for _, toolCall := range response.Message.ToolCalls { + toolRequest := &ai.ToolRequest{ + Name: toolCall.Function.Name, + Input: toolCall.Function.Arguments, + } + toolPart := ai.NewToolRequestPart(toolRequest) + chunk.Content = append(chunk.Content, toolPart) + } + } else if response.Message.Content != "" { + aiPart := ai.NewTextPart(response.Message.Content) + chunk.Content = append(chunk.Content, aiPart) + } + return chunk, nil } diff --git a/go/samples/ollama-tools/main.go b/go/samples/ollama-tools/main.go new file mode 100644 index 000000000..ca59dac01 --- /dev/null +++ b/go/samples/ollama-tools/main.go @@ -0,0 +1,124 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "time" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/ollama" +) + +// WeatherInput defines the input structure for the weather tool +type WeatherInput struct { + Location string `json:"location"` +} + +// WeatherData represents weather information +type WeatherData struct { + Location string `json:"location"` + TempC float64 `json:"temp_c"` + TempF float64 `json:"temp_f"` + Condition string `json:"condition"` +} + +func main() { + ctx := context.Background() + + // Initialize Genkit with the Ollama plugin + ollamaPlugin := &ollama.Ollama{ + ServerAddress: "http://localhost:11434", // Default Ollama server address + } + + g, err := genkit.Init(ctx, genkit.WithPlugins(ollamaPlugin)) + if err != nil { + fmt.Printf("Failed to initialize Genkit: %v\n", err) + return + } + + // Define the Ollama model + model := ollamaPlugin.DefineModel(g, + ollama.ModelDefinition{ + Name: "llama3.1", // Choose an appropriate model + Type: "chat", // Must be chat for tool support + }, + nil) + + // Define tools + weatherTool := genkit.DefineTool(g, "weather", "Get current weather for a location", + func(ctx *ai.ToolContext, input WeatherInput) (WeatherData, error) { + // Get weather data (simulated) + return simulateWeather(input.Location), nil + }, + ) + + // Create system message + systemMsg := ai.NewTextMessage(ai.RoleSystem, + "You are a helpful assistant that can look up weather. "+ + "When providing weather information, use the appropriate tool.") + + // Create user message + userMsg := ai.NewTextMessage(ai.RoleUser, + "I'd like to know the weather in Tokyo.") + + // Generate response with tools + fmt.Println("Generating response with weather tool...") + + resp, err := genkit.Generate(ctx, g, + ai.WithModel(model), + ai.WithMessages(systemMsg, userMsg), + ai.WithTools( + ai.ToolName(weatherTool.Name()), + ), + ai.WithToolChoice(ai.ToolChoiceAuto), + ) + + if err != nil { + fmt.Printf("Error: %v\n", err) + return + } + + // Print the final response + fmt.Println("\n----- Final Response -----") + fmt.Printf("%s\n", resp.Text()) + fmt.Println("--------------------------") +} + +// simulateWeather returns simulated weather data for a location +func simulateWeather(location string) WeatherData { + // In a real app, this would call a weather API + // For demonstration, we'll return mock data + tempC := 22.5 + if location == "Tokyo" || location == "Tokyo, Japan" { + tempC = 24.0 + } else if location == "Paris" || location == "Paris, France" { + tempC = 18.5 + } else if location == "New York" || location == "New York, USA" { + tempC = 15.0 + } + + conditions := []string{"Sunny", "Partly Cloudy", "Cloudy", "Rainy", "Stormy"} + condition := conditions[time.Now().Unix()%int64(len(conditions))] + + return WeatherData{ + Location: location, + TempC: tempC, + TempF: tempC*9/5 + 32, + Condition: condition, + } +}