From 0d0ff48e27fc1883e90479b2b1bbb058f49b650b Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Tue, 17 Dec 2024 23:21:03 +0000 Subject: [PATCH] refactor(ollama): local process Refactor local process handling for Ollama using a container implementation avoiding the wrapping methods. This defaults to running the binary with an ephemeral port to avoid port conflicts. This behaviour can be overridden my setting OLLAMA_HOST either in the parent environment or in the values passed via WithUseLocal. Improve API compatibility with: - Multiplexed output streams - State reporting - Exec option processing - WaitingFor customisation Fix Container implementation: - Port management - Running checks - Terminate processing - Endpoint argument definition - Add missing methods - Consistent environment handling --- modules/ollama/local.go | 751 +++++++++++++++++------------- modules/ollama/local_test.go | 402 +++++++++++----- modules/ollama/local_unit_test.go | 55 --- modules/ollama/ollama.go | 49 +- modules/ollama/options.go | 41 +- 5 files changed, 782 insertions(+), 516 deletions(-) delete mode 100644 modules/ollama/local_unit_test.go diff --git a/modules/ollama/local.go b/modules/ollama/local.go index ce6be0cd0a..68934bc019 100644 --- a/modules/ollama/local.go +++ b/modules/ollama/local.go @@ -10,6 +10,7 @@ import ( "net" "os" "os/exec" + "regexp" "strings" "sync" "syscall" @@ -18,6 +19,8 @@ import ( "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/network" + "github.com/docker/docker/errdefs" + "github.com/docker/docker/pkg/stdcopy" "github.com/docker/go-connections/nat" "github.com/testcontainers/testcontainers-go" @@ -26,485 +29,603 @@ import ( ) const ( - localIP = "127.0.0.1" - localPort = "11434" + localPort = "11434" + localBinary = "ollama" + localServeArg = "serve" + localLogRegex = `Listening on (.*:\d+) \(version\s(.*)\)` + localNamePrefix = "local-ollama" + localHostVar = "OLLAMA_HOST" + localLogVar = "OLLAMA_LOGFILE" ) var ( - defaultStopTimeout = time.Second * 5 - errCopyAPIsNotSupported = errors.New("copy APIs are not supported for local Ollama binary") + // Ensure localContext implements the testcontainers.Container interface. + _ testcontainers.Container = &localProcess{} + + // defaultStopTimeout is the default timeout for stopping the local Ollama process. + defaultStopTimeout = time.Second * 5 + + // zeroTime is the zero time value. + zeroTime time.Time + + // reLogDetails is the regular expression to extract the listening address and version from the log. + reLogDetails = regexp.MustCompile(localLogRegex) ) -// localContext is a type holding the context for local Ollama executions. -type localContext struct { - env []string - serveCmd *exec.Cmd - logFile *os.File - mx sync.Mutex - host string - port string -} +// localProcess emulates the Ollama container using a local process to improve performance. +type localProcess struct { + sessionID string -// runLocal calls the local Ollama binary instead of using a Docker container. -func runLocal(ctx context.Context, env map[string]string) (*OllamaContainer, error) { - // Apply the environment variables to the command. - cmdEnv := make([]string, 0, len(env)*2) - for k, v := range env { - cmdEnv = append(cmdEnv, k+"="+v) - } + // env is the combined environment variables passed to the Ollama binary. + env []string - localCtx := &localContext{ - env: cmdEnv, - host: localIP, - port: localPort, - } + // cmd is the command that runs the Ollama binary, not valid externally if nil. + cmd *exec.Cmd - if envHost := os.Getenv("OLLAMA_HOST"); envHost != "" { - host, port, err := net.SplitHostPort(envHost) - if err != nil { - return nil, fmt.Errorf("invalid OLLAMA_HOST: %w", err) - } + // startedAt and finishedAt are the times when the process started and finished. + startedAt time.Time + finishedAt time.Time - localCtx.host = host - localCtx.port = port - } + // logName and logFile are the file where the Ollama logs are written. + logName string + logFile *os.File - c := &OllamaContainer{ - localCtx: localCtx, - } + // host, port and version are extracted from log on startup. + host string + port string + version string - err := c.startLocalOllama(ctx) - if err != nil { - return nil, fmt.Errorf("start ollama: %w", err) - } + // waitFor is the strategy to wait for the process to be ready. + waitFor wait.Strategy - return c, nil + // done is closed when the process is finished. + done chan struct{} + + // wg is used to wait for the process to finish. + wg sync.WaitGroup + + // exitErr is the error returned by the process. + exitErr error } -// logFile returns an existing log file or creates a new one if it doesn't exist. -func logFile() (*os.File, error) { - logName := "local-ollama-" + testcontainers.SessionID() + ".log" +// runLocal returns an OllamaContainer that uses the local Ollama binary instead of using a Docker container. +func runLocal(ctx context.Context, req testcontainers.GenericContainerRequest) (*OllamaContainer, error) { + // TODO: validate the request and return an error if it + // contains any unsupported elements. + + sessionID := testcontainers.SessionID() + local := &localProcess{ + sessionID: sessionID, + env: make([]string, 0, len(req.Env)), + waitFor: req.WaitingFor, + logName: localNamePrefix + "-" + sessionID + ".log", + } + + // Apply the environment variables to the command and + // override the log file if specified. + for k, v := range req.Env { + local.env = append(local.env, k+"="+v) + if k == localLogVar { + local.logName = v + } + } - if envLogName := os.Getenv("OLLAMA_LOGFILE"); envLogName != "" { - logName = envLogName + err := local.Start(ctx) + var c *OllamaContainer + if local.cmd != nil { + c = &OllamaContainer{Container: local} } - file, err := os.Create(logName) if err != nil { - return nil, fmt.Errorf("create ollama log file: %w", err) + return nil, fmt.Errorf("start ollama: %w", err) } - return file, nil + return c, nil } -// startLocalOllama starts the Ollama serve command in the background, writing to the -// provided log file. -func (c *OllamaContainer) startLocalOllama(ctx context.Context) error { - if c.localCtx.serveCmd != nil { - return nil +// Start implements testcontainers.Container interface for the local Ollama binary. +func (c *localProcess) Start(ctx context.Context) error { + if c.IsRunning() { + return errors.New("already running") } - c.localCtx.mx.Lock() + cmd := exec.CommandContext(ctx, localBinary, localServeArg) + cmd.Env = c.env - serveCmd := exec.CommandContext(ctx, "ollama", "serve") - serveCmd.Env = append(serveCmd.Env, c.localCtx.env...) - serveCmd.Env = append(serveCmd.Env, os.Environ()...) - - logFile, err := logFile() + var err error + c.logFile, err = os.Create(c.logName) if err != nil { - c.localCtx.mx.Unlock() - return fmt.Errorf("ollama log file: %w", err) + return fmt.Errorf("create ollama log file: %w", err) } - serveCmd.Stdout = logFile - serveCmd.Stderr = logFile + // Multiplex stdout and stderr to the log file matching the Docker API. + cmd.Stdout = stdcopy.NewStdWriter(c.logFile, stdcopy.Stdout) + cmd.Stderr = stdcopy.NewStdWriter(c.logFile, stdcopy.Stderr) - // Run the ollama serve command in background - err = serveCmd.Start() - if err != nil { - c.localCtx.mx.Unlock() - return fmt.Errorf("start ollama serve: %w", err) + // Run the ollama serve command in background. + if err = cmd.Start(); err != nil { + return fmt.Errorf("start ollama serve: %w", errors.Join(err, c.cleanupLog())) } - c.localCtx.serveCmd = serveCmd - c.localCtx.logFile = logFile + // Past this point, the process was started successfully. + c.cmd = cmd + c.startedAt = time.Now() - // unlock before waiting for the process to be ready - c.localCtx.mx.Unlock() + // Reset the details to allow multiple start / stop cycles. + c.done = make(chan struct{}) + c.finishedAt = zeroTime + c.exitErr = nil - waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() + // Wait for the process to finish in a goroutine. + c.wg.Add(1) + go func() { + defer func() { + c.wg.Done() + c.finishedAt = time.Now() + close(c.done) + }() - err = c.waitForOllama(waitCtx) - if err != nil { - return fmt.Errorf("wait for ollama to start: %w", err) + if err := c.cmd.Wait(); err != nil { + c.exitErr = fmt.Errorf("process wait: %w", err) + } + }() + + if err = c.waitStrategy(ctx); err != nil { + return fmt.Errorf("wait strategy: %w", err) + } + + if err := c.extractLogDetails(ctx); err != nil { + return fmt.Errorf("extract log details: %w", err) } return nil } -// waitForOllama Wait until the Ollama process is ready, checking that the log file contains -// the "Listening on 127.0.0.1:11434" message -func (c *OllamaContainer) waitForOllama(ctx context.Context) error { - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - - err := wait.ForLog("Listening on "+c.localCtx.host+":"+c.localCtx.port).WaitUntilReady(ctx, c) - if err != nil { - logs, err := c.Logs(ctx) - if err != nil { - return fmt.Errorf("wait for ollama to start: %w", err) +// waitStrategy waits until the Ollama process is ready. +func (c *localProcess) waitStrategy(ctx context.Context) error { + if err := c.waitFor.WaitUntilReady(ctx, c); err != nil { + logs, lerr := c.Logs(ctx) + if lerr != nil { + return errors.Join(err, lerr) } + defer logs.Close() - // ignore error as we already have an error and the output is already logged - bs, _ := io.ReadAll(logs) - return fmt.Errorf("wait for ollama to start: %w. Container logs:\n%s", err, string(bs)) + var stderr, stdout bytes.Buffer + _, cerr := stdcopy.StdCopy(&stdout, &stderr, logs) + + return fmt.Errorf( + "%w (stdout: %s, stderr: %s)", + errors.Join(err, cerr), + strings.TrimSpace(stdout.String()), + strings.TrimSpace(stderr.String()), + ) } return nil } -// ContainerIP returns the IP address of the local Ollama binary. -func (c *OllamaContainer) ContainerIP(ctx context.Context) (string, error) { - if c.localCtx == nil { - return c.Container.ContainerIP(ctx) +// extractLogDetails extracts the listening address and version from the log. +func (c *localProcess) extractLogDetails(ctx context.Context) error { + rc, err := c.Logs(ctx) + if err != nil { + return fmt.Errorf("logs: %w", err) } + defer rc.Close() - return localIP, nil -} - -// ContainerIPs returns a slice with the IP address of the local Ollama binary. -func (c *OllamaContainer) ContainerIPs(ctx context.Context) ([]string, error) { - if c.localCtx == nil { - return c.Container.ContainerIPs(ctx) + bs, err := io.ReadAll(rc) + if err != nil { + return fmt.Errorf("read logs: %w", err) } - return []string{localIP}, nil -} + matches := reLogDetails.FindSubmatch(bs) + if len(matches) != 3 { + return errors.New("address and version not found") + } -// CopyToContainer is a no-op for the local Ollama binary. -func (c *OllamaContainer) CopyToContainer(ctx context.Context, fileContent []byte, containerFilePath string, fileMode int64) error { - if c.localCtx == nil { - return c.Container.CopyToContainer(ctx, fileContent, containerFilePath, fileMode) + c.host, c.port, err = net.SplitHostPort(string(matches[1])) + if err != nil { + return fmt.Errorf("split host port: %w", err) } - return errCopyAPIsNotSupported -} + // Set OLLAMA_HOST variable to the extracted host so Exec can use it. + c.env = append(c.env, localHostVar+"="+string(matches[1])) + c.version = string(matches[2]) -// CopyDirToContainer is a no-op for the local Ollama binary. -func (c *OllamaContainer) CopyDirToContainer(ctx context.Context, hostDirPath string, containerParentPath string, fileMode int64) error { - if c.localCtx == nil { - return c.Container.CopyDirToContainer(ctx, hostDirPath, containerParentPath, fileMode) - } + return nil +} - return errCopyAPIsNotSupported +// ContainerIP implements testcontainers.Container interface for the local Ollama binary. +func (c *localProcess) ContainerIP(ctx context.Context) (string, error) { + return c.host, nil } -// CopyFileToContainer is a no-op for the local Ollama binary. -func (c *OllamaContainer) CopyFileToContainer(ctx context.Context, hostFilePath string, containerFilePath string, fileMode int64) error { - if c.localCtx == nil { - return c.Container.CopyFileToContainer(ctx, hostFilePath, containerFilePath, fileMode) - } +// ContainerIPs returns a slice with the IP address of the local Ollama binary. +func (c *localProcess) ContainerIPs(ctx context.Context) ([]string, error) { + return []string{c.host}, nil +} - return errCopyAPIsNotSupported +// CopyToContainer implements testcontainers.Container interface for the local Ollama binary. +// Returns [errors.ErrUnsupported]. +func (c *localProcess) CopyToContainer(ctx context.Context, fileContent []byte, containerFilePath string, fileMode int64) error { + return errors.ErrUnsupported } -// CopyFileFromContainer is a no-op for the local Ollama binary. -func (c *OllamaContainer) CopyFileFromContainer(ctx context.Context, filePath string) (io.ReadCloser, error) { - if c.localCtx == nil { - return c.Container.CopyFileFromContainer(ctx, filePath) - } +// CopyDirToContainer implements testcontainers.Container interface for the local Ollama binary. +// Returns [errors.ErrUnsupported]. +func (c *localProcess) CopyDirToContainer(ctx context.Context, hostDirPath string, containerParentPath string, fileMode int64) error { + return errors.ErrUnsupported +} - return nil, errCopyAPIsNotSupported +// CopyFileToContainer implements testcontainers.Container interface for the local Ollama binary. +// Returns [errors.ErrUnsupported]. +func (c *localProcess) CopyFileToContainer(ctx context.Context, hostFilePath string, containerFilePath string, fileMode int64) error { + return errors.ErrUnsupported } -// GetLogProductionErrorChannel returns a nil channel. -func (c *OllamaContainer) GetLogProductionErrorChannel() <-chan error { - if c.localCtx == nil { - return c.Container.GetLogProductionErrorChannel() - } +// CopyFileFromContainer implements testcontainers.Container interface for the local Ollama binary. +// Returns [errors.ErrUnsupported]. +func (c *localProcess) CopyFileFromContainer(ctx context.Context, filePath string) (io.ReadCloser, error) { + return nil, errors.ErrUnsupported +} +// GetLogProductionErrorChannel implements testcontainers.Container interface for the local Ollama binary. +// It returns a nil channel because the local Ollama binary doesn't have a production error channel. +func (c *localProcess) GetLogProductionErrorChannel() <-chan error { return nil } -// Endpoint returns the 127.0.0.1:11434 endpoint for the local Ollama binary. -func (c *OllamaContainer) Endpoint(ctx context.Context, port string) (string, error) { - if c.localCtx == nil { - return c.Container.Endpoint(ctx, port) +// Exec implements testcontainers.Container interface for the local Ollama binary. +// It executes a command using the local Ollama binary and returns the exit status +// of the executed command, an [io.Reader] containing the combined stdout and stderr, +// and any encountered error. +// +// Reading directly from the [io.Reader] may result in unexpected bytes due to custom +// stream multiplexing headers. Use [tcexec.Multiplexed] option to read the combined output +// without the multiplexing headers. +// Alternatively, to separate the stdout and stderr from [io.Reader] and interpret these +// headers properly, [stdcopy.StdCopy] from the Docker API should be used. +func (c *localProcess) Exec(ctx context.Context, cmd []string, options ...tcexec.ProcessOption) (int, io.Reader, error) { + if len(cmd) == 0 { + return 1, nil, errors.New("no command provided") + } else if cmd[0] != localBinary { + return 1, nil, fmt.Errorf("command %q: %w", cmd[0], errors.ErrUnsupported) } - return c.localCtx.host + ":" + c.localCtx.port, nil -} - -// Exec executes a command using the local Ollama binary. -func (c *OllamaContainer) Exec(ctx context.Context, cmd []string, options ...tcexec.ProcessOption) (int, io.Reader, error) { - if c.localCtx == nil { - return c.Container.Exec(ctx, cmd, options...) - } + command := exec.CommandContext(ctx, cmd[0], cmd[1:]...) + command.Env = c.env - c.localCtx.mx.Lock() - defer c.localCtx.mx.Unlock() + // Multiplex stdout and stderr to the buffer so they can be read separately later. + var buf bytes.Buffer + command.Stdout = stdcopy.NewStdWriter(&buf, stdcopy.Stdout) + command.Stderr = stdcopy.NewStdWriter(&buf, stdcopy.Stderr) - if len(cmd) == 0 { - err := errors.New("exec: no command provided") - return 1, strings.NewReader(err.Error()), err - } else if cmd[0] != "ollama" { - err := fmt.Errorf("%s: %w", cmd[0], errors.ErrUnsupported) - return 1, strings.NewReader(err.Error()), err + // Use process options to customize the command execution + // emulating the Docker API behaviour. + processOptions := tcexec.NewProcessOptions(cmd) + processOptions.Reader = &buf + for _, o := range options { + o.Apply(processOptions) } - args := []string{} - if len(cmd) > 1 { - args = cmd[1:] // prevent when there is only one command + if err := c.validateExecOptions(processOptions.ExecConfig); err != nil { + return 1, nil, fmt.Errorf("validate exec option: %w", err) } - command := prepareExec(ctx, cmd[0], args, c.localCtx.env, c.localCtx.logFile) - err := command.Run() - if err != nil { - return command.ProcessState.ExitCode(), c.localCtx.logFile, fmt.Errorf("exec %v: %w", cmd, err) + if !processOptions.ExecConfig.AttachStderr { + command.Stderr = io.Discard } - - return command.ProcessState.ExitCode(), c.localCtx.logFile, nil -} - -func prepareExec(ctx context.Context, bin string, args []string, env []string, output io.Writer) *exec.Cmd { - command := exec.CommandContext(ctx, bin, args...) - command.Env = append(command.Env, env...) - command.Env = append(command.Env, os.Environ()...) - - command.Stdout = output - command.Stderr = output - - return command -} - -// GetContainerID returns a placeholder ID for local execution -func (c *OllamaContainer) GetContainerID() string { - if c.localCtx == nil { - return c.Container.GetContainerID() + if !processOptions.ExecConfig.AttachStdout { + command.Stdout = io.Discard + } + if processOptions.ExecConfig.AttachStdin { + command.Stdin = os.Stdin } - return "local-ollama-" + testcontainers.SessionID() -} + command.Dir = processOptions.ExecConfig.WorkingDir + command.Env = append(command.Env, processOptions.ExecConfig.Env...) -// Host returns the 127.0.0.1 address for the local Ollama binary. -func (c *OllamaContainer) Host(ctx context.Context) (string, error) { - if c.localCtx == nil { - return c.Container.Host(ctx) + if err := command.Run(); err != nil { + return command.ProcessState.ExitCode(), processOptions.Reader, fmt.Errorf("exec %v: %w", cmd, err) } - return localIP, nil + return command.ProcessState.ExitCode(), processOptions.Reader, nil } -// Inspect returns a ContainerJSON with the state of the local Ollama binary. -// The version is read from the local Ollama binary (ollama -v), and the port -// mapping is set to 11434. -func (c *OllamaContainer) Inspect(ctx context.Context) (*types.ContainerJSON, error) { - if c.localCtx == nil { - return c.Container.Inspect(ctx) +// validateExecOptions checks if the given exec options are supported by the local Ollama binary. +func (c *localProcess) validateExecOptions(options container.ExecOptions) error { + var errs []error + if options.User != "" { + errs = append(errs, fmt.Errorf("user: %w", errors.ErrUnsupported)) } - - state, err := c.State(ctx) - if err != nil { - return nil, fmt.Errorf("get ollama state: %w", err) + if options.Privileged { + errs = append(errs, fmt.Errorf("privileged: %w", errors.ErrUnsupported)) } - - // read the version from the ollama binary - var buf bytes.Buffer - command := prepareExec(ctx, "ollama", []string{"-v"}, c.localCtx.env, &buf) - if err := command.Run(); err != nil { - return nil, fmt.Errorf("read ollama -v output: %w", err) + if options.Tty { + errs = append(errs, fmt.Errorf("tty: %w", errors.ErrUnsupported)) + } + if options.Detach { + errs = append(errs, fmt.Errorf("detach: %w", errors.ErrUnsupported)) + } + if options.DetachKeys != "" { + errs = append(errs, fmt.Errorf("detach keys: %w", errors.ErrUnsupported)) } - bs, err := io.ReadAll(&buf) + return errors.Join(errs...) +} + +// Inspect implements testcontainers.Container interface for the local Ollama binary. +// It returns a ContainerJSON with the state of the local Ollama binary. +func (c *localProcess) Inspect(ctx context.Context) (*types.ContainerJSON, error) { + state, err := c.State(ctx) if err != nil { - return nil, fmt.Errorf("read ollama -v output: %w", err) + return nil, fmt.Errorf("state: %w", err) } return &types.ContainerJSON{ ContainerJSONBase: &types.ContainerJSONBase{ ID: c.GetContainerID(), - Name: "local-ollama-" + testcontainers.SessionID(), + Name: localNamePrefix + "-" + c.sessionID, State: state, }, Config: &container.Config{ - Image: string(bs), + Image: localNamePrefix + ":" + c.version, ExposedPorts: nat.PortSet{ - nat.Port(c.localCtx.port + "/tcp"): struct{}{}, + nat.Port(localPort + "/tcp"): struct{}{}, }, - Hostname: "localhost", - Entrypoint: []string{"ollama", "serve"}, + Hostname: c.host, + Entrypoint: []string{localBinary, localServeArg}, }, NetworkSettings: &types.NetworkSettings{ Networks: map[string]*network.EndpointSettings{}, NetworkSettingsBase: types.NetworkSettingsBase{ Bridge: "bridge", Ports: nat.PortMap{ - nat.Port(c.localCtx.port + "/tcp"): { - {HostIP: c.localCtx.host, HostPort: c.localCtx.port}, + nat.Port(localPort + "/tcp"): { + {HostIP: c.host, HostPort: c.port}, }, }, }, DefaultNetworkSettings: types.DefaultNetworkSettings{ - IPAddress: c.localCtx.host, + IPAddress: c.host, }, }, }, nil } -// IsRunning returns true if the local Ollama process is running. -func (c *OllamaContainer) IsRunning() bool { - if c.localCtx == nil { - return c.Container.IsRunning() +// IsRunning implements testcontainers.Container interface for the local Ollama binary. +// It returns true if the local Ollama process is running, false otherwise. +func (c *localProcess) IsRunning() bool { + if c.startedAt.IsZero() { + // The process hasn't started yet. + return false } - c.localCtx.mx.Lock() - defer c.localCtx.mx.Unlock() - - return c.localCtx.serveCmd != nil + select { + case <-c.done: + // The process exited. + return false + default: + // The process is still running. + return true + } } -// Logs returns the logs from the local Ollama binary. -func (c *OllamaContainer) Logs(ctx context.Context) (io.ReadCloser, error) { - if c.localCtx == nil { - return c.Container.Logs(ctx) +// Logs implements testcontainers.Container interface for the local Ollama binary. +// It returns the logs from the local Ollama binary. +func (c *localProcess) Logs(ctx context.Context) (io.ReadCloser, error) { + file, err := os.Open(c.logFile.Name()) + if err != nil { + return nil, fmt.Errorf("open log file: %w", err) } - c.localCtx.mx.Lock() - defer c.localCtx.mx.Unlock() - - // stream the log file - return os.Open(c.localCtx.logFile.Name()) + return file, nil } -// MappedPort returns the configured port for local Ollama binary. -func (c *OllamaContainer) MappedPort(ctx context.Context, port nat.Port) (nat.Port, error) { - if c.localCtx == nil { - return c.Container.MappedPort(ctx, port) +// State implements testcontainers.Container interface for the local Ollama binary. +// It returns the current state of the Ollama process, simulating a container state. +func (c *localProcess) State(ctx context.Context) (*types.ContainerState, error) { + if !c.IsRunning() { + state := &types.ContainerState{ + Status: "exited", + ExitCode: c.cmd.ProcessState.ExitCode(), + StartedAt: c.startedAt.Format(time.RFC3339Nano), + FinishedAt: c.finishedAt.Format(time.RFC3339Nano), + } + if c.exitErr != nil { + state.Error = c.exitErr.Error() + } + + return state, nil } - // Ollama typically uses port 11434 by default - return nat.Port(c.localCtx.port + "/tcp"), nil + // Setting the Running field because it's required by the wait strategy + // to check if the given log message is present. + return &types.ContainerState{ + Status: "running", + Running: true, + Pid: c.cmd.Process.Pid, + StartedAt: c.startedAt.Format(time.RFC3339Nano), + FinishedAt: c.finishedAt.Format(time.RFC3339Nano), + }, nil } -// Networks returns the networks for local Ollama binary, which is a nil slice. -func (c *OllamaContainer) Networks(ctx context.Context) ([]string, error) { - if c.localCtx == nil { - return c.Container.Networks(ctx) +// Stop implements testcontainers.Container interface for the local Ollama binary. +// It gracefully stops the local Ollama process. +func (c *localProcess) Stop(ctx context.Context, d *time.Duration) error { + if err := c.cmd.Process.Signal(syscall.SIGTERM); err != nil { + return fmt.Errorf("signal ollama: %w", err) } - return nil, nil + if d != nil { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, *d) + defer cancel() + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.done: + // The process exited. + return c.exitErr + } } -// NetworkAliases returns the network aliases for local Ollama binary, which is a nil map. -func (c *OllamaContainer) NetworkAliases(ctx context.Context) (map[string][]string, error) { - if c.localCtx == nil { - return c.Container.NetworkAliases(ctx) +// Terminate implements testcontainers.Container interface for the local Ollama binary. +// It stops the local Ollama process, removing the log file. +func (c *localProcess) Terminate(ctx context.Context) error { + // First try to stop gracefully. + if err := c.Stop(ctx, &defaultStopTimeout); !c.isCleanupSafe(err) { + return fmt.Errorf("stop: %w", err) } - return nil, nil -} + if c.IsRunning() { + // Still running, force kill. + if err := c.cmd.Process.Kill(); !c.isCleanupSafe(err) { + return fmt.Errorf("kill: %w", err) + } -// SessionID returns the session ID for local Ollama binary, which is the session ID -// of the test execution. -func (c *OllamaContainer) SessionID() string { - if c.localCtx == nil { - return c.Container.SessionID() + // Wait for the process to exit so capture any error. + c.wg.Wait() } - return testcontainers.SessionID() + return errors.Join(c.exitErr, c.cleanupLog()) } -// Start starts the local Ollama process, not failing if it's already running. -func (c *OllamaContainer) Start(ctx context.Context) error { - if c.localCtx == nil { - return c.Container.Start(ctx) +// cleanupLog closes the log file and removes it. +func (c *localProcess) cleanupLog() error { + if c.logFile == nil { + return nil } - err := c.startLocalOllama(ctx) - if err != nil { - return fmt.Errorf("start ollama: %w", err) + var errs []error + if err := c.logFile.Close(); err != nil { + errs = append(errs, fmt.Errorf("close log: %w", err)) } - return nil + if err := os.Remove(c.logFile.Name()); err != nil && !errors.Is(err, fs.ErrNotExist) { + errs = append(errs, fmt.Errorf("remove log: %w", err)) + } + + c.logFile = nil // Prevent double cleanup. + + return errors.Join(errs...) } -// State returns the current state of the Ollama process, simulating a container state -// for local execution. -func (c *OllamaContainer) State(ctx context.Context) (*types.ContainerState, error) { - if c.localCtx == nil { - return c.Container.State(ctx) - } +// Endpoint implements testcontainers.Container interface for the local Ollama binary. +// It returns proto://host:port string for the Ollama port. +// It returns just host:port if proto is blank. +func (c *localProcess) Endpoint(ctx context.Context, proto string) (string, error) { + return c.PortEndpoint(ctx, localPort, proto) +} - c.localCtx.mx.Lock() - defer c.localCtx.mx.Unlock() +// GetContainerID implements testcontainers.Container interface for the local Ollama binary. +func (c *localProcess) GetContainerID() string { + return localNamePrefix + "-" + c.sessionID +} - if c.localCtx.serveCmd == nil { - return &types.ContainerState{Status: "exited"}, nil - } +// Host implements testcontainers.Container interface for the local Ollama binary. +func (c *localProcess) Host(ctx context.Context) (string, error) { + return c.host, nil +} - // Check if process is still running. Signal(0) is a special case in Unix-like systems. - // When you send signal 0 to a process: - // - It performs all the normal error checking (permissions, process existence, etc.) - // - But it doesn't actually send any signal to the process - if err := c.localCtx.serveCmd.Process.Signal(syscall.Signal(0)); err != nil { - return &types.ContainerState{Status: "created"}, nil +// MappedPort implements testcontainers.Container interface for the local Ollama binary. +func (c *localProcess) MappedPort(ctx context.Context, port nat.Port) (nat.Port, error) { + if port.Port() != localPort || port.Proto() != "tcp" { + return "", errdefs.NotFound(fmt.Errorf("port %q not found", port)) } - // Setting the Running field because it's required by the wait strategy - // to check if the given log message is present. - return &types.ContainerState{Status: "running", Running: true}, nil + return nat.Port(c.port + "/tcp"), nil } -// Stop gracefully stops the local Ollama process -func (c *OllamaContainer) Stop(ctx context.Context, d *time.Duration) error { - if c.localCtx == nil { - return c.Container.Stop(ctx, d) - } +// Networks implements testcontainers.Container interface for the local Ollama binary. +// It returns a nil slice. +func (c *localProcess) Networks(ctx context.Context) ([]string, error) { + return nil, nil +} - c.localCtx.mx.Lock() - defer c.localCtx.mx.Unlock() +// NetworkAliases implements testcontainers.Container interface for the local Ollama binary. +// It returns a nil map. +func (c *localProcess) NetworkAliases(ctx context.Context) (map[string][]string, error) { + return nil, nil +} - if c.localCtx.serveCmd == nil { - return nil +// PortEndpoint implements testcontainers.Container interface for the local Ollama binary. +// It returns proto://host:port string for the given exposed port. +// It returns just host:port if proto is blank. +func (c *localProcess) PortEndpoint(ctx context.Context, port nat.Port, proto string) (string, error) { + host, err := c.Host(ctx) + if err != nil { + return "", fmt.Errorf("host: %w", err) } - if err := c.localCtx.serveCmd.Process.Signal(syscall.SIGTERM); err != nil { - return fmt.Errorf("signal ollama: %w", err) + outerPort, err := c.MappedPort(ctx, port) + if err != nil { + return "", fmt.Errorf("mapped port: %w", err) + } + + if proto != "" { + proto = proto + "://" } - c.localCtx.serveCmd = nil + return fmt.Sprintf("%s%s:%s", proto, host, outerPort.Port()), nil +} - return nil +// SessionID implements testcontainers.Container interface for the local Ollama binary. +func (c *localProcess) SessionID() string { + return c.sessionID } -// Terminate stops the local Ollama process, removing the log file. -func (c *OllamaContainer) Terminate(ctx context.Context) error { - if c.localCtx == nil { - return c.Container.Terminate(ctx) - } +// Deprecated: it will be removed in the next major release. +// FollowOutput is not implemented for the local Ollama binary. +// It panics if called. +func (c *localProcess) FollowOutput(consumer testcontainers.LogConsumer) { + panic("not implemented") +} - // First try to stop gracefully - err := c.Stop(ctx, &defaultStopTimeout) +// Deprecated: use c.Inspect(ctx).NetworkSettings.Ports instead. +// Ports gets the exposed ports for the container. +func (c *localProcess) Ports(ctx context.Context) (nat.PortMap, error) { + inspect, err := c.Inspect(ctx) if err != nil { - return fmt.Errorf("stop ollama: %w", err) + return nil, err } - c.localCtx.mx.Lock() - defer c.localCtx.mx.Unlock() + return inspect.NetworkSettings.Ports, nil +} - if c.localCtx.logFile == nil { - return nil - } +// Deprecated: it will be removed in the next major release. +// StartLogProducer implements testcontainers.Container interface for the local Ollama binary. +// It returns an error because the local Ollama binary doesn't have a log producer. +func (c *localProcess) StartLogProducer(context.Context, ...testcontainers.LogProductionOption) error { + return errors.ErrUnsupported +} - var errs []error - if err = c.localCtx.logFile.Close(); err != nil { - errs = append(errs, fmt.Errorf("close log: %w", err)) - } +// Deprecated: it will be removed in the next major release. +// StopLogProducer implements testcontainers.Container interface for the local Ollama binary. +// It returns an error because the local Ollama binary doesn't have a log producer. +func (c *localProcess) StopLogProducer() error { + return errors.ErrUnsupported +} - if err = os.Remove(c.localCtx.logFile.Name()); err != nil && !errors.Is(err, fs.ErrNotExist) { - errs = append(errs, fmt.Errorf("remove log: %w", err)) - } +// Deprecated: Use c.Inspect(ctx).Name instead. +// Name returns the name for the local Ollama binary. +func (c *localProcess) Name(context.Context) (string, error) { + return localNamePrefix + "-" + c.sessionID, nil +} - return errors.Join(errs...) +// isCleanupSafe reports whether all errors in err's tree are one of the +// following, so can safely be ignored: +// - nil +// - os: process already finished +// - context deadline exceeded +func (c *localProcess) isCleanupSafe(err error) bool { + switch { + case err == nil, + errors.Is(err, os.ErrProcessDone), + errors.Is(err, context.DeadlineExceeded): + return true + default: + return false + } } diff --git a/modules/ollama/local_test.go b/modules/ollama/local_test.go index 7bd073ca5e..cbf4c41ced 100644 --- a/modules/ollama/local_test.go +++ b/modules/ollama/local_test.go @@ -4,12 +4,15 @@ import ( "context" "errors" "io" + "io/fs" "os" "os/exec" "path/filepath" + "regexp" "testing" "time" + "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/strslice" "github.com/stretchr/testify/require" @@ -18,26 +21,52 @@ import ( "github.com/testcontainers/testcontainers-go/modules/ollama" ) +const ( + testImage = "ollama/ollama:latest" + testNatPort = "11434/tcp" + testHost = "127.0.0.1" + testBinary = "ollama" +) + +var ( + reLogDetails = regexp.MustCompile(`Listening on (.*:\d+) \(version\s(.*)\)`) + zeroTime = time.Time{}.Format(time.RFC3339Nano) +) + func TestRun_local(t *testing.T) { // check if the local ollama binary is available - if _, err := exec.LookPath("ollama"); err != nil { + if _, err := exec.LookPath(testBinary); err != nil { t.Skip("local ollama binary not found, skipping") } ctx := context.Background() - ollamaContainer, err := ollama.Run( ctx, - "ollama/ollama:0.1.25", + testImage, ollama.WithUseLocal("FOO=BAR"), ) testcontainers.CleanupContainer(t, ollamaContainer) require.NoError(t, err) + t.Run("state", func(t *testing.T) { + state, err := ollamaContainer.State(ctx) + require.NoError(t, err) + require.NotEmpty(t, state.StartedAt) + require.NotEqual(t, zeroTime, state.StartedAt) + require.NotZero(t, state.Pid) + require.Equal(t, &types.ContainerState{ + Status: "running", + Running: true, + Pid: state.Pid, + StartedAt: state.StartedAt, + FinishedAt: time.Time{}.Format(time.RFC3339Nano), + }, state) + }) + t.Run("connection-string", func(t *testing.T) { connectionStr, err := ollamaContainer.ConnectionString(ctx) require.NoError(t, err) - require.Equal(t, "http://127.0.0.1:11434", connectionStr) + require.NotEmpty(t, connectionStr) }) t.Run("container-id", func(t *testing.T) { @@ -48,11 +77,11 @@ func TestRun_local(t *testing.T) { t.Run("container-ips", func(t *testing.T) { ip, err := ollamaContainer.ContainerIP(ctx) require.NoError(t, err) - require.Equal(t, "127.0.0.1", ip) + require.Equal(t, testHost, ip) ips, err := ollamaContainer.ContainerIPs(ctx) require.NoError(t, err) - require.Equal(t, []string{"127.0.0.1"}, ips) + require.Equal(t, []string{testHost}, ips) }) t.Run("copy", func(t *testing.T) { @@ -76,52 +105,13 @@ func TestRun_local(t *testing.T) { }) t.Run("endpoint", func(t *testing.T) { - endpoint, err := ollamaContainer.Endpoint(ctx, "88888/tcp") - require.NoError(t, err) - require.Equal(t, "127.0.0.1:11434", endpoint) - }) - - t.Run("exec/pull-and-run-model", func(t *testing.T) { - const model = "llama3.2:1b" - - code, r, err := ollamaContainer.Exec(ctx, []string{"ollama", "pull", model}) - require.NoError(t, err) - require.Equal(t, 0, code) - - bs, err := io.ReadAll(r) - require.NoError(t, err) - require.Empty(t, bs) - - code, _, err = ollamaContainer.Exec(ctx, []string{"ollama", "run", model}, tcexec.Multiplexed()) + endpoint, err := ollamaContainer.Endpoint(ctx, "") require.NoError(t, err) - require.Equal(t, 0, code) - - logs, err := ollamaContainer.Logs(ctx) - require.NoError(t, err) - defer logs.Close() - - bs, err = io.ReadAll(logs) - require.NoError(t, err) - require.Contains(t, string(bs), "llama runner started") - }) - - t.Run("exec/unsupported-command", func(t *testing.T) { - code, r, err := ollamaContainer.Exec(ctx, []string{"cat", "/etc/passwd"}) - require.Equal(t, 1, code) - require.Error(t, err) - require.ErrorIs(t, err, errors.ErrUnsupported) - - bs, err := io.ReadAll(r) - require.NoError(t, err) - require.Equal(t, "cat: unsupported operation", string(bs)) - - code, r, err = ollamaContainer.Exec(ctx, []string{}) - require.Equal(t, 1, code) - require.Error(t, err) + require.Contains(t, endpoint, testHost+":") - bs, err = io.ReadAll(r) + endpoint, err = ollamaContainer.Endpoint(ctx, "http") require.NoError(t, err) - require.Equal(t, "exec: no command provided", string(bs)) + require.Contains(t, endpoint, "http://"+testHost+":") }) t.Run("is-running", func(t *testing.T) { @@ -129,20 +119,18 @@ func TestRun_local(t *testing.T) { err = ollamaContainer.Stop(ctx, nil) require.NoError(t, err) - require.False(t, ollamaContainer.IsRunning()) // return it to the running state err = ollamaContainer.Start(ctx) require.NoError(t, err) - require.True(t, ollamaContainer.IsRunning()) }) t.Run("host", func(t *testing.T) { host, err := ollamaContainer.Host(ctx) require.NoError(t, err) - require.Equal(t, "127.0.0.1", host) + require.Equal(t, testHost, host) }) t.Run("inspect", func(t *testing.T) { @@ -153,74 +141,87 @@ func TestRun_local(t *testing.T) { require.Equal(t, "local-ollama-"+testcontainers.SessionID(), inspect.ContainerJSONBase.Name) require.True(t, inspect.ContainerJSONBase.State.Running) - require.Contains(t, string(inspect.Config.Image), "ollama version is") - _, exists := inspect.Config.ExposedPorts["11434/tcp"] + require.NotEmpty(t, inspect.Config.Image) + _, exists := inspect.Config.ExposedPorts[testNatPort] require.True(t, exists) - require.Equal(t, "localhost", inspect.Config.Hostname) - require.Equal(t, strslice.StrSlice(strslice.StrSlice{"ollama", "serve"}), inspect.Config.Entrypoint) + require.Equal(t, testHost, inspect.Config.Hostname) + require.Equal(t, strslice.StrSlice(strslice.StrSlice{testBinary, "serve"}), inspect.Config.Entrypoint) require.Empty(t, inspect.NetworkSettings.Networks) require.Equal(t, "bridge", inspect.NetworkSettings.NetworkSettingsBase.Bridge) ports := inspect.NetworkSettings.NetworkSettingsBase.Ports - _, exists = ports["11434/tcp"] + port, exists := ports[testNatPort] require.True(t, exists) - - require.Equal(t, "127.0.0.1", inspect.NetworkSettings.Ports["11434/tcp"][0].HostIP) - require.Equal(t, "11434", inspect.NetworkSettings.Ports["11434/tcp"][0].HostPort) + require.Len(t, port, 1) + require.Equal(t, testHost, port[0].HostIP) + require.NotEmpty(t, port[0].HostPort) }) t.Run("logfile", func(t *testing.T) { - openFile, err := os.Open("local-ollama-" + testcontainers.SessionID() + ".log") + file, err := os.Open("local-ollama-" + testcontainers.SessionID() + ".log") require.NoError(t, err) - require.NotNil(t, openFile) - require.NoError(t, openFile.Close()) + require.NoError(t, file.Close()) }) t.Run("logs", func(t *testing.T) { logs, err := ollamaContainer.Logs(ctx) require.NoError(t, err) - defer logs.Close() + t.Cleanup(func() { + require.NoError(t, logs.Close()) + }) bs, err := io.ReadAll(logs) require.NoError(t, err) - - require.Contains(t, string(bs), "Listening on 127.0.0.1:11434") + require.Regexp(t, reLogDetails, string(bs)) }) t.Run("mapped-port", func(t *testing.T) { - port, err := ollamaContainer.MappedPort(ctx, "11434/tcp") + port, err := ollamaContainer.MappedPort(ctx, testNatPort) require.NoError(t, err) - require.Equal(t, "11434", port.Port()) + require.NotEmpty(t, port.Port()) require.Equal(t, "tcp", port.Proto()) }) t.Run("networks", func(t *testing.T) { networks, err := ollamaContainer.Networks(ctx) require.NoError(t, err) - require.Empty(t, networks) + require.Nil(t, networks) }) t.Run("network-aliases", func(t *testing.T) { aliases, err := ollamaContainer.NetworkAliases(ctx) require.NoError(t, err) - require.Empty(t, aliases) + require.Nil(t, aliases) + }) + + t.Run("port-endpoint", func(t *testing.T) { + endpoint, err := ollamaContainer.PortEndpoint(ctx, testNatPort, "") + require.NoError(t, err) + require.Regexp(t, regexp.MustCompile(`^127.0.0.1:\d+$`), endpoint) + + endpoint, err = ollamaContainer.PortEndpoint(ctx, testNatPort, "http") + require.NoError(t, err) + require.Regexp(t, regexp.MustCompile(`^http://127.0.0.1:\d+$`), endpoint) }) t.Run("session-id", func(t *testing.T) { - id := ollamaContainer.SessionID() - require.Equal(t, testcontainers.SessionID(), id) + require.Equal(t, testcontainers.SessionID(), ollamaContainer.SessionID()) }) t.Run("stop-start", func(t *testing.T) { d := time.Second * 5 - err := ollamaContainer.Stop(ctx, &d) require.NoError(t, err) state, err := ollamaContainer.State(ctx) require.NoError(t, err) require.Equal(t, "exited", state.Status) + require.NotEmpty(t, state.StartedAt) + require.NotEqual(t, zeroTime, state.StartedAt) + require.NotEmpty(t, state.FinishedAt) + require.NotEqual(t, zeroTime, state.FinishedAt) + require.Zero(t, state.ExitCode) err = ollamaContainer.Start(ctx) require.NoError(t, err) @@ -231,12 +232,13 @@ func TestRun_local(t *testing.T) { logs, err := ollamaContainer.Logs(ctx) require.NoError(t, err) - defer logs.Close() + t.Cleanup(func() { + require.NoError(t, logs.Close()) + }) bs, err := io.ReadAll(logs) require.NoError(t, err) - - require.Contains(t, string(bs), "Listening on 127.0.0.1:11434") + require.Regexp(t, reLogDetails, string(bs)) }) t.Run("start-start", func(t *testing.T) { @@ -245,7 +247,7 @@ func TestRun_local(t *testing.T) { require.Equal(t, "running", state.Status) err = ollamaContainer.Start(ctx) - require.NoError(t, err) + require.Error(t, err) }) t.Run("terminate", func(t *testing.T) { @@ -253,42 +255,126 @@ func TestRun_local(t *testing.T) { require.NoError(t, err) _, err = os.Stat("ollama-" + testcontainers.SessionID() + ".log") - require.True(t, os.IsNotExist(err)) + require.ErrorIs(t, err, fs.ErrNotExist) state, err := ollamaContainer.State(ctx) require.NoError(t, err) - require.Equal(t, "exited", state.Status) + require.NotEmpty(t, state.StartedAt) + require.NotEqual(t, zeroTime, state.StartedAt) + require.NotEmpty(t, state.FinishedAt) + require.NotEqual(t, zeroTime, state.FinishedAt) + require.Equal(t, &types.ContainerState{ + Status: "exited", + StartedAt: state.StartedAt, + FinishedAt: state.FinishedAt, + }, state) + }) + + t.Run("deprecated", func(t *testing.T) { + t.Run("ports", func(t *testing.T) { + inspect, err := ollamaContainer.Inspect(ctx) + require.NoError(t, err) + + ports, err := ollamaContainer.Ports(ctx) + require.NoError(t, err) + require.Equal(t, inspect.NetworkSettings.Ports, ports) + }) + + t.Run("follow-output", func(t *testing.T) { + require.Panics(t, func() { + ollamaContainer.FollowOutput(&testcontainers.StdoutLogConsumer{}) + }) + }) + + t.Run("start-log-producer", func(t *testing.T) { + err := ollamaContainer.StartLogProducer(ctx) + require.ErrorIs(t, err, errors.ErrUnsupported) + + }) + + t.Run("stop-log-producer", func(t *testing.T) { + err := ollamaContainer.StopLogProducer() + require.ErrorIs(t, err, errors.ErrUnsupported) + }) + + t.Run("name", func(t *testing.T) { + name, err := ollamaContainer.Name(ctx) + require.NoError(t, err) + require.Equal(t, "local-ollama-"+testcontainers.SessionID(), name) + }) }) } func TestRun_localWithCustomLogFile(t *testing.T) { - t.Setenv("OLLAMA_LOGFILE", filepath.Join(t.TempDir(), "server.log")) - ctx := context.Background() + logFile := filepath.Join(t.TempDir(), "server.log") - ollamaContainer, err := ollama.Run(ctx, "ollama/ollama:0.1.25", ollama.WithUseLocal("FOO=BAR")) - require.NoError(t, err) - testcontainers.CleanupContainer(t, ollamaContainer) + t.Run("parent-env", func(t *testing.T) { + t.Setenv("OLLAMA_LOGFILE", logFile) - logs, err := ollamaContainer.Logs(ctx) - require.NoError(t, err) - defer logs.Close() + ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal()) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) - bs, err := io.ReadAll(logs) - require.NoError(t, err) + logs, err := ollamaContainer.Logs(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, logs.Close()) + }) + + bs, err := io.ReadAll(logs) + require.NoError(t, err) + require.Regexp(t, reLogDetails, string(bs)) - require.Contains(t, string(bs), "Listening on 127.0.0.1:11434") + file, ok := logs.(*os.File) + require.True(t, ok) + require.Equal(t, logFile, file.Name()) + }) + + t.Run("local-env", func(t *testing.T) { + ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal("OLLAMA_LOGFILE="+logFile)) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + + logs, err := ollamaContainer.Logs(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, logs.Close()) + }) + + bs, err := io.ReadAll(logs) + require.NoError(t, err) + require.Regexp(t, reLogDetails, string(bs)) + + file, ok := logs.(*os.File) + require.True(t, ok) + require.Equal(t, logFile, file.Name()) + }) } func TestRun_localWithCustomHost(t *testing.T) { - t.Setenv("OLLAMA_HOST", "127.0.0.1:1234") - ctx := context.Background() - ollamaContainer, err := ollama.Run(ctx, "ollama/ollama:0.1.25", ollama.WithUseLocal()) - require.NoError(t, err) - testcontainers.CleanupContainer(t, ollamaContainer) + t.Run("parent-env", func(t *testing.T) { + t.Setenv("OLLAMA_HOST", "127.0.0.1:1234") + ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal()) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + + testRun_localWithCustomHost(ctx, t, ollamaContainer) + }) + + t.Run("local-env", func(t *testing.T) { + ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal("OLLAMA_HOST=127.0.0.1:1234")) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + + testRun_localWithCustomHost(ctx, t, ollamaContainer) + }) +} + +func testRun_localWithCustomHost(ctx context.Context, t *testing.T, ollamaContainer *ollama.OllamaContainer) { t.Run("connection-string", func(t *testing.T) { connectionStr, err := ollamaContainer.ConnectionString(ctx) require.NoError(t, err) @@ -296,36 +382,38 @@ func TestRun_localWithCustomHost(t *testing.T) { }) t.Run("endpoint", func(t *testing.T) { - endpoint, err := ollamaContainer.Endpoint(ctx, "1234/tcp") + endpoint, err := ollamaContainer.Endpoint(ctx, "http") require.NoError(t, err) - require.Equal(t, "127.0.0.1:1234", endpoint) + require.Equal(t, "http://127.0.0.1:1234", endpoint) }) t.Run("inspect", func(t *testing.T) { inspect, err := ollamaContainer.Inspect(ctx) require.NoError(t, err) + require.Regexp(t, regexp.MustCompile(`^local-ollama:\d+\.\d+\.\d+$`), inspect.Config.Image) - require.Contains(t, string(inspect.Config.Image), "ollama version is") - _, exists := inspect.Config.ExposedPorts["1234/tcp"] + _, exists := inspect.Config.ExposedPorts[testNatPort] require.True(t, exists) - require.Equal(t, "localhost", inspect.Config.Hostname) - require.Equal(t, strslice.StrSlice(strslice.StrSlice{"ollama", "serve"}), inspect.Config.Entrypoint) + require.Equal(t, testHost, inspect.Config.Hostname) + require.Equal(t, strslice.StrSlice(strslice.StrSlice{testBinary, "serve"}), inspect.Config.Entrypoint) require.Empty(t, inspect.NetworkSettings.Networks) require.Equal(t, "bridge", inspect.NetworkSettings.NetworkSettingsBase.Bridge) ports := inspect.NetworkSettings.NetworkSettingsBase.Ports - _, exists = ports["1234/tcp"] + port, exists := ports[testNatPort] require.True(t, exists) - - require.Equal(t, "127.0.0.1", inspect.NetworkSettings.Ports["1234/tcp"][0].HostIP) - require.Equal(t, "1234", inspect.NetworkSettings.Ports["1234/tcp"][0].HostPort) + require.Len(t, port, 1) + require.Equal(t, testHost, port[0].HostIP) + require.Equal(t, "1234", port[0].HostPort) }) t.Run("logs", func(t *testing.T) { logs, err := ollamaContainer.Logs(ctx) require.NoError(t, err) - defer logs.Close() + t.Cleanup(func() { + require.NoError(t, logs.Close()) + }) bs, err := io.ReadAll(logs) require.NoError(t, err) @@ -334,9 +422,109 @@ func TestRun_localWithCustomHost(t *testing.T) { }) t.Run("mapped-port", func(t *testing.T) { - port, err := ollamaContainer.MappedPort(ctx, "1234/tcp") + port, err := ollamaContainer.MappedPort(ctx, testNatPort) require.NoError(t, err) require.Equal(t, "1234", port.Port()) require.Equal(t, "tcp", port.Proto()) }) } + +func TestRun_localExec(t *testing.T) { + // check if the local ollama binary is available + if _, err := exec.LookPath(testBinary); err != nil { + t.Skip("local ollama binary not found, skipping") + } + + ctx := context.Background() + + ollamaContainer, err := ollama.Run(ctx, testImage, ollama.WithUseLocal()) + testcontainers.CleanupContainer(t, ollamaContainer) + require.NoError(t, err) + + t.Run("no-command", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, nil) + require.Error(t, err) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("unsupported-command", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, []string{"cat", "/etc/hosts"}) + require.ErrorIs(t, err, errors.ErrUnsupported) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("unsupported-option-user", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.WithUser("root")) + require.ErrorIs(t, err, errors.ErrUnsupported) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("unsupported-option-privileged", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.ProcessOptionFunc(func(opts *tcexec.ProcessOptions) { + opts.ExecConfig.Privileged = true + })) + require.ErrorIs(t, err, errors.ErrUnsupported) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("unsupported-option-tty", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.ProcessOptionFunc(func(opts *tcexec.ProcessOptions) { + opts.ExecConfig.Tty = true + })) + require.ErrorIs(t, err, errors.ErrUnsupported) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("unsupported-option-detach", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.ProcessOptionFunc(func(opts *tcexec.ProcessOptions) { + opts.ExecConfig.Detach = true + })) + require.ErrorIs(t, err, errors.ErrUnsupported) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("unsupported-option-detach-keys", func(t *testing.T) { + code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "-v"}, tcexec.ProcessOptionFunc(func(opts *tcexec.ProcessOptions) { + opts.ExecConfig.DetachKeys = "ctrl-p,ctrl-q" + })) + require.ErrorIs(t, err, errors.ErrUnsupported) + require.Equal(t, 1, code) + require.Nil(t, r) + }) + + t.Run("pull-and-run-model", func(t *testing.T) { + const model = "llama3.2:1b" + + code, r, err := ollamaContainer.Exec(ctx, []string{testBinary, "pull", model}) + require.NoError(t, err) + require.Zero(t, code) + + bs, err := io.ReadAll(r) + require.NoError(t, err) + require.Contains(t, string(bs), "success") + + code, r, err = ollamaContainer.Exec(ctx, []string{testBinary, "run", model}, tcexec.Multiplexed()) + require.NoError(t, err) + require.Zero(t, code) + + bs, err = io.ReadAll(r) + require.NoError(t, err) + require.NotEmpty(t, r) + + logs, err := ollamaContainer.Logs(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, logs.Close()) + }) + + bs, err = io.ReadAll(logs) + require.NoError(t, err) + require.Contains(t, string(bs), "llama runner started") + }) +} diff --git a/modules/ollama/local_unit_test.go b/modules/ollama/local_unit_test.go deleted file mode 100644 index 95d9b93638..0000000000 --- a/modules/ollama/local_unit_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package ollama - -import ( - "context" - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestRun_localWithCustomLogFileError(t *testing.T) { - t.Run("terminate/close-log-error", func(t *testing.T) { - // Create a temporary file for testing - f, err := os.CreateTemp(t.TempDir(), "test-log-*") - require.NoError(t, err) - - // Close the file before termination to force a "file already closed" error - err = f.Close() - require.NoError(t, err) - - c := &OllamaContainer{ - localCtx: &localContext{ - logFile: f, - }, - } - err = c.Terminate(context.Background()) - require.Error(t, err) - require.ErrorContains(t, err, "close log:") - }) - - t.Run("terminate/log-file-not-removable", func(t *testing.T) { - // Create a temporary file for testing - f, err := os.CreateTemp(t.TempDir(), "test-log-*") - require.NoError(t, err) - defer func() { - // Cleanup: restore permissions - os.Chmod(filepath.Dir(f.Name()), 0700) - }() - - // Make the file read-only and its parent directory read-only - // This should cause removal to fail on most systems - dir := filepath.Dir(f.Name()) - require.NoError(t, os.Chmod(dir, 0500)) - - c := &OllamaContainer{ - localCtx: &localContext{ - logFile: f, - }, - } - err = c.Terminate(context.Background()) - require.Error(t, err) - require.ErrorContains(t, err, "remove log:") - }) -} diff --git a/modules/ollama/ollama.go b/modules/ollama/ollama.go index 3d0cc6fa4e..0284ed215a 100644 --- a/modules/ollama/ollama.go +++ b/modules/ollama/ollama.go @@ -20,24 +20,19 @@ const DefaultOllamaImage = "ollama/ollama:0.1.25" // OllamaContainer represents the Ollama container type used in the module type OllamaContainer struct { testcontainers.Container - localCtx *localContext } // ConnectionString returns the connection string for the Ollama container, // using the default port 11434. func (c *OllamaContainer) ConnectionString(ctx context.Context) (string, error) { - if c.localCtx != nil { - return "http://" + c.localCtx.host + ":" + c.localCtx.port, nil - } - host, err := c.Host(ctx) if err != nil { - return "", err + return "", fmt.Errorf("host: %w", err) } port, err := c.MappedPort(ctx, "11434/tcp") if err != nil { - return "", err + return "", fmt.Errorf("mapped port: %w", err) } return fmt.Sprintf("http://%s:%d", host, port.Int()), nil @@ -48,7 +43,7 @@ func (c *OllamaContainer) ConnectionString(ctx context.Context) (string, error) // of the container into a new image with the given name, so it doesn't override existing images. // It should be used for creating an image that contains a loaded model. func (c *OllamaContainer) Commit(ctx context.Context, targetImage string) error { - if c.localCtx != nil { + if _, ok := c.Container.(*localProcess); ok { return nil } @@ -89,40 +84,36 @@ func RunContainer(ctx context.Context, opts ...testcontainers.ContainerCustomize // Run creates an instance of the Ollama container type func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustomizer) (*OllamaContainer, error) { - req := testcontainers.ContainerRequest{ - Image: img, - ExposedPorts: []string{"11434/tcp"}, - WaitingFor: wait.ForListeningPort("11434/tcp").WithStartupTimeout(60 * time.Second), - } - - genericContainerReq := testcontainers.GenericContainerRequest{ - ContainerRequest: req, - Started: true, + req := testcontainers.GenericContainerRequest{ + ContainerRequest: testcontainers.ContainerRequest{ + Image: img, + ExposedPorts: []string{"11434/tcp"}, + Env: map[string]string{ + localHostVar: "localhost:0", // Use a random port to avoid conflicts. + }, + WaitingFor: wait.ForListeningPort("11434/tcp").WithStartupTimeout(60 * time.Second), + }, + Started: true, } // always request a GPU if the host supports it opts = append(opts, withGpu()) - useLocal := false + var local bool for _, opt := range opts { - if err := opt.Customize(&genericContainerReq); err != nil { + if err := opt.Customize(&req); err != nil { return nil, fmt.Errorf("customize: %w", err) } - if _, ok := opt.(UseLocal); ok { - useLocal = true + if _, ok := opt.(useLocal); ok { + local = true } } - if useLocal { - container, err := runLocal(ctx, req.Env) - if err == nil { - return container, nil - } - - testcontainers.Logger.Printf("failed to run local ollama: %v, switching to docker", err) + if local { + return runLocal(ctx, req) } - container, err := testcontainers.GenericContainer(ctx, genericContainerReq) + container, err := testcontainers.GenericContainer(ctx, req) var c *OllamaContainer if container != nil { c = &OllamaContainer{Container: container} diff --git a/modules/ollama/options.go b/modules/ollama/options.go index 4761a28530..ff7a8bcff3 100644 --- a/modules/ollama/options.go +++ b/modules/ollama/options.go @@ -3,17 +3,19 @@ package ollama import ( "context" "fmt" + "os" "strings" "github.com/docker/docker/api/types/container" "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" ) var noopCustomizeRequestOption = func(req *testcontainers.GenericContainerRequest) error { return nil } // withGpu requests a GPU for the container, which could improve performance for some models. -// This option will be automaticall added to the Ollama container to check if the host supports nvidia. +// This option will be automatically added to the Ollama container to check if the host supports nvidia. func withGpu() testcontainers.CustomizeRequestOption { cli, err := testcontainers.NewDockerClientWithOpts(context.Background()) if err != nil { @@ -40,29 +42,48 @@ func withGpu() testcontainers.CustomizeRequestOption { }) } -var _ testcontainers.ContainerCustomizer = (*UseLocal)(nil) +var _ testcontainers.ContainerCustomizer = (*useLocal)(nil) -// UseLocal will use the local Ollama instance instead of pulling the Docker image. -type UseLocal struct { +// useLocal will use the local Ollama instance instead of pulling the Docker image. +type useLocal struct { env []string } // WithUseLocal the module will use the local Ollama instance instead of pulling the Docker image. // Pass the environment variables you need to set for the Ollama binary to be used, // in the format of "KEY=VALUE". KeyValue pairs with the wrong format will cause an error. -func WithUseLocal(values ...string) UseLocal { - return UseLocal{env: values} +func WithUseLocal(values ...string) useLocal { + return useLocal{env: values} } // Customize implements the ContainerCustomizer interface, taking the key value pairs // and setting them as environment variables for the Ollama binary. // In the case of an invalid key value pair, an error is returned. -func (u UseLocal) Customize(req *testcontainers.GenericContainerRequest) error { - env := make(map[string]string) - for _, kv := range u.env { +func (u useLocal) Customize(req *testcontainers.GenericContainerRequest) error { + // Replace the default host port strategy with one that waits for a log entry. + if err := wait.Walk(&req.WaitingFor, func(w wait.Strategy) error { + if _, ok := w.(*wait.HostPortStrategy); ok { + return wait.VisitRemove + } + + return nil + }); err != nil { + return fmt.Errorf("walk strategies: %w", err) + } + + logStrategy := wait.ForLog(localLogRegex).AsRegexp() + if req.WaitingFor == nil { + req.WaitingFor = logStrategy + } else { + req.WaitingFor = wait.ForAll(req.WaitingFor, logStrategy) + } + + osEnv := os.Environ() + env := make(map[string]string, len(osEnv)+len(u.env)) + for _, kv := range append(osEnv, u.env...) { parts := strings.SplitN(kv, "=", 2) if len(parts) != 2 { - return fmt.Errorf("invalid environment variable: %s", kv) + return fmt.Errorf("invalid environment variable: %q", kv) } env[parts[0]] = parts[1]