Skip to content

Commit

Permalink
feat: support Postgres (#2569)
Browse files Browse the repository at this point in the history
* skeleton of postgres

skeleton

* Adding Postgres specific db schema sql

* user test passed

* memo store test passed

* tag is working

* update user setting test done

* activity test done

* idp test passed

* inbox test done

* memo_organizer, UNTESTED

* memo relation test passed

* webhook test passed

* system setting test passed

* passed storage test

* pass resource test

* migration_history done

* fix memo_relation_test

* fixing server memo_relation test

* passes memo relation server test

* paess memo test

* final manual testing done

* final fixes

* final fixes cleanup

* sync schema

* lint

* lint

* lint

* lint

* lint
  • Loading branch information
Irvingouj authored Dec 3, 2023
1 parent 484efbb commit 9c18960
Show file tree
Hide file tree
Showing 28 changed files with 2,980 additions and 0 deletions.
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/usememos/memos
go 1.21

require (
github.com/Masterminds/squirrel v1.5.4
github.com/aws/aws-sdk-go-v2 v1.22.1
github.com/aws/aws-sdk-go-v2/config v1.22.1
github.com/aws/aws-sdk-go-v2/credentials v1.15.1
Expand All @@ -16,6 +17,7 @@ require (
github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1
github.com/improbable-eng/grpc-web v0.15.0
github.com/labstack/echo/v4 v4.11.2
github.com/lib/pq v1.10.9
github.com/microcosm-cc/bluemonday v1.0.26
github.com/pkg/errors v0.9.1
github.com/spf13/cobra v1.8.0
Expand Down Expand Up @@ -50,6 +52,8 @@ require (
github.com/gorilla/css v1.0.1 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 // indirect
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rs/cors v1.10.1 // indirect
Expand Down
8 changes: 8 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym
github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0=
github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc=
github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE=
github.com/Masterminds/squirrel v1.5.4 h1:uUcX/aBc8O7Fg9kaISIUsHXdKuqehiXAMQTYX8afzqM=
github.com/Masterminds/squirrel v1.5.4/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10=
github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo=
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g=
Expand Down Expand Up @@ -368,7 +370,13 @@ github.com/labstack/echo/v4 v4.11.2 h1:T+cTLQxWCDfqDEoydYm5kCobjmHwOwcv4OJAPHilm
github.com/labstack/echo/v4 v4.11.2/go.mod h1:UcGuQ8V6ZNRmSweBIJkPvGfwCMIlFmiqrPqiEBfPYws=
github.com/labstack/gommon v0.4.0 h1:y7cvthEAEbU0yHOf4axH8ZG2NH8knB9iNSoTO8dyIk8=
github.com/labstack/gommon v0.4.0/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM=
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 h1:SOEGU9fKiNWd/HOJuq6+3iTQz8KNCLtVX6idSoTLdUw=
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0/go.mod h1:dXGbAdH5GtBTC4WfIxhKZfyBF/HBFgRZSWwZ9g/He9o=
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 h1:P6pPBnrTSX3DEVR4fDembhRWSsG5rVo6hYhAB/ADZrk=
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0/go.mod h1:vmVJ0l/dxyfGW6FmdpVm2joNMFikkuWg0EoCKLGUMNw=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM=
github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4=
github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ=
Expand Down
3 changes: 3 additions & 0 deletions store/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/store"
"github.com/usememos/memos/store/db/mysql"
"github.com/usememos/memos/store/db/postgres"
"github.com/usememos/memos/store/db/sqlite"
)

Expand All @@ -19,6 +20,8 @@ func NewDBDriver(profile *profile.Profile) (store.Driver, error) {
driver, err = sqlite.NewDB(profile)
case "mysql":
driver, err = mysql.NewDB(profile)
case "postgres":
driver, err = postgres.NewDB(profile)
default:
return nil, errors.New("unknown db driver")
}
Expand Down
117 changes: 117 additions & 0 deletions store/db/postgres/activity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package postgres

import (
"context"
"time"

"github.com/Masterminds/squirrel"
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"

storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)

func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) {
payloadString := "{}"
if create.Payload != nil {
bytes, err := protojson.Marshal(create.Payload)
if err != nil {
return nil, errors.Wrap(err, "failed to marshal activity payload")
}
payloadString = string(bytes)
}

qb := squirrel.Insert("activity").
Columns("creator_id", "type", "level", "payload").
PlaceholderFormat(squirrel.Dollar)

values := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString}

if create.ID != 0 {
qb = qb.Columns("id")
values = append(values, create.ID)
}

if create.CreatedTs != 0 {
qb = qb.Columns("created_ts")
values = append(values, squirrel.Expr("TO_TIMESTAMP(?)", create.CreatedTs))
}

qb = qb.Values(values...).Suffix("RETURNING id")

stmt, args, err := qb.ToSql()
if err != nil {
return nil, errors.Wrap(err, "failed to construct query")
}

var id int32
err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id)
if err != nil {
return nil, errors.Wrap(err, "failed to execute statement and retrieve ID")
}

list, err := d.ListActivities(ctx, &store.FindActivity{ID: &id})
if err != nil || len(list) == 0 {
return nil, errors.Wrap(err, "failed to find activity")
}

return list[0], nil
}

func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
qb := squirrel.Select("id", "creator_id", "type", "level", "payload", "created_ts").
From("activity").
Where("1 = 1").
PlaceholderFormat(squirrel.Dollar)

if find.ID != nil {
qb = qb.Where(squirrel.Eq{"id": *find.ID})
}
if find.Type != nil {
qb = qb.Where(squirrel.Eq{"type": find.Type.String()})
}

query, args, err := qb.ToSql()
if err != nil {
return nil, err
}

rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()

list := []*store.Activity{}
for rows.Next() {
activity := &store.Activity{}
var payloadBytes []byte
createdTsPlaceHolder := time.Time{}
if err := rows.Scan(
&activity.ID,
&activity.CreatorID,
&activity.Type,
&activity.Level,
&payloadBytes,
&createdTsPlaceHolder,
); err != nil {
return nil, err
}

activity.CreatedTs = createdTsPlaceHolder.Unix()

payload := &storepb.ActivityPayload{}
if err := protojson.Unmarshal(payloadBytes, payload); err != nil {
return nil, err
}
activity.Payload = payload
list = append(list, activity)
}

if err := rows.Err(); err != nil {
return nil, err
}

return list, nil
}
9 changes: 9 additions & 0 deletions store/db/postgres/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package postgres

import "google.golang.org/protobuf/encoding/protojson"

var (
protojsonUnmarshaler = protojson.UnmarshalOptions{
DiscardUnknown: true,
}
)
178 changes: 178 additions & 0 deletions store/db/postgres/idp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package postgres

import (
"context"
"encoding/json"

"github.com/Masterminds/squirrel"
"github.com/pkg/errors"

"github.com/usememos/memos/store"
)

func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
var configBytes []byte
if create.Type == store.IdentityProviderOAuth2Type {
bytes, err := json.Marshal(create.Config.OAuth2Config)
if err != nil {
return nil, err
}
configBytes = bytes
} else {
return nil, errors.Errorf("unsupported idp type %s", string(create.Type))
}

qb := squirrel.Insert("idp").Columns("name", "type", "identifier_filter", "config")
values := []any{create.Name, create.Type, create.IdentifierFilter, string(configBytes)}

if create.ID != 0 {
qb = qb.Columns("id")
values = append(values, create.ID)
}

qb = qb.Values(values...).PlaceholderFormat(squirrel.Dollar)
qb = qb.Suffix("RETURNING id")

stmt, args, err := qb.ToSql()
if err != nil {
return nil, err
}

var id int32
err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id)
if err != nil {
return nil, err
}

create.ID = id
return create, nil
}
func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) {
qb := squirrel.Select("id", "name", "type", "identifier_filter", "config").
From("idp").
Where("1 = 1").
PlaceholderFormat(squirrel.Dollar)

if v := find.ID; v != nil {
qb = qb.Where(squirrel.Eq{"id": *v})
}

query, args, err := qb.ToSql()
if err != nil {
return nil, err
}

rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()

var identityProviders []*store.IdentityProvider
for rows.Next() {
var identityProvider store.IdentityProvider
var identityProviderConfig string
if err := rows.Scan(
&identityProvider.ID,
&identityProvider.Name,
&identityProvider.Type,
&identityProvider.IdentifierFilter,
&identityProviderConfig,
); err != nil {
return nil, err
}

if identityProvider.Type == store.IdentityProviderOAuth2Type {
oauth2Config := &store.IdentityProviderOAuth2Config{}
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
return nil, err
}
identityProvider.Config = &store.IdentityProviderConfig{
OAuth2Config: oauth2Config,
}
} else {
return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type))
}
identityProviders = append(identityProviders, &identityProvider)
}

if err := rows.Err(); err != nil {
return nil, err
}

return identityProviders, nil
}

func (d *DB) GetIdentityProvider(ctx context.Context, find *store.FindIdentityProvider) (*store.IdentityProvider, error) {
list, err := d.ListIdentityProviders(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}

return list[0], nil
}

func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) {
qb := squirrel.Update("idp").
PlaceholderFormat(squirrel.Dollar)
var err error

if v := update.Name; v != nil {
qb = qb.Set("name", *v)
}
if v := update.IdentifierFilter; v != nil {
qb = qb.Set("identifier_filter", *v)
}
if v := update.Config; v != nil {
var configBytes []byte
if update.Type == store.IdentityProviderOAuth2Type {
bytes, err := json.Marshal(update.Config.OAuth2Config)
if err != nil {
return nil, err
}
configBytes = bytes
} else {
return nil, errors.Errorf("unsupported idp type %s", string(update.Type))
}
qb = qb.Set("config", string(configBytes))
}

qb = qb.Where(squirrel.Eq{"id": update.ID})

stmt, args, err := qb.ToSql()
if err != nil {
return nil, err
}

_, err = d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return nil, err
}

return d.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &update.ID})
}

func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error {
qb := squirrel.Delete("idp").
Where(squirrel.Eq{"id": delete.ID}).
PlaceholderFormat(squirrel.Dollar)

stmt, args, err := qb.ToSql()
if err != nil {
return err
}

result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return err
}

if _, err = result.RowsAffected(); err != nil {
return err
}

return nil
}
Loading

0 comments on commit 9c18960

Please sign in to comment.