From a6f29bf9f08c34dc92d03d615b62465aecc4e705 Mon Sep 17 00:00:00 2001 From: YuKang Date: Wed, 4 Sep 2024 11:42:25 +0800 Subject: [PATCH 1/2] feat: add inference module --- .../example/dev/example_workspace.yaml | 7 + modules/inference/example/dev/kcl.mod | 10 + modules/inference/example/dev/main.k | 26 ++ modules/inference/example/dev/stack.yaml | 1 + modules/inference/example/project.yaml | 1 + modules/inference/kcl.mod | 3 + modules/inference/src/Makefile | 36 ++ modules/inference/src/go.mod | 49 +++ modules/inference/src/go.sum | 161 ++++++++ modules/inference/src/inference_generator.go | 364 ++++++++++++++++++ .../inference/src/inference_generator_test.go | 338 ++++++++++++++++ modules/inference/v1/inference.k | 60 +++ 12 files changed, 1056 insertions(+) create mode 100644 modules/inference/example/dev/example_workspace.yaml create mode 100644 modules/inference/example/dev/kcl.mod create mode 100644 modules/inference/example/dev/main.k create mode 100644 modules/inference/example/dev/stack.yaml create mode 100644 modules/inference/example/project.yaml create mode 100644 modules/inference/kcl.mod create mode 100644 modules/inference/src/Makefile create mode 100644 modules/inference/src/go.mod create mode 100644 modules/inference/src/go.sum create mode 100644 modules/inference/src/inference_generator.go create mode 100644 modules/inference/src/inference_generator_test.go create mode 100644 modules/inference/v1/inference.k diff --git a/modules/inference/example/dev/example_workspace.yaml b/modules/inference/example/dev/example_workspace.yaml new file mode 100644 index 0000000..ae02970 --- /dev/null +++ b/modules/inference/example/dev/example_workspace.yaml @@ -0,0 +1,7 @@ +# The configuration items in perspective of platform engineers. +modules: + inference: + path: oci://ghcr.io/kusionstack/inference + version: 0.1.0-beta.1 + configs: + default: {} \ No newline at end of file diff --git a/modules/inference/example/dev/kcl.mod b/modules/inference/example/dev/kcl.mod new file mode 100644 index 0000000..d009ae0 --- /dev/null +++ b/modules/inference/example/dev/kcl.mod @@ -0,0 +1,10 @@ +[package] +name = "example" + +[dependencies] +inference = { oci = "oci://ghcr.io/kusionstack/inference", tag = "0.1.0-beta.1" } +service = {oci = "oci://ghcr.io/kusionstack/service", tag = "0.1.0" } +kam = { git = "https://github.com/KusionStack/kam.git", tag = "0.2.0" } + +[profile] +entries = ["main.k"] diff --git a/modules/inference/example/dev/main.k b/modules/inference/example/dev/main.k new file mode 100644 index 0000000..cb0832d --- /dev/null +++ b/modules/inference/example/dev/main.k @@ -0,0 +1,26 @@ +# The configuration codes in perspective of developers. +import kam.v1.app_configuration as ac +import service +import service.container as c +import inference.v1.inference + +inference: ac.AppConfiguration { + # Declare the workload configurations. + workload: service.Service { + containers: { + myct: c.Container {image: "xxx/ai-app"} + } + replicas: 1 + } + # Declare the inference module configurations. + accessories: { + "inference": inference.Inference { + model: "llama3" + framework: "Ollama" + } + "network": n.Network {ports: [n.Port { + port: 80 + public: True + }]} + } +} \ No newline at end of file diff --git a/modules/inference/example/dev/stack.yaml b/modules/inference/example/dev/stack.yaml new file mode 100644 index 0000000..19e96e3 --- /dev/null +++ b/modules/inference/example/dev/stack.yaml @@ -0,0 +1 @@ +name: dev diff --git a/modules/inference/example/project.yaml b/modules/inference/example/project.yaml new file mode 100644 index 0000000..2551cb1 --- /dev/null +++ b/modules/inference/example/project.yaml @@ -0,0 +1 @@ +name: example diff --git a/modules/inference/kcl.mod b/modules/inference/kcl.mod new file mode 100644 index 0000000..98b46f0 --- /dev/null +++ b/modules/inference/kcl.mod @@ -0,0 +1,3 @@ +[package] +name = "inference" +version = "0.1.0-beta.1" diff --git a/modules/inference/src/Makefile b/modules/inference/src/Makefile new file mode 100644 index 0000000..a6c65de --- /dev/null +++ b/modules/inference/src/Makefile @@ -0,0 +1,36 @@ +TEST?=$$(go list ./... | grep -v 'vendor') +###### chang variables below according to your own modules ### +NAMESPACE=kusionstack +NAME=inference +VERSION=0.1.0-beta.1 +BINARY=../bin/kusion-module-${NAME}_${VERSION} + +LOCAL_ARCH := $(shell uname -m) +ifeq ($(LOCAL_ARCH),x86_64) +GOARCH_LOCAL := amd64 +else +GOARCH_LOCAL := $(LOCAL_ARCH) +endif +export GOOS_LOCAL := $(shell uname|tr 'A-Z' 'a-z') +export OS_ARCH ?= $(GOARCH_LOCAL) + +default: install + +build-darwin: + GOOS=darwin GOARCH=arm64 go build -o ${BINARY} . + +install: build-darwin +# copy module binary to $KUSION_HOME. e.g. ~/.kusion/modules/kusionstack/inference/v0.1.0/darwin/arm64/kusion-module-inference_0.1.0 + mkdir -p ${KUSION_HOME}/modules/${NAMESPACE}/${NAME}/${VERSION}/${GOOS_LOCAL}/${OS_ARCH} + cp ${BINARY} ${KUSION_HOME}/modules/${NAMESPACE}/${NAME}/${VERSION}/${GOOS_LOCAL}/${OS_ARCH} + +release: + GOOS=darwin GOARCH=arm64 go build -o ${BINARY}_darwin_arm64 ./${NAME} + GOOS=darwin GOARCH=amd64 go build -o ${BINARY}_darwin_amd64 ./${NAME} + GOOS=linux GOARCH=arm64 go build -o ${BINARY}_linux_arm64 ./${NAME} + GOOS=linux GOARCH=amd64 go build -o ${BINARY}_linux_amd64 ./${NAME} + GOOS=windows GOARCH=amd64 go build -o ${BINARY}_windows_amd64 ./${NAME} + GOOS=windows GOARCH=386 go build -o ${BINARY}_windows_386 ./${NAME} + +test: + TF_ACC=1 go test $(TEST) -v $(TESTARGS) -timeout 5m diff --git a/modules/inference/src/go.mod b/modules/inference/src/go.mod new file mode 100644 index 0000000..1c5c030 --- /dev/null +++ b/modules/inference/src/go.mod @@ -0,0 +1,49 @@ +module inference + +go 1.22.1 + +require ( + github.com/stretchr/testify v1.9.0 + gopkg.in/yaml.v2 v2.4.0 + k8s.io/api v0.30.0 + k8s.io/apimachinery v0.30.0 + kusionstack.io/kusion v0.12.0-rc.3.0.20240612063438-7e50571609dc + kusionstack.io/kusion-module-framework v0.2.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fatih/color v1.16.0 // indirect + github.com/go-logr/logr v1.4.1 // indirect + github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang/protobuf v1.5.4 // indirect + github.com/google/gofuzz v1.2.0 // indirect + github.com/hashicorp/go-hclog v1.6.2 // indirect + github.com/hashicorp/go-plugin v1.6.0 // indirect + github.com/hashicorp/yamux v0.1.1 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mitchellh/go-testing-interface v1.14.1 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/oklog/run v1.1.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + go.uber.org/multierr v1.11.0 // indirect + go.uber.org/zap v1.27.0 // indirect + golang.org/x/net v0.23.0 // indirect + golang.org/x/sys v0.18.0 // indirect + golang.org/x/text v0.14.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240325203815-454cdb8f5daa // indirect + google.golang.org/grpc v1.64.0 // indirect + google.golang.org/protobuf v1.34.1 // indirect + gopkg.in/inf.v0 v0.9.1 // indirect + gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect + k8s.io/klog/v2 v2.120.1 // indirect + k8s.io/utils v0.0.0-20240310230437-4693a0247e57 // indirect + sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect + sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect + sigs.k8s.io/yaml v1.4.0 // indirect +) diff --git a/modules/inference/src/go.sum b/modules/inference/src/go.sum new file mode 100644 index 0000000..09532a8 --- /dev/null +++ b/modules/inference/src/go.sum @@ -0,0 +1,161 @@ +github.com/bufbuild/protocompile v0.4.0 h1:LbFKd2XowZvQ/kajzguUp2DC9UEIQhIq77fZZlaQsNA= +github.com/bufbuild/protocompile v0.4.0/go.mod h1:3v93+mbWn/v3xzN+31nwkJfrEpAUwp+BagBSZWx+TP8= +github.com/bytedance/mockey v1.2.10 h1:4JlMpkm7HMXmTUtItid+iCu2tm61wvq+ca1X2u7ymzE= +github.com/bytedance/mockey v1.2.10/go.mod h1:bNrUnI1u7+pAc0TYDgPATM+wF2yzHxmNH+iDXg4AOCU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= +github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= +github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= +github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= +github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/hashicorp/go-hclog v1.6.2 h1:NOtoftovWkDheyUM/8JW3QMiXyxJK3uHRK7wV04nD2I= +github.com/hashicorp/go-hclog v1.6.2/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= +github.com/hashicorp/go-plugin v1.6.0 h1:wgd4KxHJTVGGqWBq4QPB1i5BZNEx9BR8+OFmHDmTk8A= +github.com/hashicorp/go-plugin v1.6.0/go.mod h1:lBS5MtSSBZk0SHc66KACcjjlU6WzEVP/8pwz68aMkCI= +github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE= +github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= +github.com/jhump/protoreflect v1.15.1 h1:HUMERORf3I3ZdX05WaQ6MIpd/NJ434hTp5YiKgfCL6c= +github.com/jhump/protoreflect v1.15.1/go.mod h1:jD/2GMKKE6OqX8qTjhADU1e6DShO+gavG9e0Q693nKo= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mitchellh/go-testing-interface v1.14.1 h1:jrgshOhYAUVNMAJiKbEu7EqAwgJJ2JqpQmpLJOu07cU= +github.com/mitchellh/go-testing-interface v1.14.1/go.mod h1:gfgS7OtZj6MA4U1UrDRp04twqAjfvlZyCfX3sDjEym8= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/oklog/run v1.1.0 h1:GEenZ1cK0+q0+wsJew9qUg/DyD8k3JzYsZAi5gYi2mA= +github.com/oklog/run v1.1.0/go.mod h1:sVPdnTZT1zYwAJeCMu2Th4T21pA3FPOQRfWjQlk7DVU= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= +github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= +github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +golang.org/x/arch v0.1.0 h1:oMxhUYsO9VsR1dcoVUjJjIGhx1LXol3989T/yZ59Xsw= +golang.org/x/arch v0.1.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= +golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240325203815-454cdb8f5daa h1:RBgMaUMP+6soRkik4VoN8ojR2nex2TqZwjSSogic+eo= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240325203815-454cdb8f5daa/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY= +google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY= +google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg= +google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= +google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +k8s.io/api v0.30.0 h1:siWhRq7cNjy2iHssOB9SCGNCl2spiF1dO3dABqZ8niA= +k8s.io/api v0.30.0/go.mod h1:OPlaYhoHs8EQ1ql0R/TsUgaRPhpKNxIMrKQfWUp8QSE= +k8s.io/apimachinery v0.30.0 h1:qxVPsyDM5XS96NIh9Oj6LavoVFYff/Pon9cZeDIkHHA= +k8s.io/apimachinery v0.30.0/go.mod h1:iexa2somDaxdnj7bha06bhb43Zpa6eWH8N8dbqVjTUc= +k8s.io/klog/v2 v2.120.1 h1:QXU6cPEOIslTGvZaXvFWiP9VKyeet3sawzTOvdXb4Vw= +k8s.io/klog/v2 v2.120.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= +k8s.io/utils v0.0.0-20240310230437-4693a0247e57 h1:gbqbevonBh57eILzModw6mrkbwM0gQBEuevE/AaBsHY= +k8s.io/utils v0.0.0-20240310230437-4693a0247e57/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= +kusionstack.io/kusion v0.12.0-rc.3.0.20240612063438-7e50571609dc h1:ioSIDfFDXKYz99y+ZBsk4B//Wej6I/pVu2dl8t83lN0= +kusionstack.io/kusion v0.12.0-rc.3.0.20240612063438-7e50571609dc/go.mod h1:ibFycMYFQFuZ/JC7XTTysV4ZKZtVZt8rSkvcoA+3d28= +kusionstack.io/kusion-module-framework v0.2.0 h1:aV6q0lisWF4h8K/i08b1A+CoM89JxUHg//i333AwXTM= +kusionstack.io/kusion-module-framework v0.2.0/go.mod h1:rD5yidwI0WVsgcBYtG2Wxb0ibw4pSTt+53ZCjF880OI= +sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo= +sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0= +sigs.k8s.io/structured-merge-diff/v4 v4.4.1 h1:150L+0vs/8DA78h1u02ooW1/fFq/Lwr+sGiqlzvrtq4= +sigs.k8s.io/structured-merge-diff/v4 v4.4.1/go.mod h1:N8hJocpFajUSSeSJ9bOZ77VzejKZaXsTtZo4/u7Io08= +sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E= +sigs.k8s.io/yaml v1.4.0/go.mod h1:Ejl7/uTz7PSA4eKMyQCUTnhZYNmLIl+5c2lQPGR2BPY= diff --git a/modules/inference/src/inference_generator.go b/modules/inference/src/inference_generator.go new file mode 100644 index 0000000..4b536fa --- /dev/null +++ b/modules/inference/src/inference_generator.go @@ -0,0 +1,364 @@ +package main + +import ( + "context" + "errors" + "fmt" + "strings" + + "gopkg.in/yaml.v2" + appsv1 "k8s.io/api/apps/v1" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "kusionstack.io/kusion-module-framework/pkg/module" + "kusionstack.io/kusion-module-framework/pkg/server" + apiv1 "kusionstack.io/kusion/pkg/apis/api.kusion.io/v1" + "kusionstack.io/kusion/pkg/log" +) + +var ( + ErrUnsupportFramework = errors.New("framework must be Ollama or KubeRay") + ErrRangeTopK = errors.New("topK must be greater than 0 if exist") + ErrRangeTopP = errors.New("topP must be greater than 0 and less than or equal to 1 if exist") + ErrRangeTemperature = errors.New("temperature must be greater than 0 if exist") + ErrRangeNumPredict = errors.New("numPredict must be greater than or equal to -2") + ErrRangeNumCtx = errors.New("numCtx must be greater than 0 if exist") +) + +var ( + inferDeploymentSuffix = "-infer-deployment" + inferStorageSuffix = "-infer-storage" + inferServiceSuffix = "-infer-service" +) + +var ( + defaultTopK int = 40 + defaultTopP float64 = 0.9 + defaultTemperature float64 = 0.8 + defaultNumPredict int = 128 + defaultNumCtx int = 2048 +) + +var ( + OllamaType = "ollama" +) + +var ( + OllamaImage = "ollama" +) + +func main() { + server.Start(&Inference{}) +} + +// Inference implements the Kusion Module generator interface. +type Inference struct { + Model string `yaml:"model,omitempty" json:"model,omitempty"` + Framework string `yaml:"framework,omitempty" json:"framework,omitempty"` + System string `yaml:"system,omitempty" json:"system,omitempty"` + Template string `yaml:"template,omitempty" json:"template,omitempty"` + TopK int `yaml:"top_k,omitempty" json:"top_k,omitempty"` + TopP float64 `yaml:"top_p,omitempty" json:"top_p,omitempty"` + Temperature float64 `yaml:"temperature,omitempty" json:"temperature,omitempty"` + NumPredict int `yaml:"num_predict,omitempty" json:"num_predict,omitempty"` + NumCtx int `yaml:"num_ctx,omitempty" json:"num_ctx,omitempty"` +} + +func (infer *Inference) Generate(_ context.Context, request *module.GeneratorRequest) (*module.GeneratorResponse, error) { + defer func() { + if r := recover(); r != nil { + log.Debugf("failed to generate inference module: %v", r) + } + }() + + // Inference module does not exist in AppConfiguration configs. + if request.DevConfig == nil { + log.Info("Inference does not exist in AppConfig config") + return nil, nil + } + + // Get the complete inference module configs. + if err := infer.CompleteConfig(request.DevConfig, request.PlatformConfig); err != nil { + log.Debugf("failed to get complete inference module configs: %v", err) + return nil, err + } + + // Validate the completed inference module configs. + if err := infer.ValidateConfig(); err != nil { + log.Debugf("failed to validate the inference module configs: %v", err) + return nil, err + } + + // var resources []apiv1.Resource + // var patcher *apiv1.Patcher + // var err error + + // switch strings.ToLower(infer.Framework) { + // case OllamaType: + // resources, patcher, err = infer.GenerateInferenceResource(request) + // default: + // return nil, ErrUnsupportFramework + // } + + // Generate the Kubernetes Service related resource. + resources, patcher, err := infer.GenerateInferenceResource(request) + if err != nil { + return nil, err + } + + // Return the Kusion generator response. + return &module.GeneratorResponse{ + Resources: resources, + Patcher: patcher, + }, nil +} + +// CompleteConfig completes the inference module configs with both devModuleConfig and platformModuleConfig. +func (infer *Inference) CompleteConfig(devConfig apiv1.Accessory, platformConfig apiv1.GenericConfig) error { + infer.TopK = defaultTopK + infer.TopP = defaultTopP + infer.Temperature = defaultTemperature + infer.NumPredict = defaultNumPredict + infer.NumCtx = defaultNumCtx + + // Retrieve the config items the developers are concerned about. + if devConfig != nil { + devCfgYamlStr, err := yaml.Marshal(devConfig) + if err != nil { + return err + } + + if err = yaml.Unmarshal(devCfgYamlStr, infer); err != nil { + return err + } + } + // Retrieve the config items the platform engineers care about. + if platformConfig != nil { + platformCfgYamlStr, err := yaml.Marshal(platformConfig) + if err != nil { + return err + } + + if err = yaml.Unmarshal(platformCfgYamlStr, infer); err != nil { + return err + } + } + return nil +} + +// ValidateConfig validates the completed inference configs are valid or not. +func (infer *Inference) ValidateConfig() error { + if infer.Framework != "Ollama" && infer.Framework != "KubeRay" { + return ErrUnsupportFramework + } + if infer.TopK <= 0 { + return ErrRangeTopK + } + if infer.TopP <= 0 || infer.TopP > 1 { + return ErrRangeTopP + } + if infer.Temperature <= 0 { + return ErrRangeTemperature + } + if infer.NumPredict < -2 { + return ErrRangeNumPredict + } + if infer.NumCtx <= 0 { + return ErrRangeNumCtx + } + return nil +} + +// GenerateInferenceResource generates the Kubernetes Service related to the inference module service. +// +// Note that we will use the SDK provided by the kusion module framework to wrap the Kubernetes resource +// into Kusion resource. +func (infer *Inference) GenerateInferenceResource(request *module.GeneratorRequest) ([]apiv1.Resource, *apiv1.Patcher, error) { + var resources []apiv1.Resource + + // Build Kubernetes Deployment for the Inference instance. + deployment, err := infer.generateDeployment(request) + if err != nil { + return nil, nil, err + } + resources = append(resources, *deployment) + + // Build Kubernetes Service for the Inference instance. + svc, svcName, err := infer.generateService(request) + if err != nil { + return nil, nil, err + } + resources = append(resources, *svc) + + envVars := []v1.EnvVar{ + { + Name: "INFERENCE_PATH", + Value: svcName, + }, + } + patcher := &apiv1.Patcher{ + Environments: envVars, + } + + return resources, patcher, nil +} + +// generatePodSpec generates the Kubernetes PodSpec for the Inference instance. +func (infer *Inference) generatePodSpec(_ *module.GeneratorRequest) (v1.PodSpec, error) { + var mountPath string + var modelPullCmd []string + var containerPort int32 + switch infer.Framework { + case "ollama": + mountPath = "/root/.ollama" + + var builder strings.Builder + builder.WriteString("'") + builder.WriteString(fmt.Sprintf("FROM %s\n", infer.Model)) + if infer.System != "" { + builder.WriteString(fmt.Sprintf(`SYSTEM """%s"""`, infer.System)) + builder.WriteString("\n") + } + if infer.Template != "" { + builder.WriteString(fmt.Sprintf(`TEMPLATE """%s""""`, infer.Template)) + builder.WriteString("\n") + } + builder.WriteString(fmt.Sprintf("PARAMETER top_k %d\n", infer.TopK)) + builder.WriteString(fmt.Sprintf("PARAMETER top_p %f\n", infer.TopP)) + builder.WriteString(fmt.Sprintf("PARAMETER temperature %f\n", infer.Temperature)) + builder.WriteString(fmt.Sprintf("PARAMETER num_predict %d\n", infer.NumPredict)) + builder.WriteString(fmt.Sprintf("PARAMETER num_ctx %d\n", infer.NumCtx)) + builder.WriteString("'") + + var commandParts []string + commandParts = append(commandParts, fmt.Sprintf("echo %s > Modelfile", builder.String())) + commandParts = append(commandParts, fmt.Sprintf("ollama create %s -f Modelfile", infer.Model)) + + modelPullCmd = append(modelPullCmd, "/bin/sh", "-c", strings.Join(commandParts, " && ")) + containerPort = 11434 + default: + } + + image := OllamaImage + + volumes := []v1.Volume{ + { + Name: infer.Framework + inferStorageSuffix, + VolumeSource: v1.VolumeSource{ + EmptyDir: &v1.EmptyDirVolumeSource{}, + }, + }, + } + + volumeMounts := []v1.VolumeMount{ + { + Name: infer.Framework + inferStorageSuffix, + MountPath: mountPath, + }, + } + + ports := []v1.ContainerPort{ + { + Name: infer.Framework, + ContainerPort: containerPort, + }, + } + + podSpec := v1.PodSpec{ + Containers: []v1.Container{ + { + Name: infer.Framework, + Image: image, + Ports: ports, + Command: modelPullCmd, + VolumeMounts: volumeMounts, + }, + }, + Volumes: volumes, + } + return podSpec, nil +} + +// generateDeployment generates the Kubernetes Deployment resource for the Inference instance. +func (infer *Inference) generateDeployment(request *module.GeneratorRequest) (*apiv1.Resource, error) { + // Prepare the Pod Spec for the Inference instance. + podSpec, err := infer.generatePodSpec(request) + if err != nil { + return nil, nil + } + + // Create the Kubernetes Deployment for the Inference instance. + deployment := &appsv1.Deployment{ + TypeMeta: metav1.TypeMeta{ + Kind: "Deployment", + APIVersion: appsv1.SchemeGroupVersion.String(), + }, + ObjectMeta: metav1.ObjectMeta{ + Name: infer.Framework + inferDeploymentSuffix, + Namespace: request.Project, + }, + Spec: appsv1.DeploymentSpec{ + Selector: &metav1.LabelSelector{ + MatchLabels: infer.generateMatchLabels(), + }, + Template: v1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: infer.generateMatchLabels(), + }, + Spec: podSpec, + }, + }, + } + + resourceID := module.KubernetesResourceID(deployment.TypeMeta, deployment.ObjectMeta) + resource, err := module.WrapK8sResourceToKusionResource(resourceID, deployment) + if err != nil { + return nil, err + } + + return resource, nil +} + +// generateService generates the Kubernetes Service resource for the Inference instance. +func (infer *Inference) generateService(request *module.GeneratorRequest) (*apiv1.Resource, string, error) { + // Prepare the service port for the Inference instance. + svcName := infer.Framework + inferServiceSuffix + svcPort := []v1.ServicePort{ + { + Port: int32(80), + }, + } + + // Create the Kubernetes service for Inference instance. + service := &v1.Service{ + TypeMeta: metav1.TypeMeta{ + Kind: "Service", + APIVersion: v1.SchemeGroupVersion.String(), + }, + ObjectMeta: metav1.ObjectMeta{ + Name: svcName, + Namespace: request.Project, + Labels: infer.generateMatchLabels(), + }, + Spec: v1.ServiceSpec{ + Type: v1.ServiceTypeClusterIP, + Ports: svcPort, + Selector: infer.generateMatchLabels(), + }, + } + + resourceID := module.KubernetesResourceID(service.TypeMeta, service.ObjectMeta) + resource, err := module.WrapK8sResourceToKusionResource(resourceID, service) + if err != nil { + return nil, svcName, err + } + + return resource, svcName, nil +} + +// generateMatchLabels generates the match labels for the Kubernetes resources of the Inference instance. +func (infer *Inference) generateMatchLabels() map[string]string { + return map[string]string{ + "accessory": infer.Framework, + } +} diff --git a/modules/inference/src/inference_generator_test.go b/modules/inference/src/inference_generator_test.go new file mode 100644 index 0000000..801215d --- /dev/null +++ b/modules/inference/src/inference_generator_test.go @@ -0,0 +1,338 @@ +package main + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "kusionstack.io/kusion-module-framework/pkg/module" + apiv1 "kusionstack.io/kusion/pkg/apis/api.kusion.io/v1" + v1 "kusionstack.io/kusion/pkg/apis/api.kusion.io/v1" +) + +func TestInferenceModule_Generator(t *testing.T) { + r := &module.GeneratorRequest{ + Project: "test-project", + Stack: "test-stack", + App: "test-app", + Workload: &apiv1.Workload{ + Header: apiv1.Header{ + Type: "Service", + }, + Service: &apiv1.Service{}, + }, + DevConfig: apiv1.Accessory{ + "model": "llama3", + "framework": "Ollama", + }, + PlatformConfig: nil, + } + + infer := &Inference{} + res, err := infer.Generate(context.Background(), r) + + assert.NoError(t, err) + assert.NotNil(t, res) +} + +func TestInferenceModule_CompleteConfig(t *testing.T) { + testcases := []struct { + name string + devModuleConfig apiv1.Accessory + platformConfig apiv1.GenericConfig + expectedInference *Inference + }{ + { + name: "Default inference config", + devModuleConfig: apiv1.Accessory{ + "model": "qwen", + "framework": "Ollama", + }, + platformConfig: nil, + expectedInference: &Inference{ + Model: "qwen", + Framework: "Ollama", + System: "", + Template: "", + TopK: 40, + TopP: 0.9, + Temperature: 0.8, + NumPredict: 128, + NumCtx: 2048, + }, + }, + { + name: "Custom inference config", + devModuleConfig: apiv1.Accessory{ + "model": "qwen", + "framework": "Ollama", + "top_k": 50, + "top_p": 0.5, + "temperature": 0.5, + "num_predict": 256, + "num_ctx": 4096, + }, + platformConfig: nil, + expectedInference: &Inference{ + Model: "qwen", + Framework: "Ollama", + System: "", + Template: "", + TopK: 50, + TopP: 0.5, + Temperature: 0.5, + NumPredict: 256, + NumCtx: 4096, + }, + }, + } + + for _, tc := range testcases { + infer := &Inference{} + t.Run(tc.name, func(t *testing.T) { + _ = infer.CompleteConfig(tc.devModuleConfig, tc.platformConfig) + assert.Equal(t, tc.expectedInference, infer) + }) + } +} + +func TestInferenceModule_ValidateConfig(t *testing.T) { + t.Run("validate no error", func(t *testing.T) { + infer := &Inference{ + Model: "qwen", + Framework: "Ollama", + System: "", + Template: "", + TopK: 40, + TopP: 0.9, + Temperature: 0.8, + NumPredict: 128, + NumCtx: 2048, + } + err := infer.ValidateConfig() + assert.NoError(t, err) + }) + + t.Run("test framework", func(t *testing.T) { + infer := &Inference{ + Model: "qwen", + Framework: "unsupport_framework", + System: "", + Template: "", + TopK: 40, + TopP: 0.9, + Temperature: 0.8, + NumPredict: 128, + NumCtx: 2048, + } + err := infer.ValidateConfig() + assert.ErrorContains(t, err, ErrUnsupportFramework.Error()) + }) + + t.Run("test top_k", func(t *testing.T) { + infer := &Inference{ + Model: "qwen", + Framework: "Ollama", + System: "", + Template: "", + TopK: 0, + TopP: 0.9, + Temperature: 0.8, + NumPredict: 128, + NumCtx: 2048, + } + err := infer.ValidateConfig() + assert.ErrorContains(t, err, ErrRangeTopK.Error()) + }) + + t.Run("test top_p", func(t *testing.T) { + infer := &Inference{ + Model: "qwen", + Framework: "Ollama", + System: "", + Template: "", + TopK: 40, + TopP: 2, + Temperature: 0.8, + NumPredict: 128, + NumCtx: 2048, + } + err := infer.ValidateConfig() + assert.ErrorContains(t, err, ErrRangeTopP.Error()) + }) + + t.Run("test temperature", func(t *testing.T) { + infer := &Inference{ + Model: "qwen", + Framework: "Ollama", + System: "", + Template: "", + TopK: 40, + TopP: 0.9, + Temperature: 0, + NumPredict: 128, + NumCtx: 2048, + } + err := infer.ValidateConfig() + assert.ErrorContains(t, err, ErrRangeTemperature.Error()) + }) + + t.Run("test num_predict", func(t *testing.T) { + infer := &Inference{ + Model: "qwen", + Framework: "Ollama", + System: "", + Template: "", + TopK: 40, + TopP: 0.9, + Temperature: 0.8, + NumPredict: -100, + NumCtx: 2048, + } + err := infer.ValidateConfig() + assert.ErrorContains(t, err, ErrRangeNumPredict.Error()) + }) + + t.Run("test num_ctx", func(t *testing.T) { + infer := &Inference{ + Model: "qwen", + Framework: "Ollama", + System: "", + Template: "", + TopK: 40, + TopP: 0.9, + Temperature: 0.8, + NumPredict: 128, + NumCtx: -100, + } + err := infer.ValidateConfig() + assert.ErrorContains(t, err, ErrRangeNumCtx.Error()) + }) +} + +func TestInferenceModule_GenerateInferenceResource(t *testing.T) { + r := &module.GeneratorRequest{ + Project: "test-project", + Stack: "test-stack", + App: "test-app", + Workload: &v1.Workload{ + Header: v1.Header{ + Type: "Service", + }, + Service: &v1.Service{}, + }, + } + + infer := &Inference{ + Model: "qwen", + Framework: "Ollama", + System: "", + Template: "", + TopK: 40, + TopP: 0.9, + Temperature: 0.8, + NumPredict: 128, + NumCtx: 2048, + } + + res, patch, err := infer.GenerateInferenceResource(r) + + assert.NotNil(t, res) + assert.NotNil(t, patch) + assert.NoError(t, err) +} + +func TestInferenceModule_GeneratePodSpec(t *testing.T) { + r := &module.GeneratorRequest{ + Project: "test-project", + Stack: "test-stack", + App: "test-app", + Workload: &v1.Workload{ + Header: v1.Header{ + Type: "Service", + }, + Service: &v1.Service{}, + }, + } + + infer := &Inference{ + Model: "qwen", + Framework: "Ollama", + System: "", + Template: "", + TopK: 40, + TopP: 0.9, + Temperature: 0.8, + NumPredict: 128, + NumCtx: 2048, + } + + res, err := infer.generatePodSpec(r) + + assert.NotNil(t, res) + assert.NoError(t, err) +} + +func TestInferenceModule_GenerateDeployment(t *testing.T) { + r := &module.GeneratorRequest{ + Project: "test-project", + Stack: "test-stack", + App: "test-app", + Workload: &v1.Workload{ + Header: v1.Header{ + Type: "Service", + }, + Service: &v1.Service{}, + }, + } + + infer := &Inference{ + Model: "qwen", + Framework: "Ollama", + System: "", + Template: "", + TopK: 40, + TopP: 0.9, + Temperature: 0.8, + NumPredict: 128, + NumCtx: 2048, + } + + res, err := infer.generateDeployment(r) + + assert.NotNil(t, res) + assert.NoError(t, err) +} + +func TestInferenceModule_GenerateService(t *testing.T) { + r := &module.GeneratorRequest{ + Project: "test-project", + Stack: "test-stack", + App: "test-app", + Workload: &v1.Workload{ + Header: v1.Header{ + Type: "Service", + }, + Service: &v1.Service{}, + }, + } + + infer := &Inference{ + Model: "qwen", + Framework: "Ollama", + System: "", + Template: "", + TopK: 40, + TopP: 0.9, + Temperature: 0.8, + NumPredict: 128, + NumCtx: 2048, + } + + res, svcName, err := infer.generateService(r) + + assert.NotNil(t, res) + assert.NotNil(t, svcName) + assert.Equal(t, infer.Framework+inferServiceSuffix, svcName) + assert.NoError(t, err) +} diff --git a/modules/inference/v1/inference.k b/modules/inference/v1/inference.k new file mode 100644 index 0000000..831218f --- /dev/null +++ b/modules/inference/v1/inference.k @@ -0,0 +1,60 @@ +schema Inference: + """ Inference is a module schema consisting of model, framework and so on + + Attributes + ---------- + model: str, default is Undefined, required. + The model name to be used for inference. + framework: "Ollama" | "KubeRay", default is Undefined, required. + The framework or environment in which the model operates. + system: str, default is "". + The system message, which will be set in the template. + template: str, default is "". + The full prompt template, which will be sent to the model. + top_k: int, default is 40. + A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. + top_p: float, default is 0.9. + A higher value (e.g. 0.9) will give more diverse answers, while a lower value (e.g. 0.5) will be more conservative. + num_predict: int, default is 128. + Maximum number of tokens to predict when generating text. + num_ctx: int, default is 2048. + The size of the context window used to generate the next token. + + Examples + -------- + import inference.v1.infer + + accessories: { + "inference@v0.1.0": infer.Inference { + model: "llama3" + framework: "Ollama" + + system: "You are Mario from super mario bros, acting as an assistant." + template: "{{ if .System }}<|im_start|>system {{ .System }}<|im_end|> {{ end }}{{ if .Prompt }}<|im_start|>user {{ .Prompt }}<|im_end|> {{ end }}<|im_start|>assistant" + + top_k: 40 + top_p: 0.9 + temperature: 0.8 + + num_predict: 128 + num_ctx: 2048 + } + } + """ + model: str + framework: "Ollama" | "KubeRay" + system?: str = "" + template?: str = "" + top_k?: int = 40 + top_p?: float = 0.9 + temperature?: float = 0.8 + num_predict?: int = 128 + num_ctx?: int = 2048 + + check: + 0 < top_k if top_k, "top_k must be more than 0" + 0 < top_p <= 1 if top_p, "top_p must be greater than 0 and less than or equal to 1" + 0 < temperature if temperature, "temperature must be more than 0" + -2 <= num_predict if num_predict, "num_predict must be greater than or equal to -2" + 0 < num_ctx if num_ctx, "num_ctx must be greater than 0" + From 11645cbb2c074685d7b0c467e46a271e9359ce61 Mon Sep 17 00:00:00 2001 From: YuKang Date: Wed, 4 Sep 2024 14:18:28 +0800 Subject: [PATCH 2/2] fix: correct inference module pod spec generate --- modules/inference/example/dev/example_workspace.yaml | 2 +- modules/inference/example/dev/kcl.mod | 2 +- modules/inference/kcl.mod | 2 +- modules/inference/src/Makefile | 2 +- modules/inference/src/inference_generator.go | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/inference/example/dev/example_workspace.yaml b/modules/inference/example/dev/example_workspace.yaml index ae02970..6d5c36a 100644 --- a/modules/inference/example/dev/example_workspace.yaml +++ b/modules/inference/example/dev/example_workspace.yaml @@ -2,6 +2,6 @@ modules: inference: path: oci://ghcr.io/kusionstack/inference - version: 0.1.0-beta.1 + version: 0.1.0-beta.2 configs: default: {} \ No newline at end of file diff --git a/modules/inference/example/dev/kcl.mod b/modules/inference/example/dev/kcl.mod index d009ae0..9cc7df5 100644 --- a/modules/inference/example/dev/kcl.mod +++ b/modules/inference/example/dev/kcl.mod @@ -2,7 +2,7 @@ name = "example" [dependencies] -inference = { oci = "oci://ghcr.io/kusionstack/inference", tag = "0.1.0-beta.1" } +inference = { oci = "oci://ghcr.io/kusionstack/inference", tag = "0.1.0-beta.2" } service = {oci = "oci://ghcr.io/kusionstack/service", tag = "0.1.0" } kam = { git = "https://github.com/KusionStack/kam.git", tag = "0.2.0" } diff --git a/modules/inference/kcl.mod b/modules/inference/kcl.mod index 98b46f0..e6dabeb 100644 --- a/modules/inference/kcl.mod +++ b/modules/inference/kcl.mod @@ -1,3 +1,3 @@ [package] name = "inference" -version = "0.1.0-beta.1" +version = "0.1.0-beta.2" diff --git a/modules/inference/src/Makefile b/modules/inference/src/Makefile index a6c65de..4ab7e41 100644 --- a/modules/inference/src/Makefile +++ b/modules/inference/src/Makefile @@ -2,7 +2,7 @@ TEST?=$$(go list ./... | grep -v 'vendor') ###### chang variables below according to your own modules ### NAMESPACE=kusionstack NAME=inference -VERSION=0.1.0-beta.1 +VERSION=0.1.0-beta.2 BINARY=../bin/kusion-module-${NAME}_${VERSION} LOCAL_ARCH := $(shell uname -m) diff --git a/modules/inference/src/inference_generator.go b/modules/inference/src/inference_generator.go index 4b536fa..a9e12e4 100644 --- a/modules/inference/src/inference_generator.go +++ b/modules/inference/src/inference_generator.go @@ -208,8 +208,8 @@ func (infer *Inference) generatePodSpec(_ *module.GeneratorRequest) (v1.PodSpec, var mountPath string var modelPullCmd []string var containerPort int32 - switch infer.Framework { - case "ollama": + switch strings.ToLower(infer.Framework) { + case OllamaType: mountPath = "/root/.ollama" var builder strings.Builder