Skip to content

Commit

Permalink
Add worker pool to WASM capability (#15088)
Browse files Browse the repository at this point in the history
* [chore] Add worker pool to compute capability

- Also add step-level timeout to engine. This was removed when we moved
  away from ExecuteSync().

* WIP

* Some more comments
  • Loading branch information
cedric-cordenier authored Nov 13, 2024
1 parent 3b3b86c commit 1a9f8cc
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 71 deletions.
119 changes: 103 additions & 16 deletions core/capabilities/compute/compute.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"net/http"
"strings"
"sync"
"time"

"github.com/google/uuid"
Expand All @@ -19,6 +20,7 @@ import (
capabilitiespb "github.com/smartcontractkit/chainlink-common/pkg/capabilities/pb"
"github.com/smartcontractkit/chainlink-common/pkg/custmsg"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/services"
coretypes "github.com/smartcontractkit/chainlink-common/pkg/types/core"
"github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host"
wasmpb "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/pb"
Expand Down Expand Up @@ -73,7 +75,8 @@ var (
var _ capabilities.ActionCapability = (*Compute)(nil)

type Compute struct {
log logger.Logger
stopCh services.StopChan
log logger.Logger

// emitter is used to emit messages from the WASM module to a configured collector.
emitter custmsg.MessageEmitter
Expand All @@ -82,9 +85,13 @@ type Compute struct {

// transformer is used to transform a values.Map into a ParsedConfig struct on each execution
// of a request.
transformer ConfigTransformer
transformer *transformer
outgoingConnectorHandler *webapi.OutgoingConnectorHandler
idGenerator func() string

numWorkers int
queue chan request
wg sync.WaitGroup
}

func (c *Compute) RegisterToWorkflow(ctx context.Context, request capabilities.RegisterToWorkflowRequest) error {
Expand All @@ -100,35 +107,76 @@ func generateID(binary []byte) string {
return fmt.Sprintf("%x", id)
}

func copyRequest(req capabilities.CapabilityRequest) capabilities.CapabilityRequest {
return capabilities.CapabilityRequest{
Metadata: req.Metadata,
Inputs: req.Inputs.CopyMap(),
Config: req.Config.CopyMap(),
func (c *Compute) Execute(ctx context.Context, request capabilities.CapabilityRequest) (capabilities.CapabilityResponse, error) {
ch, err := c.enqueueRequest(ctx, request)
if err != nil {
return capabilities.CapabilityResponse{}, err
}

select {
case <-c.stopCh:
return capabilities.CapabilityResponse{}, errors.New("service shutting down, aborting request")
case <-ctx.Done():
return capabilities.CapabilityResponse{}, fmt.Errorf("request cancelled by upstream: %w", ctx.Err())
case resp := <-ch:
return resp.resp, resp.err
}
}

func (c *Compute) Execute(ctx context.Context, request capabilities.CapabilityRequest) (capabilities.CapabilityResponse, error) {
copied := copyRequest(request)
type request struct {
ch chan response
req capabilities.CapabilityRequest
ctx func() context.Context
}

cfg, err := c.transformer.Transform(copied.Config)
type response struct {
resp capabilities.CapabilityResponse
err error
}

func (c *Compute) enqueueRequest(ctx context.Context, req capabilities.CapabilityRequest) (<-chan response, error) {
ch := make(chan response)
r := request{
ch: ch,
req: req,
ctx: func() context.Context { return ctx },
}
select {
case <-c.stopCh:
return nil, errors.New("service shutting down, aborting request")
case <-ctx.Done():
return nil, fmt.Errorf("could not enqueue request: %w", ctx.Err())
case c.queue <- r:
return ch, nil
}
}

func (c *Compute) execute(ctx context.Context, respCh chan response, req capabilities.CapabilityRequest) {
copiedReq, cfg, err := c.transformer.Transform(req)
if err != nil {
return capabilities.CapabilityResponse{}, fmt.Errorf("invalid request: could not transform config: %w", err)
respCh <- response{err: fmt.Errorf("invalid request: could not transform config: %w", err)}
return
}

id := generateID(cfg.Binary)

m, ok := c.modules.get(id)
if !ok {
mod, err := c.initModule(id, cfg.ModuleConfig, cfg.Binary, request.Metadata)
if err != nil {
return capabilities.CapabilityResponse{}, err
mod, innerErr := c.initModule(id, cfg.ModuleConfig, cfg.Binary, copiedReq.Metadata)
if innerErr != nil {
respCh <- response{err: innerErr}
return
}

m = mod
}

return c.executeWithModule(ctx, m.module, cfg.Config, request)
resp, err := c.executeWithModule(ctx, m.module, cfg.Config, copiedReq)
select {
case <-c.stopCh:
case <-ctx.Done():
case respCh <- response{resp: resp, err: err}:
}
}

func (c *Compute) initModule(id string, cfg *host.ModuleConfig, binary []byte, requestMetadata capabilities.RequestMetadata) (*module, error) {
Expand Down Expand Up @@ -196,11 +244,35 @@ func (c *Compute) Info(ctx context.Context) (capabilities.CapabilityInfo, error)

func (c *Compute) Start(ctx context.Context) error {
c.modules.start()

c.wg.Add(c.numWorkers)
for i := 0; i < c.numWorkers; i++ {
go func() {
innerCtx, cancel := c.stopCh.NewCtx()
defer cancel()

defer c.wg.Done()
c.worker(innerCtx)
}()
}
return c.registry.Add(ctx, c)
}

func (c *Compute) worker(ctx context.Context) {
for {
select {
case <-c.stopCh:
return
case req := <-c.queue:
c.execute(req.ctx(), req.ch, req.req)
}
}
}

func (c *Compute) Close() error {
c.modules.close()
close(c.stopCh)
c.wg.Wait()
return nil
}

Expand Down Expand Up @@ -270,25 +342,40 @@ func (c *Compute) createFetcher() func(ctx context.Context, req *wasmpb.FetchReq
}
}

const (
defaultNumWorkers = 3
)

type Config struct {
webapi.ServiceConfig
NumWorkers int
}

func NewAction(
config webapi.ServiceConfig,
config Config,
log logger.Logger,
registry coretypes.CapabilitiesRegistry,
handler *webapi.OutgoingConnectorHandler,
idGenerator func() string,
opts ...func(*Compute),
) *Compute {
if config.NumWorkers == 0 {
config.NumWorkers = defaultNumWorkers
}
var (
lggr = logger.Named(log, "CustomCompute")
labeler = custmsg.NewLabeler()
compute = &Compute{
stopCh: make(services.StopChan),
log: lggr,
emitter: labeler,
registry: registry,
modules: newModuleCache(clockwork.NewRealClock(), 1*time.Minute, 10*time.Minute, 3),
transformer: NewTransformer(lggr, labeler),
outgoingConnectorHandler: handler,
idGenerator: idGenerator,
queue: make(chan request),
numWorkers: defaultNumWorkers,
}
)

Expand Down
20 changes: 11 additions & 9 deletions core/capabilities/compute/compute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,32 @@ const (
validRequestUUID = "d2fe6db9-beb4-47c9-b2d6-d3065ace111e"
)

var defaultConfig = webapi.ServiceConfig{
RateLimiter: common.RateLimiterConfig{
GlobalRPS: 100.0,
GlobalBurst: 100,
PerSenderRPS: 100.0,
PerSenderBurst: 100,
var defaultConfig = Config{
ServiceConfig: webapi.ServiceConfig{
RateLimiter: common.RateLimiterConfig{
GlobalRPS: 100.0,
GlobalBurst: 100,
PerSenderRPS: 100.0,
PerSenderBurst: 100,
},
},
}

type testHarness struct {
registry *corecapabilities.Registry
connector *gcmocks.GatewayConnector
log logger.Logger
config webapi.ServiceConfig
config Config
connectorHandler *webapi.OutgoingConnectorHandler
compute *Compute
}

func setup(t *testing.T, config webapi.ServiceConfig) testHarness {
func setup(t *testing.T, config Config) testHarness {
log := logger.TestLogger(t)
registry := capabilities.NewRegistry(log)
connector := gcmocks.NewGatewayConnector(t)
idGeneratorFn := func() string { return validRequestUUID }
connectorHandler, err := webapi.NewOutgoingConnectorHandler(connector, config, ghcapabilities.MethodComputeAction, log)
connectorHandler, err := webapi.NewOutgoingConnectorHandler(connector, config.ServiceConfig, ghcapabilities.MethodComputeAction, log)
require.NoError(t, err)

compute := NewAction(config, log, registry, connectorHandler, idGeneratorFn)
Expand Down
54 changes: 31 additions & 23 deletions core/capabilities/compute/transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,13 @@ import (
"fmt"
"time"

"github.com/smartcontractkit/chainlink-common/pkg/capabilities"
"github.com/smartcontractkit/chainlink-common/pkg/custmsg"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/values"
"github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host"
)

type Transformer[T any, U any] interface {
// Transform changes a struct of type T into a struct of type U. Accepts a variadic list of options to modify the
// output struct.
Transform(T, ...func(*U)) (*U, error)
}

// ConfigTransformer is a Transformer that converts a values.Map into a ParsedConfig struct.
type ConfigTransformer = Transformer[*values.Map, ParsedConfig]

// ParsedConfig is a struct that contains the binary and config for a wasm module, as well as the module config.
type ParsedConfig struct {
Binary []byte
Expand All @@ -36,25 +28,41 @@ type transformer struct {
emitter custmsg.MessageEmitter
}

func shallowCopy(m *values.Map) *values.Map {
to := values.EmptyMap()

for k, v := range m.Underlying {
to.Underlying[k] = v
}

return to
}

// Transform attempts to read a valid ParsedConfig from an arbitrary values map. The map must
// contain the binary and config keys. Optionally the map may specify wasm module specific
// configuration values such as maxMemoryMBs, timeout, and tickInterval. Default logger and
// emitter for the module are taken from the transformer instance. Override these values with
// the functional options.
func (t *transformer) Transform(in *values.Map, opts ...func(*ParsedConfig)) (*ParsedConfig, error) {
binary, err := popValue[[]byte](in, binaryKey)
func (t *transformer) Transform(req capabilities.CapabilityRequest, opts ...func(*ParsedConfig)) (capabilities.CapabilityRequest, *ParsedConfig, error) {
copiedReq := capabilities.CapabilityRequest{
Inputs: req.Inputs,
Metadata: req.Metadata,
Config: shallowCopy(req.Config),
}

binary, err := popValue[[]byte](copiedReq.Config, binaryKey)
if err != nil {
return nil, NewInvalidRequestError(err)
return capabilities.CapabilityRequest{}, nil, NewInvalidRequestError(err)
}

config, err := popValue[[]byte](in, configKey)
config, err := popValue[[]byte](copiedReq.Config, configKey)
if err != nil {
return nil, NewInvalidRequestError(err)
return capabilities.CapabilityRequest{}, nil, NewInvalidRequestError(err)
}

maxMemoryMBs, err := popOptionalValue[int64](in, maxMemoryMBsKey)
maxMemoryMBs, err := popOptionalValue[int64](copiedReq.Config, maxMemoryMBsKey)
if err != nil {
return nil, NewInvalidRequestError(err)
return capabilities.CapabilityRequest{}, nil, NewInvalidRequestError(err)
}

mc := &host.ModuleConfig{
Expand All @@ -63,30 +71,30 @@ func (t *transformer) Transform(in *values.Map, opts ...func(*ParsedConfig)) (*P
Labeler: t.emitter,
}

timeout, err := popOptionalValue[string](in, timeoutKey)
timeout, err := popOptionalValue[string](copiedReq.Config, timeoutKey)
if err != nil {
return nil, NewInvalidRequestError(err)
return capabilities.CapabilityRequest{}, nil, NewInvalidRequestError(err)
}

var td time.Duration
if timeout != "" {
td, err = time.ParseDuration(timeout)
if err != nil {
return nil, NewInvalidRequestError(err)
return capabilities.CapabilityRequest{}, nil, NewInvalidRequestError(err)
}
mc.Timeout = &td
}

tickInterval, err := popOptionalValue[string](in, tickIntervalKey)
tickInterval, err := popOptionalValue[string](copiedReq.Config, tickIntervalKey)
if err != nil {
return nil, NewInvalidRequestError(err)
return capabilities.CapabilityRequest{}, nil, NewInvalidRequestError(err)
}

var ti time.Duration
if tickInterval != "" {
ti, err = time.ParseDuration(tickInterval)
if err != nil {
return nil, NewInvalidRequestError(err)
return capabilities.CapabilityRequest{}, nil, NewInvalidRequestError(err)
}
mc.TickInterval = ti
}
Expand All @@ -101,7 +109,7 @@ func (t *transformer) Transform(in *values.Map, opts ...func(*ParsedConfig)) (*P
opt(pc)
}

return pc, nil
return copiedReq, pc, nil
}

func NewTransformer(lggr logger.Logger, emitter custmsg.MessageEmitter) *transformer {
Expand Down
Loading

0 comments on commit 1a9f8cc

Please sign in to comment.