diff --git a/.github/scripts/modules/ollama/install-dependencies.sh b/.github/scripts/modules/ollama/install-dependencies.sh new file mode 100755 index 0000000000..d699158806 --- /dev/null +++ b/.github/scripts/modules/ollama/install-dependencies.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +curl -fsSL https://ollama.com/install.sh | sh + +# kill any running ollama process so that the tests can start from a clean state +sudo systemctl stop ollama.service diff --git a/.github/workflows/ci-test-go.yml b/.github/workflows/ci-test-go.yml index 82be78435f..0d6af15880 100644 --- a/.github/workflows/ci-test-go.yml +++ b/.github/workflows/ci-test-go.yml @@ -107,6 +107,16 @@ jobs: working-directory: ./${{ inputs.project-directory }} run: go build + - name: Install dependencies + shell: bash + run: | + SCRIPT_PATH="./.github/scripts/${{ inputs.project-directory }}/install-dependencies.sh" + if [ -f "$SCRIPT_PATH" ]; then + $SCRIPT_PATH + else + echo "No dependencies script found at $SCRIPT_PATH - skipping installation" + fi + - name: go test # only run tests on linux, there are a number of things that won't allow the tests to run on anything else # many (maybe, all?) images used can only be build on Linux, they don't have Windows in their manifest, and diff --git a/docs/modules/ollama.md b/docs/modules/ollama.md index c16e612142..18cb08b47a 100644 --- a/docs/modules/ollama.md +++ b/docs/modules/ollama.md @@ -16,10 +16,15 @@ go get github.com/testcontainers/testcontainers-go/modules/ollama ## Usage example +The module allows you to run the Ollama container or the local Ollama binary. + [Creating a Ollama container](../../modules/ollama/examples_test.go) inside_block:runOllamaContainer +[Running the local Ollama binary](../../modules/ollama/examples_test.go) inside_block:localOllama +If the local Ollama binary fails to execute, the module will fallback to the container version of Ollama. + ## Module Reference ### Run function @@ -48,6 +53,51 @@ When starting the Ollama container, you can pass options in a variadic way to co If you need to set a different Ollama Docker image, you can set a valid Docker image as the second argument in the `Run` function. E.g. `Run(context.Background(), "ollama/ollama:0.1.25")`. +#### Use Local + +- Not available until the next release of testcontainers-go :material-tag: main + +!!!warning + Please make sure the local Ollama binary is not running when using the local version of the module: + Ollama can be started as a system service, or as part of the Ollama application, + and interacting with the logs of a running Ollama process not managed by the module is not supported. + +If you need to run the local Ollama binary, you can set the `UseLocal` option in the `Run` function. +This option accepts a list of environment variables as a string, that will be applied to the Ollama binary when executing commands. + +E.g. `Run(context.Background(), "ollama/ollama:0.1.25", WithUseLocal("OLLAMA_DEBUG=true"))`. + +All the container methods are available when using the local Ollama binary, but will be executed locally instead of inside the container. +Please consider the following differences when using the local Ollama binary: + +- The local Ollama binary will create a log file in the current working directory, identified by the session ID. E.g. `local-ollama-.log`. It's possible to set the log file name using the `OLLAMA_LOGFILE` environment variable. So if you're running Ollama yourself, from the Ollama app, or the standalone binary, you could use this environment variable to set the same log file name. + - For the Ollama app, the default log file resides in the `$HOME/.ollama/logs/server.log`. + - For the standalone binary, you should start it redirecting the logs to a file. E.g. `ollama serve > /tmp/ollama.log 2>&1`. +- `ConnectionString` returns the connection string to connect to the local Ollama binary started by the module instead of the container. +- `ContainerIP` returns the bound host IP `127.0.0.1` by default. +- `ContainerIPs` returns the bound host IP `["127.0.0.1"]` by default. +- `CopyToContainer`, `CopyDirToContainer`, `CopyFileToContainer` and `CopyFileFromContainer` return an error if called. +- `GetLogProductionErrorChannel` returns a nil channel. +- `Endpoint` returns the endpoint to connect to the local Ollama binary started by the module instead of the container. +- `Exec` passes the command to the local Ollama binary started by the module instead of inside the container. First argument is the command to execute, and the second argument is the list of arguments, else, an error is returned. +- `GetContainerID` returns the container ID of the local Ollama binary started by the module instead of the container, which maps to `local-ollama-`. +- `Host` returns the bound host IP `127.0.0.1` by default. +- `Inspect` returns a ContainerJSON with the state of the local Ollama binary started by the module. +- `IsRunning` returns true if the local Ollama binary process started by the module is running. +- `Logs` returns the logs from the local Ollama binary started by the module instead of the container. +- `MappedPort` returns the port mapping for the local Ollama binary started by the module instead of the container. +- `Start` starts the local Ollama binary process. +- `State` returns the current state of the local Ollama binary process, `stopped` or `running`. +- `Stop` stops the local Ollama binary process. +- `Terminate` calls the `Stop` method and then removes the log file. + +The local Ollama binary will create a log file in the current working directory, and it will be available in the container's `Logs` method. + +!!!info + The local Ollama binary will use the `OLLAMA_HOST` environment variable to set the host and port to listen on. + If the environment variable is not set, it will default to `localhost:0` + which bind to a loopback address on an ephemeral port to avoid port conflicts. + {% include "../features/common_functional_options.md" %} ### Container Methods diff --git a/modules/ollama/examples_test.go b/modules/ollama/examples_test.go index 741db846be..188be45bbb 100644 --- a/modules/ollama/examples_test.go +++ b/modules/ollama/examples_test.go @@ -173,3 +173,73 @@ func ExampleRun_withModel_llama2_langchain() { // Intentionally not asserting the output, as we don't want to run this example in the tests. } + +func ExampleRun_withLocal() { + ctx := context.Background() + + // localOllama { + ollamaContainer, err := tcollama.Run(ctx, "ollama/ollama:0.3.13", tcollama.WithUseLocal("OLLAMA_DEBUG=true")) + defer func() { + if err := testcontainers.TerminateContainer(ollamaContainer); err != nil { + log.Printf("failed to terminate container: %s", err) + } + }() + if err != nil { + log.Printf("failed to start container: %s", err) + return + } + // } + + model := "llama3.2:1b" + + _, _, err = ollamaContainer.Exec(ctx, []string{"ollama", "pull", model}) + if err != nil { + log.Printf("failed to pull model %s: %s", model, err) + return + } + + _, _, err = ollamaContainer.Exec(ctx, []string{"ollama", "run", model}) + if err != nil { + log.Printf("failed to run model %s: %s", model, err) + return + } + + connectionStr, err := ollamaContainer.ConnectionString(ctx) + if err != nil { + log.Printf("failed to get connection string: %s", err) + return + } + + var llm *langchainollama.LLM + if llm, err = langchainollama.New( + langchainollama.WithModel(model), + langchainollama.WithServerURL(connectionStr), + ); err != nil { + log.Printf("failed to create langchain ollama: %s", err) + return + } + + completion, err := llm.Call( + context.Background(), + "how can Testcontainers help with testing?", + llms.WithSeed(42), // the lower the seed, the more deterministic the completion + llms.WithTemperature(0.0), // the lower the temperature, the more creative the completion + ) + if err != nil { + log.Printf("failed to create langchain ollama: %s", err) + return + } + + words := []string{ + "easy", "isolation", "consistency", + } + lwCompletion := strings.ToLower(completion) + + for _, word := range words { + if strings.Contains(lwCompletion, word) { + fmt.Println(true) + } + } + + // Intentionally not asserting the output, as we don't want to run this example in the tests. +} diff --git a/modules/ollama/go.mod b/modules/ollama/go.mod index e22b801031..2aab83b978 100644 --- a/modules/ollama/go.mod +++ b/modules/ollama/go.mod @@ -4,6 +4,7 @@ go 1.22 require ( github.com/docker/docker v27.1.1+incompatible + github.com/docker/go-connections v0.5.0 github.com/google/uuid v1.6.0 github.com/stretchr/testify v1.9.0 github.com/testcontainers/testcontainers-go v0.34.0 @@ -22,7 +23,6 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/distribution/reference v0.6.0 // indirect github.com/dlclark/regexp2 v1.8.1 // indirect - github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.1 // indirect diff --git a/modules/ollama/local.go b/modules/ollama/local.go new file mode 100644 index 0000000000..371cbb60c5 --- /dev/null +++ b/modules/ollama/local.go @@ -0,0 +1,750 @@ +package ollama + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "io/fs" + "net" + "os" + "os/exec" + "reflect" + "strings" + "sync" + "syscall" + "time" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/network" + "github.com/docker/docker/errdefs" + "github.com/docker/docker/pkg/stdcopy" + "github.com/docker/go-connections/nat" + + "github.com/testcontainers/testcontainers-go" + tcexec "github.com/testcontainers/testcontainers-go/exec" + "github.com/testcontainers/testcontainers-go/wait" +) + +const ( + localPort = "11434" + localBinary = "ollama" + localServeArg = "serve" + localLogRegex = `Listening on (.*:\d+) \(version\s(.*)\)` + localNamePrefix = "local-ollama" + localHostVar = "OLLAMA_HOST" + localLogVar = "OLLAMA_LOGFILE" +) + +var ( + // Ensure localProcess implements the required interfaces. + _ testcontainers.Container = (*localProcess)(nil) + _ testcontainers.ContainerCustomizer = (*localProcess)(nil) + + // defaultStopTimeout is the default timeout for stopping the local Ollama process. + defaultStopTimeout = time.Second * 5 + + // zeroTime is the zero time value. + zeroTime time.Time +) + +// localProcess emulates the Ollama container using a local process to improve performance. +type localProcess struct { + sessionID string + + // env is the combined environment variables passed to the Ollama binary. + env []string + + // cmd is the command that runs the Ollama binary, not valid externally if nil. + cmd *exec.Cmd + + // logName and logFile are the file where the Ollama logs are written. + logName string + logFile *os.File + + // host, port and version are extracted from log on startup. + host string + port string + version string + + // waitFor is the strategy to wait for the process to be ready. + waitFor wait.Strategy + + // done is closed when the process is finished. + done chan struct{} + + // wg is used to wait for the process to finish. + wg sync.WaitGroup + + // startedAt is the time when the process started. + startedAt time.Time + + // mtx is used to synchronize access to the process state fields below. + mtx sync.Mutex + + // finishedAt is the time when the process finished. + finishedAt time.Time + + // exitErr is the error returned by the process. + exitErr error + + // binary is the name of the Ollama binary. + binary string +} + +// runLocal returns an OllamaContainer that uses the local Ollama binary instead of using a Docker container. +func (c *localProcess) run(ctx context.Context, req testcontainers.GenericContainerRequest) (*OllamaContainer, error) { + if err := c.validateRequest(req); err != nil { + return nil, fmt.Errorf("validate request: %w", err) + } + + // Apply the updated details from the request. + c.waitFor = req.WaitingFor + c.env = c.env[:0] + for k, v := range req.Env { + c.env = append(c.env, k+"="+v) + if k == localLogVar { + c.logName = v + } + } + + err := c.Start(ctx) + var container *OllamaContainer + if c.cmd != nil { + container = &OllamaContainer{Container: c} + } + + if err != nil { + return container, fmt.Errorf("start ollama: %w", err) + } + + return container, nil +} + +// validateRequest checks that req is valid for the local Ollama binary. +func (c *localProcess) validateRequest(req testcontainers.GenericContainerRequest) error { + var errs []error + if req.WaitingFor == nil { + errs = append(errs, errors.New("ContainerRequest.WaitingFor must be set")) + } + + if !req.Started { + errs = append(errs, errors.New("Started must be true")) + } + + if !reflect.DeepEqual(req.ExposedPorts, []string{localPort + "/tcp"}) { + errs = append(errs, fmt.Errorf("ContainerRequest.ExposedPorts must be %s/tcp got: %s", localPort, req.ExposedPorts)) + } + + // Validate the image and extract the binary name. + // The image must be in the format "[/][:latest]". + if binary := req.Image; binary != "" { + // Check if the version is "latest" or not specified. + if idx := strings.IndexByte(binary, ':'); idx != -1 { + if binary[idx+1:] != "latest" { + errs = append(errs, fmt.Errorf(`ContainerRequest.Image version must be blank or "latest", got: %q`, binary[idx+1:])) + } + binary = binary[:idx] + } + + // Trim the path if present. + if idx := strings.LastIndexByte(binary, '/'); idx != -1 { + binary = binary[idx+1:] + } + + if _, err := exec.LookPath(binary); err != nil { + errs = append(errs, fmt.Errorf("invalid image %q: %w", req.Image, err)) + } else { + c.binary = binary + } + } + + // Reset fields we support to their zero values. + req.Env = nil + req.ExposedPorts = nil + req.WaitingFor = nil + req.Image = "" + req.Started = false + req.Logger = nil // We don't need the logger. + + parts := make([]string, 0, 3) + value := reflect.ValueOf(req) + typ := value.Type() + fields := reflect.VisibleFields(typ) + for _, f := range fields { + field := value.FieldByIndex(f.Index) + if field.Kind() == reflect.Struct { + // Only check the leaf fields. + continue + } + + if !field.IsZero() { + parts = parts[:0] + for i := range f.Index { + parts = append(parts, typ.FieldByIndex(f.Index[:i+1]).Name) + } + errs = append(errs, fmt.Errorf("unsupported field: %s = %q", strings.Join(parts, "."), field)) + } + } + + return errors.Join(errs...) +} + +// Start implements testcontainers.Container interface for the local Ollama binary. +func (c *localProcess) Start(ctx context.Context) error { + if c.IsRunning() { + return errors.New("already running") + } + + cmd := exec.CommandContext(ctx, c.binary, localServeArg) + cmd.Env = c.env + + var err error + c.logFile, err = os.Create(c.logName) + if err != nil { + return fmt.Errorf("create ollama log file: %w", err) + } + + // Multiplex stdout and stderr to the log file matching the Docker API. + cmd.Stdout = stdcopy.NewStdWriter(c.logFile, stdcopy.Stdout) + cmd.Stderr = stdcopy.NewStdWriter(c.logFile, stdcopy.Stderr) + + // Run the ollama serve command in background. + if err = cmd.Start(); err != nil { + return fmt.Errorf("start ollama serve: %w", errors.Join(err, c.cleanupLog())) + } + + // Past this point, the process was started successfully. + c.cmd = cmd + c.startedAt = time.Now() + + // Reset the details to allow multiple start / stop cycles. + c.done = make(chan struct{}) + c.mtx.Lock() + c.finishedAt = zeroTime + c.exitErr = nil + c.mtx.Unlock() + + // Wait for the process to finish in a goroutine. + c.wg.Add(1) + go func() { + defer func() { + c.wg.Done() + close(c.done) + }() + + err := c.cmd.Wait() + c.mtx.Lock() + defer c.mtx.Unlock() + if err != nil { + c.exitErr = fmt.Errorf("process wait: %w", err) + } + c.finishedAt = time.Now() + }() + + if err = c.waitStrategy(ctx); err != nil { + return fmt.Errorf("wait strategy: %w", err) + } + + return nil +} + +// waitStrategy waits until the Ollama process is ready. +func (c *localProcess) waitStrategy(ctx context.Context) error { + if err := c.waitFor.WaitUntilReady(ctx, c); err != nil { + logs, lerr := c.Logs(ctx) + if lerr != nil { + return errors.Join(err, lerr) + } + defer logs.Close() + + var stderr, stdout bytes.Buffer + _, cerr := stdcopy.StdCopy(&stdout, &stderr, logs) + + return fmt.Errorf( + "%w (stdout: %s, stderr: %s)", + errors.Join(err, cerr), + strings.TrimSpace(stdout.String()), + strings.TrimSpace(stderr.String()), + ) + } + + return nil +} + +// extractLogDetails extracts the listening address and version from the log. +func (c *localProcess) extractLogDetails(pattern string, submatches [][][]byte) error { + var err error + for _, matches := range submatches { + if len(matches) != 3 { + err = fmt.Errorf("`%s` matched %d times, expected %d", pattern, len(matches), 3) + continue + } + + c.host, c.port, err = net.SplitHostPort(string(matches[1])) + if err != nil { + return wait.NewPermanentError(fmt.Errorf("split host port: %w", err)) + } + + // Set OLLAMA_HOST variable to the extracted host so Exec can use it. + c.env = append(c.env, localHostVar+"="+string(matches[1])) + c.version = string(matches[2]) + + return nil + } + + if err != nil { + // Return the last error encountered. + return err + } + + return fmt.Errorf("address and version not found: `%s` no matches", pattern) +} + +// ContainerIP implements testcontainers.Container interface for the local Ollama binary. +func (c *localProcess) ContainerIP(ctx context.Context) (string, error) { + return c.host, nil +} + +// ContainerIPs returns a slice with the IP address of the local Ollama binary. +func (c *localProcess) ContainerIPs(ctx context.Context) ([]string, error) { + return []string{c.host}, nil +} + +// CopyToContainer implements testcontainers.Container interface for the local Ollama binary. +// Returns [errors.ErrUnsupported]. +func (c *localProcess) CopyToContainer(ctx context.Context, fileContent []byte, containerFilePath string, fileMode int64) error { + return errors.ErrUnsupported +} + +// CopyDirToContainer implements testcontainers.Container interface for the local Ollama binary. +// Returns [errors.ErrUnsupported]. +func (c *localProcess) CopyDirToContainer(ctx context.Context, hostDirPath string, containerParentPath string, fileMode int64) error { + return errors.ErrUnsupported +} + +// CopyFileToContainer implements testcontainers.Container interface for the local Ollama binary. +// Returns [errors.ErrUnsupported]. +func (c *localProcess) CopyFileToContainer(ctx context.Context, hostFilePath string, containerFilePath string, fileMode int64) error { + return errors.ErrUnsupported +} + +// CopyFileFromContainer implements testcontainers.Container interface for the local Ollama binary. +// Returns [errors.ErrUnsupported]. +func (c *localProcess) CopyFileFromContainer(ctx context.Context, filePath string) (io.ReadCloser, error) { + return nil, errors.ErrUnsupported +} + +// GetLogProductionErrorChannel implements testcontainers.Container interface for the local Ollama binary. +// It returns a nil channel because the local Ollama binary doesn't have a production error channel. +func (c *localProcess) GetLogProductionErrorChannel() <-chan error { + return nil +} + +// Exec implements testcontainers.Container interface for the local Ollama binary. +// It executes a command using the local Ollama binary and returns the exit status +// of the executed command, an [io.Reader] containing the combined stdout and stderr, +// and any encountered error. +// +// Reading directly from the [io.Reader] may result in unexpected bytes due to custom +// stream multiplexing headers. Use [tcexec.Multiplexed] option to read the combined output +// without the multiplexing headers. +// Alternatively, to separate the stdout and stderr from [io.Reader] and interpret these +// headers properly, [stdcopy.StdCopy] from the Docker API should be used. +func (c *localProcess) Exec(ctx context.Context, cmd []string, options ...tcexec.ProcessOption) (int, io.Reader, error) { + if len(cmd) == 0 { + return 1, nil, errors.New("no command provided") + } else if cmd[0] != c.binary { + return 1, nil, fmt.Errorf("command %q: %w", cmd[0], errors.ErrUnsupported) + } + + command := exec.CommandContext(ctx, cmd[0], cmd[1:]...) + command.Env = c.env + + // Multiplex stdout and stderr to the buffer so they can be read separately later. + var buf bytes.Buffer + command.Stdout = stdcopy.NewStdWriter(&buf, stdcopy.Stdout) + command.Stderr = stdcopy.NewStdWriter(&buf, stdcopy.Stderr) + + // Use process options to customize the command execution + // emulating the Docker API behaviour. + processOptions := tcexec.NewProcessOptions(cmd) + processOptions.Reader = &buf + for _, o := range options { + o.Apply(processOptions) + } + + if err := c.validateExecOptions(processOptions.ExecConfig); err != nil { + return 1, nil, fmt.Errorf("validate exec option: %w", err) + } + + if !processOptions.ExecConfig.AttachStderr { + command.Stderr = io.Discard + } + if !processOptions.ExecConfig.AttachStdout { + command.Stdout = io.Discard + } + if processOptions.ExecConfig.AttachStdin { + command.Stdin = os.Stdin + } + + command.Dir = processOptions.ExecConfig.WorkingDir + command.Env = append(command.Env, processOptions.ExecConfig.Env...) + + if err := command.Run(); err != nil { + return command.ProcessState.ExitCode(), processOptions.Reader, fmt.Errorf("exec %v: %w", cmd, err) + } + + return command.ProcessState.ExitCode(), processOptions.Reader, nil +} + +// validateExecOptions checks if the given exec options are supported by the local Ollama binary. +func (c *localProcess) validateExecOptions(options container.ExecOptions) error { + var errs []error + if options.User != "" { + errs = append(errs, fmt.Errorf("user: %w", errors.ErrUnsupported)) + } + if options.Privileged { + errs = append(errs, fmt.Errorf("privileged: %w", errors.ErrUnsupported)) + } + if options.Tty { + errs = append(errs, fmt.Errorf("tty: %w", errors.ErrUnsupported)) + } + if options.Detach { + errs = append(errs, fmt.Errorf("detach: %w", errors.ErrUnsupported)) + } + if options.DetachKeys != "" { + errs = append(errs, fmt.Errorf("detach keys: %w", errors.ErrUnsupported)) + } + + return errors.Join(errs...) +} + +// Inspect implements testcontainers.Container interface for the local Ollama binary. +// It returns a ContainerJSON with the state of the local Ollama binary. +func (c *localProcess) Inspect(ctx context.Context) (*types.ContainerJSON, error) { + state, err := c.State(ctx) + if err != nil { + return nil, fmt.Errorf("state: %w", err) + } + + return &types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + ID: c.GetContainerID(), + Name: localNamePrefix + "-" + c.sessionID, + State: state, + }, + Config: &container.Config{ + Image: localNamePrefix + ":" + c.version, + ExposedPorts: nat.PortSet{ + nat.Port(localPort + "/tcp"): struct{}{}, + }, + Hostname: c.host, + Entrypoint: []string{c.binary, localServeArg}, + }, + NetworkSettings: &types.NetworkSettings{ + Networks: map[string]*network.EndpointSettings{}, + NetworkSettingsBase: types.NetworkSettingsBase{ + Bridge: "bridge", + Ports: nat.PortMap{ + nat.Port(localPort + "/tcp"): { + {HostIP: c.host, HostPort: c.port}, + }, + }, + }, + DefaultNetworkSettings: types.DefaultNetworkSettings{ + IPAddress: c.host, + }, + }, + }, nil +} + +// IsRunning implements testcontainers.Container interface for the local Ollama binary. +// It returns true if the local Ollama process is running, false otherwise. +func (c *localProcess) IsRunning() bool { + if c.startedAt.IsZero() { + // The process hasn't started yet. + return false + } + + select { + case <-c.done: + // The process exited. + return false + default: + // The process is still running. + return true + } +} + +// Logs implements testcontainers.Container interface for the local Ollama binary. +// It returns the logs from the local Ollama binary. +func (c *localProcess) Logs(ctx context.Context) (io.ReadCloser, error) { + file, err := os.Open(c.logFile.Name()) + if err != nil { + return nil, fmt.Errorf("open log file: %w", err) + } + + return file, nil +} + +// State implements testcontainers.Container interface for the local Ollama binary. +// It returns the current state of the Ollama process, simulating a container state. +func (c *localProcess) State(ctx context.Context) (*types.ContainerState, error) { + c.mtx.Lock() + defer c.mtx.Unlock() + + if !c.IsRunning() { + state := &types.ContainerState{ + Status: "exited", + ExitCode: c.cmd.ProcessState.ExitCode(), + StartedAt: c.startedAt.Format(time.RFC3339Nano), + FinishedAt: c.finishedAt.Format(time.RFC3339Nano), + } + if c.exitErr != nil { + state.Error = c.exitErr.Error() + } + + return state, nil + } + + // Setting the Running field because it's required by the wait strategy + // to check if the given log message is present. + return &types.ContainerState{ + Status: "running", + Running: true, + Pid: c.cmd.Process.Pid, + StartedAt: c.startedAt.Format(time.RFC3339Nano), + FinishedAt: c.finishedAt.Format(time.RFC3339Nano), + }, nil +} + +// Stop implements testcontainers.Container interface for the local Ollama binary. +// It gracefully stops the local Ollama process. +func (c *localProcess) Stop(ctx context.Context, d *time.Duration) error { + if err := c.cmd.Process.Signal(syscall.SIGTERM); err != nil { + return fmt.Errorf("signal ollama: %w", err) + } + + if d != nil { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, *d) + defer cancel() + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.done: + // The process exited. + c.mtx.Lock() + defer c.mtx.Unlock() + + return c.exitErr + } +} + +// Terminate implements testcontainers.Container interface for the local Ollama binary. +// It stops the local Ollama process, removing the log file. +func (c *localProcess) Terminate(ctx context.Context) error { + // First try to stop gracefully. + if err := c.Stop(ctx, &defaultStopTimeout); !c.isCleanupSafe(err) { + return fmt.Errorf("stop: %w", err) + } + + if c.IsRunning() { + // Still running, force kill. + if err := c.cmd.Process.Kill(); !c.isCleanupSafe(err) { + return fmt.Errorf("kill: %w", err) + } + + // Wait for the process to exit so capture any error. + c.wg.Wait() + } + + c.mtx.Lock() + exitErr := c.exitErr + c.mtx.Unlock() + + return errors.Join(exitErr, c.cleanupLog()) +} + +// cleanupLog closes the log file and removes it. +func (c *localProcess) cleanupLog() error { + if c.logFile == nil { + return nil + } + + var errs []error + if err := c.logFile.Close(); err != nil { + errs = append(errs, fmt.Errorf("close log: %w", err)) + } + + if err := os.Remove(c.logFile.Name()); err != nil && !errors.Is(err, fs.ErrNotExist) { + errs = append(errs, fmt.Errorf("remove log: %w", err)) + } + + c.logFile = nil // Prevent double cleanup. + + return errors.Join(errs...) +} + +// Endpoint implements testcontainers.Container interface for the local Ollama binary. +// It returns proto://host:port string for the Ollama port. +// It returns just host:port if proto is blank. +func (c *localProcess) Endpoint(ctx context.Context, proto string) (string, error) { + return c.PortEndpoint(ctx, localPort, proto) +} + +// GetContainerID implements testcontainers.Container interface for the local Ollama binary. +func (c *localProcess) GetContainerID() string { + return localNamePrefix + "-" + c.sessionID +} + +// Host implements testcontainers.Container interface for the local Ollama binary. +func (c *localProcess) Host(ctx context.Context) (string, error) { + return c.host, nil +} + +// MappedPort implements testcontainers.Container interface for the local Ollama binary. +func (c *localProcess) MappedPort(ctx context.Context, port nat.Port) (nat.Port, error) { + if port.Port() != localPort || port.Proto() != "tcp" { + return "", errdefs.NotFound(fmt.Errorf("port %q not found", port)) + } + + return nat.Port(c.port + "/tcp"), nil +} + +// Networks implements testcontainers.Container interface for the local Ollama binary. +// It returns a nil slice. +func (c *localProcess) Networks(ctx context.Context) ([]string, error) { + return nil, nil +} + +// NetworkAliases implements testcontainers.Container interface for the local Ollama binary. +// It returns a nil map. +func (c *localProcess) NetworkAliases(ctx context.Context) (map[string][]string, error) { + return nil, nil +} + +// PortEndpoint implements testcontainers.Container interface for the local Ollama binary. +// It returns proto://host:port string for the given exposed port. +// It returns just host:port if proto is blank. +func (c *localProcess) PortEndpoint(ctx context.Context, port nat.Port, proto string) (string, error) { + host, err := c.Host(ctx) + if err != nil { + return "", fmt.Errorf("host: %w", err) + } + + outerPort, err := c.MappedPort(ctx, port) + if err != nil { + return "", fmt.Errorf("mapped port: %w", err) + } + + if proto != "" { + proto += "://" + } + + return fmt.Sprintf("%s%s:%s", proto, host, outerPort.Port()), nil +} + +// SessionID implements testcontainers.Container interface for the local Ollama binary. +func (c *localProcess) SessionID() string { + return c.sessionID +} + +// Deprecated: it will be removed in the next major release. +// FollowOutput is not implemented for the local Ollama binary. +// It panics if called. +func (c *localProcess) FollowOutput(consumer testcontainers.LogConsumer) { + panic("not implemented") +} + +// Deprecated: use c.Inspect(ctx).NetworkSettings.Ports instead. +// Ports gets the exposed ports for the container. +func (c *localProcess) Ports(ctx context.Context) (nat.PortMap, error) { + inspect, err := c.Inspect(ctx) + if err != nil { + return nil, err + } + + return inspect.NetworkSettings.Ports, nil +} + +// Deprecated: it will be removed in the next major release. +// StartLogProducer implements testcontainers.Container interface for the local Ollama binary. +// It returns an error because the local Ollama binary doesn't have a log producer. +func (c *localProcess) StartLogProducer(context.Context, ...testcontainers.LogProductionOption) error { + return errors.ErrUnsupported +} + +// Deprecated: it will be removed in the next major release. +// StopLogProducer implements testcontainers.Container interface for the local Ollama binary. +// It returns an error because the local Ollama binary doesn't have a log producer. +func (c *localProcess) StopLogProducer() error { + return errors.ErrUnsupported +} + +// Deprecated: Use c.Inspect(ctx).Name instead. +// Name returns the name for the local Ollama binary. +func (c *localProcess) Name(context.Context) (string, error) { + return localNamePrefix + "-" + c.sessionID, nil +} + +// Customize implements the [testcontainers.ContainerCustomizer] interface. +// It configures the environment variables set by [WithUseLocal] and sets up +// the wait strategy to extract the host, port and version from the log. +func (c *localProcess) Customize(req *testcontainers.GenericContainerRequest) error { + // Replace the default host port strategy with one that waits for a log entry + // and extracts the host, port and version from it. + if err := wait.Walk(&req.WaitingFor, func(w wait.Strategy) error { + if _, ok := w.(*wait.HostPortStrategy); ok { + return wait.VisitRemove + } + + return nil + }); err != nil { + return fmt.Errorf("walk strategies: %w", err) + } + + logStrategy := wait.ForLog(localLogRegex).Submatch(c.extractLogDetails) + if req.WaitingFor == nil { + req.WaitingFor = logStrategy + } else { + req.WaitingFor = wait.ForAll(req.WaitingFor, logStrategy) + } + + // Setup the environment variables using a random port by default + // to avoid conflicts. + osEnv := os.Environ() + env := make(map[string]string, len(osEnv)+len(c.env)+1) + env[localHostVar] = "localhost:0" + for _, kv := range append(osEnv, c.env...) { + parts := strings.SplitN(kv, "=", 2) + if len(parts) != 2 { + return fmt.Errorf("invalid environment variable: %q", kv) + } + + env[parts[0]] = parts[1] + } + + return testcontainers.WithEnv(env)(req) +} + +// isCleanupSafe reports whether all errors in err's tree are one of the +// following, so can safely be ignored: +// - nil +// - os: process already finished +// - context deadline exceeded +func (c *localProcess) isCleanupSafe(err error) bool { + switch { + case err == nil, + errors.Is(err, os.ErrProcessDone), + errors.Is(err, context.DeadlineExceeded): + return true + default: + return false + } +} diff --git a/modules/ollama/local_test.go b/modules/ollama/local_test.go new file mode 100644 index 0000000000..3e0376d4de --- /dev/null +++ b/modules/ollama/local_test.go @@ -0,0 +1,636 @@ +package ollama_test + +import ( + "context" + "errors" + "io" + "io/fs" + "os" + "os/exec" + "path/filepath" + "regexp" + "testing" + "time" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/strslice" + "github.com/stretchr/testify/require" + + "github.com/testcontainers/testcontainers-go" + tcexec "github.com/testcontainers/testcontainers-go/exec" + "github.com/testcontainers/testcontainers-go/modules/ollama" +) + +const ( + testImage = "ollama/ollama:latest" + testNatPort = "11434/tcp" + testHost = "127.0.0.1" + testBinary = "ollama" +) + +var ( + // reLogDetails matches the log details of the local ollama binary and should match localLogRegex. + reLogDetails = regexp.MustCompile(`Listening on (.*:\d+) \(version\s(.*)\)`) + zeroTime = time.Time{}.Format(time.RFC3339Nano) +) + +func TestRun_local(t *testing.T) { + // check if the local ollama binary is available + if _, err := exec.LookPath(testBinary); err != nil { + t.Skip("local ollama binary not found, skipping") + } + + ctx := context.Background() + ollamaContainer, err := ollama.Run( + ctx, + testImage, + ollama.WithUseLocal("FOO=BAR"), + ) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + + t.Run("state", func(t *testing.T) { + state, err := ollamaContainer.State(ctx) + require.NoError(t, err) + require.NotEmpty(t, state.StartedAt) + require.NotEqual(t, zeroTime, state.StartedAt) + require.NotZero(t, state.Pid) + require.Equal(t, &types.ContainerState{ + Status: "running", + Running: true, + Pid: state.Pid, + StartedAt: state.StartedAt, + FinishedAt: time.Time{}.Format(time.RFC3339Nano), + }, state) + }) + + t.Run("connection-string", func(t *testing.T) { + connectionStr, err := ollamaContainer.ConnectionString(ctx) + require.NoError(t, err) + require.NotEmpty(t, connectionStr) + }) + + t.Run("container-id", func(t *testing.T) { + id := ollamaContainer.GetContainerID() + require.Equal(t, "local-ollama-"+testcontainers.SessionID(), id) + }) + + t.Run("container-ips", func(t *testing.T) { + ip, err := ollamaContainer.ContainerIP(ctx) + require.NoError(t, err) + require.Equal(t, testHost, ip) + + ips, err := ollamaContainer.ContainerIPs(ctx) + require.NoError(t, err) + require.Equal(t, []string{testHost}, ips) + }) + + t.Run("copy", func(t *testing.T) { + err := ollamaContainer.CopyToContainer(ctx, []byte("test"), "/tmp", 0o755) + require.Error(t, err) + + err = ollamaContainer.CopyDirToContainer(ctx, ".", "/tmp", 0o755) + require.Error(t, err) + + err = ollamaContainer.CopyFileToContainer(ctx, ".", "/tmp", 0o755) + require.Error(t, err) + + reader, err := ollamaContainer.CopyFileFromContainer(ctx, "/tmp") + require.Error(t, err) + require.Nil(t, reader) + }) + + t.Run("log-production-error-channel", func(t *testing.T) { + ch := ollamaContainer.GetLogProductionErrorChannel() + require.Nil(t, ch) + }) + + t.Run("endpoint", func(t *testing.T) { + endpoint, err := ollamaContainer.Endpoint(ctx, "") + require.NoError(t, err) + require.Contains(t, endpoint, testHost+":") + + endpoint, err = ollamaContainer.Endpoint(ctx, "http") + require.NoError(t, err) + require.Contains(t, endpoint, "http://"+testHost+":") + }) + + t.Run("is-running", func(t *testing.T) { + require.True(t, ollamaContainer.IsRunning()) + + err = ollamaContainer.Stop(ctx, nil) + require.NoError(t, err) + require.False(t, ollamaContainer.IsRunning()) + + // return it to the running state + err = ollamaContainer.Start(ctx) + require.NoError(t, err) + require.True(t, ollamaContainer.IsRunning()) + }) + + t.Run("host", func(t *testing.T) { + host, err := ollamaContainer.Host(ctx) + require.NoError(t, err) + require.Equal(t, testHost, host) + }) + + t.Run("inspect", func(t *testing.T) { + inspect, err := ollamaContainer.Inspect(ctx) + require.NoError(t, err) + + require.Equal(t, "local-ollama-"+testcontainers.SessionID(), inspect.ContainerJSONBase.ID) + require.Equal(t, "local-ollama-"+testcontainers.SessionID(), inspect.ContainerJSONBase.Name) + require.True(t, inspect.ContainerJSONBase.State.Running) + + require.NotEmpty(t, inspect.Config.Image) + _, exists := inspect.Config.ExposedPorts[testNatPort] + require.True(t, exists) + require.Equal(t, testHost, inspect.Config.Hostname) + require.Equal(t, strslice.StrSlice(strslice.StrSlice{testBinary, "serve"}), inspect.Config.Entrypoint) + + require.Empty(t, inspect.NetworkSettings.Networks) + require.Equal(t, "bridge", inspect.NetworkSettings.NetworkSettingsBase.Bridge) + + ports := inspect.NetworkSettings.NetworkSettingsBase.Ports + port, exists := ports[testNatPort] + require.True(t, exists) + require.Len(t, port, 1) + require.Equal(t, testHost, port[0].HostIP) + require.NotEmpty(t, port[0].HostPort) + }) + + t.Run("logfile", func(t *testing.T) { + file, err := os.Open("local-ollama-" + testcontainers.SessionID() + ".log") + require.NoError(t, err) + require.NoError(t, file.Close()) + }) + + t.Run("logs", func(t *testing.T) { + logs, err := ollamaContainer.Logs(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, logs.Close()) + }) + + bs, err := io.ReadAll(logs) + require.NoError(t, err) + require.Regexp(t, reLogDetails, string(bs)) + }) + + t.Run("mapped-port", func(t *testing.T) { + port, err := ollamaContainer.MappedPort(ctx, testNatPort) + require.NoError(t, err) + require.NotEmpty(t, port.Port()) + require.Equal(t, "tcp", port.Proto()) + }) + + t.Run("networks", func(t *testing.T) { + networks, err := ollamaContainer.Networks(ctx) + require.NoError(t, err) + require.Nil(t, networks) + }) + + t.Run("network-aliases", func(t *testing.T) { + aliases, err := ollamaContainer.NetworkAliases(ctx) + require.NoError(t, err) + require.Nil(t, aliases) + }) + + t.Run("port-endpoint", func(t *testing.T) { + endpoint, err := ollamaContainer.PortEndpoint(ctx, testNatPort, "") + require.NoError(t, err) + require.Regexp(t, regexp.MustCompile(`^127.0.0.1:\d+$`), endpoint) + + endpoint, err = ollamaContainer.PortEndpoint(ctx, testNatPort, "http") + require.NoError(t, err) + require.Regexp(t, regexp.MustCompile(`^http://127.0.0.1:\d+$`), endpoint) + }) + + t.Run("session-id", func(t *testing.T) { + require.Equal(t, testcontainers.SessionID(), ollamaContainer.SessionID()) + }) + + t.Run("stop-start", func(t *testing.T) { + d := time.Second * 5 + err := ollamaContainer.Stop(ctx, &d) + require.NoError(t, err) + + state, err := ollamaContainer.State(ctx) + require.NoError(t, err) + require.Equal(t, "exited", state.Status) + require.NotEmpty(t, state.StartedAt) + require.NotEqual(t, zeroTime, state.StartedAt) + require.NotEmpty(t, state.FinishedAt) + require.NotEqual(t, zeroTime, state.FinishedAt) + require.Zero(t, state.ExitCode) + + err = ollamaContainer.Start(ctx) + require.NoError(t, err) + + state, err = ollamaContainer.State(ctx) + require.NoError(t, err) + require.Equal(t, "running", state.Status) + + logs, err := ollamaContainer.Logs(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, logs.Close()) + }) + + bs, err := io.ReadAll(logs) + require.NoError(t, err) + require.Regexp(t, reLogDetails, string(bs)) + }) + + t.Run("start-start", func(t *testing.T) { + state, err := ollamaContainer.State(ctx) + require.NoError(t, err) + require.Equal(t, "running", state.Status) + + err = ollamaContainer.Start(ctx) + require.Error(t, err) + }) + + t.Run("terminate", func(t *testing.T) { + err := ollamaContainer.Terminate(ctx) + require.NoError(t, err) + + _, err = os.Stat("ollama-" + testcontainers.SessionID() + ".log") + require.ErrorIs(t, err, fs.ErrNotExist) + + state, err := ollamaContainer.State(ctx) + require.NoError(t, err) + require.NotEmpty(t, state.StartedAt) + require.NotEqual(t, zeroTime, state.StartedAt) + require.NotEmpty(t, state.FinishedAt) + require.NotEqual(t, zeroTime, state.FinishedAt) + require.Equal(t, &types.ContainerState{ + Status: "exited", + StartedAt: state.StartedAt, + FinishedAt: state.FinishedAt, + }, state) + }) + + t.Run("deprecated", func(t *testing.T) { + t.Run("ports", func(t *testing.T) { + inspect, err := ollamaContainer.Inspect(ctx) + require.NoError(t, err) + + ports, err := ollamaContainer.Ports(ctx) + require.NoError(t, err) + require.Equal(t, inspect.NetworkSettings.Ports, ports) + }) + + t.Run("follow-output", func(t *testing.T) { + require.Panics(t, func() { + ollamaContainer.FollowOutput(&testcontainers.StdoutLogConsumer{}) + }) + }) + + t.Run("start-log-producer", func(t *testing.T) { + err := ollamaContainer.StartLogProducer(ctx) + require.ErrorIs(t, err, errors.ErrUnsupported) + }) + + t.Run("stop-log-producer", func(t *testing.T) { + err := ollamaContainer.StopLogProducer() + require.ErrorIs(t, err, errors.ErrUnsupported) + }) + + t.Run("name", func(t *testing.T) { + name, err := ollamaContainer.Name(ctx) + require.NoError(t, err) + require.Equal(t, "local-ollama-"+testcontainers.SessionID(), name) + }) + }) +} + +func TestRun_localWithCustomLogFile(t *testing.T) { + ctx := context.Background() + logFile := filepath.Join(t.TempDir(), "server.log") + + t.Run("parent-env", func(t *testing.T) { + t.Setenv("OLLAMA_LOGFILE", logFile) + + ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal()) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + + logs, err := ollamaContainer.Logs(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, logs.Close()) + }) + + bs, err := io.ReadAll(logs) + require.NoError(t, err) + require.Regexp(t, reLogDetails, string(bs)) + + file, ok := logs.(*os.File) + require.True(t, ok) + require.Equal(t, logFile, file.Name()) + }) + + t.Run("local-env", func(t *testing.T) { + ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal("OLLAMA_LOGFILE="+logFile)) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + + logs, err := ollamaContainer.Logs(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, logs.Close()) + }) + + bs, err := io.ReadAll(logs) + require.NoError(t, err) + require.Regexp(t, reLogDetails, string(bs)) + + file, ok := logs.(*os.File) + require.True(t, ok) + require.Equal(t, logFile, file.Name()) + }) +} + +func TestRun_localWithCustomHost(t *testing.T) { + ctx := context.Background() + + t.Run("parent-env", func(t *testing.T) { + t.Setenv("OLLAMA_HOST", "127.0.0.1:1234") + + ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal()) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + + testRun_localWithCustomHost(ctx, t, ollamaContainer) + }) + + t.Run("local-env", func(t *testing.T) { + ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal("OLLAMA_HOST=127.0.0.1:1234")) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + + testRun_localWithCustomHost(ctx, t, ollamaContainer) + }) +} + +func testRun_localWithCustomHost(ctx context.Context, t *testing.T, ollamaContainer *ollama.OllamaContainer) { + t.Helper() + + t.Run("connection-string", func(t *testing.T) { + connectionStr, err := ollamaContainer.ConnectionString(ctx) + require.NoError(t, err) + require.Equal(t, "http://127.0.0.1:1234", connectionStr) + }) + + t.Run("endpoint", func(t *testing.T) { + endpoint, err := ollamaContainer.Endpoint(ctx, "http") + require.NoError(t, err) + require.Equal(t, "http://127.0.0.1:1234", endpoint) + }) + + t.Run("inspect", func(t *testing.T) { + inspect, err := ollamaContainer.Inspect(ctx) + require.NoError(t, err) + require.Regexp(t, regexp.MustCompile(`^local-ollama:\d+\.\d+\.\d+$`), inspect.Config.Image) + + _, exists := inspect.Config.ExposedPorts[testNatPort] + require.True(t, exists) + require.Equal(t, testHost, inspect.Config.Hostname) + require.Equal(t, strslice.StrSlice(strslice.StrSlice{testBinary, "serve"}), inspect.Config.Entrypoint) + + require.Empty(t, inspect.NetworkSettings.Networks) + require.Equal(t, "bridge", inspect.NetworkSettings.NetworkSettingsBase.Bridge) + + ports := inspect.NetworkSettings.NetworkSettingsBase.Ports + port, exists := ports[testNatPort] + require.True(t, exists) + require.Len(t, port, 1) + require.Equal(t, testHost, port[0].HostIP) + require.Equal(t, "1234", port[0].HostPort) + }) + + t.Run("logs", func(t *testing.T) { + logs, err := ollamaContainer.Logs(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, logs.Close()) + }) + + bs, err := io.ReadAll(logs) + require.NoError(t, err) + + require.Contains(t, string(bs), "Listening on 127.0.0.1:1234") + }) + + t.Run("mapped-port", func(t *testing.T) { + port, err := ollamaContainer.MappedPort(ctx, testNatPort) + require.NoError(t, err) + require.Equal(t, "1234", port.Port()) + require.Equal(t, "tcp", port.Proto()) + }) +} + +func TestRun_localExec(t *testing.T) { + // check if the local ollama binary is available + if _, err := exec.LookPath(testBinary); err != nil { + t.Skip("local ollama binary not found, skipping") + } + + ctx := context.Background() + + ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal()) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + + t.Run("no-command", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, nil) + require.Error(t, err) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("unsupported-command", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, []string{"cat", "/etc/hosts"}) + require.ErrorIs(t, err, errors.ErrUnsupported) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("unsupported-option-user", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.WithUser("root")) + require.ErrorIs(t, err, errors.ErrUnsupported) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("unsupported-option-privileged", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.ProcessOptionFunc(func(opts *tcexec.ProcessOptions) { + opts.ExecConfig.Privileged = true + })) + require.ErrorIs(t, err, errors.ErrUnsupported) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("unsupported-option-tty", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.ProcessOptionFunc(func(opts *tcexec.ProcessOptions) { + opts.ExecConfig.Tty = true + })) + require.ErrorIs(t, err, errors.ErrUnsupported) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("unsupported-option-detach", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.ProcessOptionFunc(func(opts *tcexec.ProcessOptions) { + opts.ExecConfig.Detach = true + })) + require.ErrorIs(t, err, errors.ErrUnsupported) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("unsupported-option-detach-keys", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.ProcessOptionFunc(func(opts *tcexec.ProcessOptions) { + opts.ExecConfig.DetachKeys = "ctrl-p,ctrl-q" + })) + require.ErrorIs(t, err, errors.ErrUnsupported) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("pull-and-run-model", func(t *testing.T) { + const model = "llama3.2:1b" + + code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "pull", model}) + require.NoError(t, err) + require.Zero(t, code) + + bs, err := io.ReadAll(r) + require.NoError(t, err) + require.Contains(t, string(bs), "success") + + code, r, err = ollamaContainer.Exec(ctx, []string{testBinary, "run", model}, tcexec.Multiplexed()) + require.NoError(t, err) + require.Zero(t, code) + + bs, err = io.ReadAll(r) + require.NoError(t, err) + require.Empty(t, bs) + + logs, err := ollamaContainer.Logs(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, logs.Close()) + }) + + bs, err = io.ReadAll(logs) + require.NoError(t, err) + require.Contains(t, string(bs), "llama runner started") + }) +} + +func TestRun_localValidateRequest(t *testing.T) { + // check if the local ollama binary is available + if _, err := exec.LookPath(testBinary); err != nil { + t.Skip("local ollama binary not found, skipping") + } + + ctx := context.Background() + t.Run("waiting-for-nil", func(t *testing.T) { + ollamaContainer, err := ollama.Run( + ctx, + testImage, + ollama.WithUseLocal("FOO=BAR"), + testcontainers.CustomizeRequestOption(func(req *testcontainers.GenericContainerRequest) error { + req.WaitingFor = nil + return nil + }), + ) + testcontainers.CleanupContainer(t, ollamaContainer) + require.EqualError(t, err, "validate request: ContainerRequest.WaitingFor must be set") + }) + + t.Run("started-false", func(t *testing.T) { + ollamaContainer, err := ollama.Run( + ctx, + testImage, + ollama.WithUseLocal("FOO=BAR"), + testcontainers.CustomizeRequestOption(func(req *testcontainers.GenericContainerRequest) error { + req.Started = false + return nil + }), + ) + testcontainers.CleanupContainer(t, ollamaContainer) + require.EqualError(t, err, "validate request: Started must be true") + }) + + t.Run("exposed-ports-empty", func(t *testing.T) { + ollamaContainer, err := ollama.Run( + ctx, + testImage, + ollama.WithUseLocal("FOO=BAR"), + testcontainers.CustomizeRequestOption(func(req *testcontainers.GenericContainerRequest) error { + req.ExposedPorts = req.ExposedPorts[:0] + return nil + }), + ) + testcontainers.CleanupContainer(t, ollamaContainer) + require.EqualError(t, err, "validate request: ContainerRequest.ExposedPorts must be 11434/tcp got: []") + }) + + t.Run("dockerfile-set", func(t *testing.T) { + ollamaContainer, err := ollama.Run( + ctx, + testImage, + ollama.WithUseLocal("FOO=BAR"), + testcontainers.CustomizeRequestOption(func(req *testcontainers.GenericContainerRequest) error { + req.Dockerfile = "FROM scratch" + return nil + }), + ) + testcontainers.CleanupContainer(t, ollamaContainer) + require.EqualError(t, err, "validate request: unsupported field: ContainerRequest.FromDockerfile.Dockerfile = \"FROM scratch\"") + }) + + t.Run("image-only", func(t *testing.T) { + ollamaContainer, err := ollama.Run( + ctx, + testBinary, + ollama.WithUseLocal(), + ) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + }) + + t.Run("image-path", func(t *testing.T) { + ollamaContainer, err := ollama.Run( + ctx, + "prefix-path/"+testBinary, + ollama.WithUseLocal(), + ) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + }) + + t.Run("image-bad-version", func(t *testing.T) { + ollamaContainer, err := ollama.Run( + ctx, + testBinary+":bad-version", + ollama.WithUseLocal(), + ) + testcontainers.CleanupContainer(t, ollamaContainer) + require.EqualError(t, err, `validate request: ContainerRequest.Image version must be blank or "latest", got: "bad-version"`) + }) + + t.Run("image-not-found", func(t *testing.T) { + ollamaContainer, err := ollama.Run( + ctx, + "ollama/ollama-not-found", + ollama.WithUseLocal(), + ) + testcontainers.CleanupContainer(t, ollamaContainer) + require.EqualError(t, err, `validate request: invalid image "ollama/ollama-not-found": exec: "ollama-not-found": executable file not found in $PATH`) + }) +} diff --git a/modules/ollama/ollama.go b/modules/ollama/ollama.go index 203d80103f..4d78fa171e 100644 --- a/modules/ollama/ollama.go +++ b/modules/ollama/ollama.go @@ -27,12 +27,12 @@ type OllamaContainer struct { func (c *OllamaContainer) ConnectionString(ctx context.Context) (string, error) { host, err := c.Host(ctx) if err != nil { - return "", err + return "", fmt.Errorf("host: %w", err) } port, err := c.MappedPort(ctx, "11434/tcp") if err != nil { - return "", err + return "", fmt.Errorf("mapped port: %w", err) } return fmt.Sprintf("http://%s:%d", host, port.Int()), nil @@ -43,6 +43,10 @@ func (c *OllamaContainer) ConnectionString(ctx context.Context) (string, error) // of the container into a new image with the given name, so it doesn't override existing images. // It should be used for creating an image that contains a loaded model. func (c *OllamaContainer) Commit(ctx context.Context, targetImage string) error { + if _, ok := c.Container.(*localProcess); ok { + return nil + } + cli, err := testcontainers.NewDockerClientWithOpts(context.Background()) if err != nil { return err @@ -80,27 +84,34 @@ func RunContainer(ctx context.Context, opts ...testcontainers.ContainerCustomize // Run creates an instance of the Ollama container type func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustomizer) (*OllamaContainer, error) { - req := testcontainers.ContainerRequest{ - Image: img, - ExposedPorts: []string{"11434/tcp"}, - WaitingFor: wait.ForListeningPort("11434/tcp").WithStartupTimeout(60 * time.Second), - } - - genericContainerReq := testcontainers.GenericContainerRequest{ - ContainerRequest: req, - Started: true, + req := testcontainers.GenericContainerRequest{ + ContainerRequest: testcontainers.ContainerRequest{ + Image: img, + ExposedPorts: []string{"11434/tcp"}, + WaitingFor: wait.ForListeningPort("11434/tcp").WithStartupTimeout(60 * time.Second), + }, + Started: true, } - // always request a GPU if the host supports it + // Always request a GPU if the host supports it. opts = append(opts, withGpu()) + var local *localProcess for _, opt := range opts { - if err := opt.Customize(&genericContainerReq); err != nil { + if err := opt.Customize(&req); err != nil { return nil, fmt.Errorf("customize: %w", err) } + if l, ok := opt.(*localProcess); ok { + local = l + } + } + + // Now we have processed all the options, we can check if we need to use the local process. + if local != nil { + return local.run(ctx, req) } - container, err := testcontainers.GenericContainer(ctx, genericContainerReq) + container, err := testcontainers.GenericContainer(ctx, req) var c *OllamaContainer if container != nil { c = &OllamaContainer{Container: container} diff --git a/modules/ollama/options.go b/modules/ollama/options.go index 605768a379..1cf29453fe 100644 --- a/modules/ollama/options.go +++ b/modules/ollama/options.go @@ -11,7 +11,7 @@ import ( var noopCustomizeRequestOption = func(req *testcontainers.GenericContainerRequest) error { return nil } // withGpu requests a GPU for the container, which could improve performance for some models. -// This option will be automaticall added to the Ollama container to check if the host supports nvidia. +// This option will be automatically added to the Ollama container to check if the host supports nvidia. func withGpu() testcontainers.CustomizeRequestOption { cli, err := testcontainers.NewDockerClientWithOpts(context.Background()) if err != nil { @@ -37,3 +37,30 @@ func withGpu() testcontainers.CustomizeRequestOption { } }) } + +// WithUseLocal starts a local Ollama process with the given environment in +// format KEY=VALUE instead of a Docker container, which can be more performant +// as it has direct access to the GPU. +// By default `OLLAMA_HOST=localhost:0` is set to avoid port conflicts. +// +// When using this option, the container request will be validated to ensure +// that only the options that are compatible with the local process are used. +// +// Supported fields are: +// - [testcontainers.GenericContainerRequest.Started] must be set to true +// - [testcontainers.GenericContainerRequest.ExposedPorts] must be set to ["11434/tcp"] +// - [testcontainers.ContainerRequest.WaitingFor] should not be changed from the default +// - [testcontainers.ContainerRequest.Image] used to determine the local process binary [/][:latest] if not blank. +// - [testcontainers.ContainerRequest.Env] applied to all local process executions +// - [testcontainers.GenericContainerRequest.Logger] is unused +// +// Any other leaf field not set to the type's zero value will result in an error. +func WithUseLocal(envKeyValues ...string) *localProcess { + sessionID := testcontainers.SessionID() + return &localProcess{ + sessionID: sessionID, + logName: localNamePrefix + "-" + sessionID + ".log", + env: envKeyValues, + binary: localBinary, + } +} diff --git a/modules/ollama/options_test.go b/modules/ollama/options_test.go new file mode 100644 index 0000000000..f842d15a17 --- /dev/null +++ b/modules/ollama/options_test.go @@ -0,0 +1,49 @@ +package ollama_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/ollama" +) + +func TestWithUseLocal(t *testing.T) { + req := testcontainers.GenericContainerRequest{} + + t.Run("keyVal/valid", func(t *testing.T) { + opt := ollama.WithUseLocal("OLLAMA_MODELS=/path/to/models") + err := opt.Customize(&req) + require.NoError(t, err) + require.Equal(t, "/path/to/models", req.Env["OLLAMA_MODELS"]) + }) + + t.Run("keyVal/invalid", func(t *testing.T) { + opt := ollama.WithUseLocal("OLLAMA_MODELS") + err := opt.Customize(&req) + require.Error(t, err) + }) + + t.Run("keyVal/valid/multiple", func(t *testing.T) { + opt := ollama.WithUseLocal("OLLAMA_MODELS=/path/to/models", "OLLAMA_HOST=localhost") + err := opt.Customize(&req) + require.NoError(t, err) + require.Equal(t, "/path/to/models", req.Env["OLLAMA_MODELS"]) + require.Equal(t, "localhost", req.Env["OLLAMA_HOST"]) + }) + + t.Run("keyVal/valid/multiple-equals", func(t *testing.T) { + opt := ollama.WithUseLocal("OLLAMA_MODELS=/path/to/models", "OLLAMA_HOST=localhost=127.0.0.1") + err := opt.Customize(&req) + require.NoError(t, err) + require.Equal(t, "/path/to/models", req.Env["OLLAMA_MODELS"]) + require.Equal(t, "localhost=127.0.0.1", req.Env["OLLAMA_HOST"]) + }) + + t.Run("keyVal/invalid/multiple", func(t *testing.T) { + opt := ollama.WithUseLocal("OLLAMA_MODELS=/path/to/models", "OLLAMA_HOST") + err := opt.Customize(&req) + require.Error(t, err) + }) +}