Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Go SDK implementation #47

Merged
merged 24 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6ccc782
feat: Add Golang SDK
Hyodar Dec 9, 2024
7819400
feat: Rename go package to tappd
Hyodar Dec 9, 2024
66c417a
test: Add Go SDK unit tests
Hyodar Dec 9, 2024
a7e586b
docs: Add Go SDK README
Hyodar Dec 9, 2024
815b46d
feat: Add hash algorithms support to go SDK
Hyodar Dec 9, 2024
ef479f1
feat: Add RTMR replay to go SDK
Hyodar Dec 9, 2024
63faf21
refactor: Use info log instead of warn
Hyodar Dec 9, 2024
7dd4857
test: Add go SDK unit tests
Hyodar Dec 9, 2024
254b580
feat: Add logger to Tappd client creation in go SDK
Hyodar Dec 9, 2024
6622f2a
refactor: Avoid extra hex decoding in go SDK
Hyodar Dec 9, 2024
da6a14b
docs: Mention when altNames is included in the request
Hyodar Dec 9, 2024
7716689
test: Test parsing TDX quotes
Hyodar Dec 9, 2024
1730167
test: Move tests to test package
Hyodar Dec 10, 2024
49fdc6a
refactor: Add constructor options
Hyodar Dec 10, 2024
6c19a25
docs: Add installation instructions and snippet
Hyodar Dec 10, 2024
59dad92
test: Parse RTMRs manually and remove go-tdx-qpl dependency
Hyodar Dec 11, 2024
8eb1aaa
refactor: Improve report data size check error on raw report
Hyodar Dec 11, 2024
729aecd
docs: Fix typo
Hyodar Dec 11, 2024
864e4c2
refactor: Add TdxQuoteWithHashAlgorithm
Hyodar Dec 11, 2024
44106a9
refactor: Add DeriveKeyWithSubject and DeriveKeyWithSubjectAndAltNames
Hyodar Dec 11, 2024
25dfa33
refactor: Avoid unnecessary string operations
Hyodar Dec 11, 2024
00ab7b9
docs: Mention report data size and padding for raw hashing
Hyodar Dec 11, 2024
24da94d
docs: Add package summary and author
Hyodar Dec 11, 2024
9eaf1dc
test: Fix raw hashing test using SHA512
Hyodar Dec 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions sdk/go/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Go SDK

The DStack SDK for Go.

# Installation

```bash
go get github.com/Dstack-TEE/dstack/sdk/go
```

# Usage

```go
package main

import (
"context"
"fmt"
"log/slog"

"github.com/Dstack-TEE/dstack/sdk/go/tappd"
)

func main() {
client := tappd.NewTappdClient(
// tappd.WithEndpoint("http://localhost"),
// tappd.WithLogger(slog.Default()),
)

deriveKeyResp, err := client.DeriveKey(context.Background(), "/", "test", nil)
if err != nil {
fmt.Println(err)
return
}
fmt.Println(deriveKeyResp) // &{-----BEGIN PRIVATE KEY--- ...

tdxQuoteResp, err := client.TdxQuote(context.Background(), []byte("test"))
if err != nil {
fmt.Println(err)
return
}
fmt.Println(tdxQuoteResp) // &{0x0000000000000000000 ...

rtmrs, err := tdxQuoteResp.ReplayRTMRs()
if err != nil {
fmt.Println(err)
return
}
fmt.Println(rtmrs) // map[0:00000000000000000 ...
}
```

# For Development

Set up [Go](https://go.dev/doc/install).

Running the unit tests with local simulator via `/tmp/tappd.sock`:

```bash
DSTACK_SIMULATOR_ENDPOINT=/tmp/tappd.sock go test ./...
```
3 changes: 3 additions & 0 deletions sdk/go/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module github.com/Dstack-TEE/dstack/sdk/go

go 1.21.6
Empty file added sdk/go/go.sum
Empty file.
281 changes: 281 additions & 0 deletions sdk/go/tappd/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
package tappd

import (
"bytes"
"context"
"crypto/sha512"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"os"
"strings"
)

// Represents the hash algorithm used in TDX quote generation.
type QuoteHashAlgorithm string

const (
SHA256 QuoteHashAlgorithm = "sha256"
SHA384 QuoteHashAlgorithm = "sha384"
SHA512 QuoteHashAlgorithm = "sha512"
SHA3_256 QuoteHashAlgorithm = "sha3-256"
SHA3_384 QuoteHashAlgorithm = "sha3-384"
SHA3_512 QuoteHashAlgorithm = "sha3-512"
KECCAK256 QuoteHashAlgorithm = "keccak256"
KECCAK384 QuoteHashAlgorithm = "keccak384"
KECCAK512 QuoteHashAlgorithm = "keccak512"
RAW QuoteHashAlgorithm = "raw"
)

// Represents the response from a key derivation request.
type DeriveKeyResponse struct {
Key string `json:"key"`
CertificateChain []string `json:"certificate_chain"`
}

// Decodes the key to bytes, optionally truncating to maxLength. If maxLength
// < 0, the key is not truncated.
func (d *DeriveKeyResponse) ToBytes(maxLength int) ([]byte, error) {
content := d.Key

content = strings.Replace(content, "-----BEGIN PRIVATE KEY-----", "", 1)
content = strings.Replace(content, "-----END PRIVATE KEY-----", "", 1)
content = strings.Replace(content, "\n", "", -1)

binary, err := base64.StdEncoding.DecodeString(content)
if err != nil {
return nil, err
}

if maxLength >= 0 && len(binary) > maxLength {
return binary[:maxLength], nil
}
return binary, nil
}

// Represents the response from a TDX quote request.
type TdxQuoteResponse struct {
Quote string `json:"quote"`
EventLog string `json:"event_log"`
}

const INIT_MR = "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"

// Replays the RTMR history to calculate final RTMR values
func replayRTMR(history []string) (string, error) {
if len(history) == 0 {
return INIT_MR, nil
}

mr := make([]byte, 48)

for _, content := range history {
contentBytes, err := hex.DecodeString(content)
if err != nil {
return "", err
}

if len(contentBytes) < 48 {
padding := make([]byte, 48-len(contentBytes))
contentBytes = append(contentBytes, padding...)
}

h := sha512.New384()
h.Write(append(mr, contentBytes...))
mr = h.Sum(nil)
}

return hex.EncodeToString(mr), nil
}

// Replays the RTMR history to calculate final RTMR values
func (r *TdxQuoteResponse) ReplayRTMRs() (map[int]string, error) {
var eventLog []struct {
IMR int `json:"imr"`
Digest string `json:"digest"`
}
json.Unmarshal([]byte(r.EventLog), &eventLog)

rtmrs := make(map[int]string, 4)
for idx := 0; idx < 4; idx++ {
history := make([]string, 0)
for _, event := range eventLog {
if event.IMR == idx {
history = append(history, event.Digest)
}
}

rtmr, err := replayRTMR(history)
if err != nil {
return nil, err
}

rtmrs[idx] = rtmr
}

return rtmrs, nil
}

// Handles communication with the Tappd service.
type TappdClient struct {
endpoint string
baseURL string
httpClient *http.Client
logger *slog.Logger
}

// Functional option for configuring a TappdClient.
type TappdClientOption func(*TappdClient)

// Sets the endpoint for the TappdClient.
func WithEndpoint(endpoint string) TappdClientOption {
return func(c *TappdClient) {
c.endpoint = endpoint
}
}

// Sets the logger for the TappdClient
func WithLogger(logger *slog.Logger) TappdClientOption {
return func(c *TappdClient) {
c.logger = logger
}
}

// Creates a new TappdClient instance based on the provided endpoint.
// If the endpoint is empty, it will use the simulator endpoint if it is
// set in the environment through DSTACK_SIMULATOR_ENDPOINT. Otherwise, it
// will use the default endpoint at /var/run/tappd.sock.
func NewTappdClient(opts ...TappdClientOption) *TappdClient {
client := &TappdClient{
endpoint: "",
baseURL: "",
httpClient: &http.Client{},
logger: slog.Default(),
}

for _, opt := range opts {
opt(client)
}

client.endpoint = client.getEndpoint()

if strings.HasPrefix(client.endpoint, "http://") || strings.HasPrefix(client.endpoint, "https://") {
client.baseURL = client.endpoint
} else {
client.baseURL = "http://localhost"
client.httpClient = &http.Client{
Transport: &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", client.endpoint)
},
},
}
}

return client
}

// Returns the appropriate endpoint based on environment and input. If the
// endpoint is empty, it will use the simulator endpoint if it is set in the
// environment through DSTACK_SIMULATOR_ENDPOINT. Otherwise, it will use the
// default endpoint at /var/run/tappd.sock.
func (c *TappdClient) getEndpoint() string {
if c.endpoint != "" {
return c.endpoint
}
if simEndpoint, exists := os.LookupEnv("DSTACK_SIMULATOR_ENDPOINT"); exists {
c.logger.Info("using simulator endpoint", "endpoint", simEndpoint)
return simEndpoint
}
return "/var/run/tappd.sock"
}

// Sends an RPC request to the Tappd service.
func (c *TappdClient) sendRPCRequest(ctx context.Context, path string, payload interface{}) ([]byte, error) {
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, err
}

req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+path, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}

req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

return io.ReadAll(resp.Body)
}

// Derives a key from the Tappd service. If altNames is empty or nil, it will
// not be included in the request.
func (c *TappdClient) DeriveKey(ctx context.Context, path string, subject string, altNames []string) (*DeriveKeyResponse, error) {
Leechael marked this conversation as resolved.
Show resolved Hide resolved
payload := map[string]interface{}{
"path": path,
"subject": subject,
}
if len(altNames) > 0 {
payload["alt_names"] = altNames
}

data, err := c.sendRPCRequest(ctx, "/prpc/Tappd.DeriveKey", payload)
if err != nil {
return nil, err
}

var response DeriveKeyResponse
if err := json.Unmarshal(data, &response); err != nil {
return nil, err
}
return &response, nil
}

// Sends a TDX quote request to the Tappd service using SHA512 as the report
// data hash algorithm.
func (c *TappdClient) TdxQuote(ctx context.Context, reportData []byte) (*TdxQuoteResponse, error) {
return c.TdxQuoteWithHashAlgorithm(ctx, reportData, SHA512)
}

// Sends a TDX quote request to the Tappd service with a specific hash
// report data hash algorithm.
func (c *TappdClient) TdxQuoteWithHashAlgorithm(ctx context.Context, reportData []byte, hashAlgorithm QuoteHashAlgorithm) (*TdxQuoteResponse, error) {
hexData := hex.EncodeToString(reportData)
if hashAlgorithm == RAW {
if len(hexData) > 128 {
return nil, fmt.Errorf("report data is too large, it should be at most 64 bytes when hashAlgorithm is RAW")
}
if len(hexData) < 128 {
hexData = strings.Repeat("0", 128-len(hexData)) + hexData
}
}

payload := map[string]interface{}{
"report_data": hexData,
"hash_algorithm": string(hashAlgorithm),
}

data, err := c.sendRPCRequest(ctx, "/prpc/Tappd.TdxQuote", payload)
if err != nil {
return nil, err
}

var response TdxQuoteResponse
if err := json.Unmarshal(data, &response); err != nil {
return nil, err
}
return &response, nil
}
Loading