diff --git a/sdk/go/README.md b/sdk/go/README.md new file mode 100644 index 00000000..bb768469 --- /dev/null +++ b/sdk/go/README.md @@ -0,0 +1,61 @@ +# Go SDK + +The DStack SDK for Go. + +# Installation + +```bash +go get github.com/Dstack-TEE/dstack/sdk/go +``` + +# Usage + +```go +package main + +import ( + "context" + "fmt" + "log/slog" + + "github.com/Dstack-TEE/dstack/sdk/go/tappd" +) + +func main() { + client := tappd.NewTappdClient( + // tappd.WithEndpoint("http://localhost"), + // tappd.WithLogger(slog.Default()), + ) + + deriveKeyResp, err := client.DeriveKey(context.Background(), "/") + if err != nil { + fmt.Println(err) + return + } + fmt.Println(deriveKeyResp) // &{-----BEGIN PRIVATE KEY--- ... + + tdxQuoteResp, err := client.TdxQuote(context.Background(), []byte("test")) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(tdxQuoteResp) // &{0x0000000000000000000 ... + + rtmrs, err := tdxQuoteResp.ReplayRTMRs() + if err != nil { + fmt.Println(err) + return + } + fmt.Println(rtmrs) // map[0:00000000000000000 ... +} +``` + +# For Development + +Set up [Go](https://go.dev/doc/install). + +Running the unit tests with local simulator via `/tmp/tappd.sock`: + +```bash +DSTACK_SIMULATOR_ENDPOINT=/tmp/tappd.sock go test ./... +``` diff --git a/sdk/go/go.mod b/sdk/go/go.mod new file mode 100644 index 00000000..95ac1855 --- /dev/null +++ b/sdk/go/go.mod @@ -0,0 +1,3 @@ +module github.com/Dstack-TEE/dstack/sdk/go + +go 1.21.6 diff --git a/sdk/go/go.sum b/sdk/go/go.sum new file mode 100644 index 00000000..e69de29b diff --git a/sdk/go/tappd/client.go b/sdk/go/tappd/client.go new file mode 100644 index 00000000..398e67bf --- /dev/null +++ b/sdk/go/tappd/client.go @@ -0,0 +1,302 @@ +// Provides a Dstack SDK Tappd client and related utilities +// +// Author: Franco Barpp Gomes +package tappd + +import ( + "bytes" + "context" + "crypto/sha512" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "os" + "strings" +) + +// Represents the hash algorithm used in TDX quote generation. +type QuoteHashAlgorithm string + +const ( + SHA256 QuoteHashAlgorithm = "sha256" + SHA384 QuoteHashAlgorithm = "sha384" + SHA512 QuoteHashAlgorithm = "sha512" + SHA3_256 QuoteHashAlgorithm = "sha3-256" + SHA3_384 QuoteHashAlgorithm = "sha3-384" + SHA3_512 QuoteHashAlgorithm = "sha3-512" + KECCAK256 QuoteHashAlgorithm = "keccak256" + KECCAK384 QuoteHashAlgorithm = "keccak384" + KECCAK512 QuoteHashAlgorithm = "keccak512" + RAW QuoteHashAlgorithm = "raw" +) + +// Represents the response from a key derivation request. +type DeriveKeyResponse struct { + Key string `json:"key"` + CertificateChain []string `json:"certificate_chain"` +} + +// Decodes the key to bytes, optionally truncating to maxLength. If maxLength +// < 0, the key is not truncated. +func (d *DeriveKeyResponse) ToBytes(maxLength int) ([]byte, error) { + content := d.Key + + content = strings.Replace(content, "-----BEGIN PRIVATE KEY-----", "", 1) + content = strings.Replace(content, "-----END PRIVATE KEY-----", "", 1) + content = strings.Replace(content, "\n", "", -1) + + binary, err := base64.StdEncoding.DecodeString(content) + if err != nil { + return nil, err + } + + if maxLength >= 0 && len(binary) > maxLength { + return binary[:maxLength], nil + } + return binary, nil +} + +// Represents the response from a TDX quote request. +type TdxQuoteResponse struct { + Quote string `json:"quote"` + EventLog string `json:"event_log"` +} + +const INIT_MR = "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + +// Replays the RTMR history to calculate final RTMR values +func replayRTMR(history []string) (string, error) { + if len(history) == 0 { + return INIT_MR, nil + } + + mr := make([]byte, 48) + + for _, content := range history { + contentBytes, err := hex.DecodeString(content) + if err != nil { + return "", err + } + + if len(contentBytes) < 48 { + padding := make([]byte, 48-len(contentBytes)) + contentBytes = append(contentBytes, padding...) + } + + h := sha512.New384() + h.Write(append(mr, contentBytes...)) + mr = h.Sum(nil) + } + + return hex.EncodeToString(mr), nil +} + +// Replays the RTMR history to calculate final RTMR values +func (r *TdxQuoteResponse) ReplayRTMRs() (map[int]string, error) { + var eventLog []struct { + IMR int `json:"imr"` + Digest string `json:"digest"` + } + json.Unmarshal([]byte(r.EventLog), &eventLog) + + rtmrs := make(map[int]string, 4) + for idx := 0; idx < 4; idx++ { + history := make([]string, 0) + for _, event := range eventLog { + if event.IMR == idx { + history = append(history, event.Digest) + } + } + + rtmr, err := replayRTMR(history) + if err != nil { + return nil, err + } + + rtmrs[idx] = rtmr + } + + return rtmrs, nil +} + +// Handles communication with the Tappd service. +type TappdClient struct { + endpoint string + baseURL string + httpClient *http.Client + logger *slog.Logger +} + +// Functional option for configuring a TappdClient. +type TappdClientOption func(*TappdClient) + +// Sets the endpoint for the TappdClient. +func WithEndpoint(endpoint string) TappdClientOption { + return func(c *TappdClient) { + c.endpoint = endpoint + } +} + +// Sets the logger for the TappdClient +func WithLogger(logger *slog.Logger) TappdClientOption { + return func(c *TappdClient) { + c.logger = logger + } +} + +// Creates a new TappdClient instance based on the provided endpoint. +// If the endpoint is empty, it will use the simulator endpoint if it is +// set in the environment through DSTACK_SIMULATOR_ENDPOINT. Otherwise, it +// will use the default endpoint at /var/run/tappd.sock. +func NewTappdClient(opts ...TappdClientOption) *TappdClient { + client := &TappdClient{ + endpoint: "", + baseURL: "", + httpClient: &http.Client{}, + logger: slog.Default(), + } + + for _, opt := range opts { + opt(client) + } + + client.endpoint = client.getEndpoint() + + if strings.HasPrefix(client.endpoint, "http://") || strings.HasPrefix(client.endpoint, "https://") { + client.baseURL = client.endpoint + } else { + client.baseURL = "http://localhost" + client.httpClient = &http.Client{ + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", client.endpoint) + }, + }, + } + } + + return client +} + +// Returns the appropriate endpoint based on environment and input. If the +// endpoint is empty, it will use the simulator endpoint if it is set in the +// environment through DSTACK_SIMULATOR_ENDPOINT. Otherwise, it will use the +// default endpoint at /var/run/tappd.sock. +func (c *TappdClient) getEndpoint() string { + if c.endpoint != "" { + return c.endpoint + } + if simEndpoint, exists := os.LookupEnv("DSTACK_SIMULATOR_ENDPOINT"); exists { + c.logger.Info("using simulator endpoint", "endpoint", simEndpoint) + return simEndpoint + } + return "/var/run/tappd.sock" +} + +// Sends an RPC request to the Tappd service. +func (c *TappdClient) sendRPCRequest(ctx context.Context, path string, payload interface{}) ([]byte, error) { + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+path, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + return io.ReadAll(resp.Body) +} + +// Derives a key from the Tappd service. This wraps +// DeriveKeyWithSubjectAndAltNames using the path as the subject and an empty +// altNames. +func (c *TappdClient) DeriveKey(ctx context.Context, path string) (*DeriveKeyResponse, error) { + return c.DeriveKeyWithSubjectAndAltNames(ctx, path, path, nil) +} + +// Derives a key from the Tappd service. This wraps +// DeriveKeyWithSubjectAndAltNames using an empty altNames. +func (c *TappdClient) DeriveKeyWithSubject(ctx context.Context, path string, subject string) (*DeriveKeyResponse, error) { + return c.DeriveKeyWithSubjectAndAltNames(ctx, path, subject, nil) +} + +// Derives a key from the Tappd service, explicitly setting the subject and +// altNames. +func (c *TappdClient) DeriveKeyWithSubjectAndAltNames(ctx context.Context, path string, subject string, altNames []string) (*DeriveKeyResponse, error) { + if subject == "" { + subject = path + } + + payload := map[string]interface{}{ + "path": path, + "subject": subject, + } + if len(altNames) > 0 { + payload["alt_names"] = altNames + } + + data, err := c.sendRPCRequest(ctx, "/prpc/Tappd.DeriveKey", payload) + if err != nil { + return nil, err + } + + var response DeriveKeyResponse + if err := json.Unmarshal(data, &response); err != nil { + return nil, err + } + return &response, nil +} + +// Sends a TDX quote request to the Tappd service using SHA512 as the report +// data hash algorithm. +func (c *TappdClient) TdxQuote(ctx context.Context, reportData []byte) (*TdxQuoteResponse, error) { + return c.TdxQuoteWithHashAlgorithm(ctx, reportData, SHA512) +} + +// Sends a TDX quote request to the Tappd service with a specific hash +// report data hash algorithm. If the hash algorithm is RAW, the report data +// must be at most 64 bytes - if it's below that, it will be left-padded with +// zeros. +func (c *TappdClient) TdxQuoteWithHashAlgorithm(ctx context.Context, reportData []byte, hashAlgorithm QuoteHashAlgorithm) (*TdxQuoteResponse, error) { + if hashAlgorithm == RAW { + if len(reportData) > 64 { + return nil, fmt.Errorf("report data is too large, it should be at most 64 bytes when hashAlgorithm is RAW") + } + if len(reportData) < 64 { + reportData = append(make([]byte, 64-len(reportData)), reportData...) + } + } + + payload := map[string]interface{}{ + "report_data": hex.EncodeToString(reportData), + "hash_algorithm": string(hashAlgorithm), + } + + data, err := c.sendRPCRequest(ctx, "/prpc/Tappd.TdxQuote", payload) + if err != nil { + return nil, err + } + + var response TdxQuoteResponse + if err := json.Unmarshal(data, &response); err != nil { + return nil, err + } + return &response, nil +} diff --git a/sdk/go/tappd/client_test.go b/sdk/go/tappd/client_test.go new file mode 100644 index 00000000..00d39827 --- /dev/null +++ b/sdk/go/tappd/client_test.go @@ -0,0 +1,134 @@ +package tappd_test + +import ( + "bytes" + "context" + "encoding/hex" + "encoding/json" + "strings" + "testing" + + "github.com/Dstack-TEE/dstack/sdk/go/tappd" +) + +func TestDeriveKey(t *testing.T) { + client := tappd.NewTappdClient() + resp, err := client.DeriveKeyWithSubjectAndAltNames(context.Background(), "/", "test", nil) + if err != nil { + t.Fatal(err) + } + + if resp.Key == "" { + t.Error("expected key to not be empty") + } + + if len(resp.CertificateChain) == 0 { + t.Error("expected certificate chain to not be empty") + } + + // Test ToBytes + key, err := resp.ToBytes(-1) + if err != nil { + t.Fatal(err) + } + if len(key) == 0 { + t.Error("expected key bytes to not be empty") + } + + // Test ToBytes with max length + key, err = resp.ToBytes(32) + if err != nil { + t.Fatal(err) + } + if len(key) != 32 { + t.Errorf("expected key length to be 32, got %d", len(key)) + } +} + +func TestTdxQuote(t *testing.T) { + client := tappd.NewTappdClient() + resp, err := client.TdxQuote(context.Background(), []byte("test")) + if err != nil { + t.Fatal(err) + } + + if resp.Quote == "" { + t.Error("expected quote to not be empty") + } + + if !strings.HasPrefix(resp.Quote, "0x") { + t.Error("expected quote to start with 0x") + } + + if resp.EventLog == "" { + t.Error("expected event log to not be empty") + } + + var eventLog []map[string]interface{} + err = json.Unmarshal([]byte(resp.EventLog), &eventLog) + if err != nil { + t.Errorf("expected event log to be a valid JSON object: %v", err) + } + + quoteBytes, err := hex.DecodeString(resp.Quote[2:]) + if err != nil { + t.Errorf("expected quote to be a valid hex string: %v", err) + } + + // Get quote RTMRs manually + quoteRtmrs := [4][48]byte{ + [48]byte(quoteBytes[376:424]), + [48]byte(quoteBytes[424:472]), + [48]byte(quoteBytes[472:520]), + [48]byte(quoteBytes[520:568]), + } + + // Test ReplayRTMRs + rtmrs, err := resp.ReplayRTMRs() + if err != nil { + t.Fatal(err) + } + + if len(rtmrs) != 4 { + t.Errorf("expected 4 RTMRs, got %d", len(rtmrs)) + } + + // Verify RTMRs + for i := 0; i < 4; i++ { + if rtmrs[i] == "" { + t.Errorf("expected RTMR %d to not be empty", i) + } + + rtmrBytes, err := hex.DecodeString(rtmrs[i]) + if err != nil { + t.Errorf("expected RTMR %d to be valid hex: %v", i, err) + } + + if !bytes.Equal(rtmrBytes, quoteRtmrs[i][:]) { + t.Errorf("expected RTMR %d to be %s, got %s", i, hex.EncodeToString(quoteRtmrs[i][:]), rtmrs[i]) + } + } +} + +func TestTdxQuoteRawHash(t *testing.T) { + client := tappd.NewTappdClient() + + // Test valid raw hash + resp, err := client.TdxQuoteWithHashAlgorithm(context.Background(), []byte("test"), tappd.RAW) + if err != nil { + t.Fatal(err) + } + if resp.Quote == "" { + t.Error("expected quote to not be empty") + } + + // Test too large raw hash + largeData := make([]byte, 65) + _, err = client.TdxQuoteWithHashAlgorithm(context.Background(), largeData, tappd.RAW) + if err == nil { + t.Error("expected error for large raw hash data") + } + if !strings.Contains(err.Error(), "report data is too large") { + t.Errorf("unexpected error message: %v", err) + } +}