-
-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* skeleton of postgres skeleton * Adding Postgres specific db schema sql * user test passed * memo store test passed * tag is working * update user setting test done * activity test done * idp test passed * inbox test done * memo_organizer, UNTESTED * memo relation test passed * webhook test passed * system setting test passed * passed storage test * pass resource test * migration_history done * fix memo_relation_test * fixing server memo_relation test * passes memo relation server test * paess memo test * final manual testing done * final fixes * final fixes cleanup * sync schema * lint * lint * lint * lint * lint
- Loading branch information
Showing
28 changed files
with
2,980 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
package postgres | ||
|
||
import ( | ||
"context" | ||
"time" | ||
|
||
"github.com/Masterminds/squirrel" | ||
"github.com/pkg/errors" | ||
"google.golang.org/protobuf/encoding/protojson" | ||
|
||
storepb "github.com/usememos/memos/proto/gen/store" | ||
"github.com/usememos/memos/store" | ||
) | ||
|
||
func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) { | ||
payloadString := "{}" | ||
if create.Payload != nil { | ||
bytes, err := protojson.Marshal(create.Payload) | ||
if err != nil { | ||
return nil, errors.Wrap(err, "failed to marshal activity payload") | ||
} | ||
payloadString = string(bytes) | ||
} | ||
|
||
qb := squirrel.Insert("activity"). | ||
Columns("creator_id", "type", "level", "payload"). | ||
PlaceholderFormat(squirrel.Dollar) | ||
|
||
values := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString} | ||
|
||
if create.ID != 0 { | ||
qb = qb.Columns("id") | ||
values = append(values, create.ID) | ||
} | ||
|
||
if create.CreatedTs != 0 { | ||
qb = qb.Columns("created_ts") | ||
values = append(values, squirrel.Expr("TO_TIMESTAMP(?)", create.CreatedTs)) | ||
} | ||
|
||
qb = qb.Values(values...).Suffix("RETURNING id") | ||
|
||
stmt, args, err := qb.ToSql() | ||
if err != nil { | ||
return nil, errors.Wrap(err, "failed to construct query") | ||
} | ||
|
||
var id int32 | ||
err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id) | ||
if err != nil { | ||
return nil, errors.Wrap(err, "failed to execute statement and retrieve ID") | ||
} | ||
|
||
list, err := d.ListActivities(ctx, &store.FindActivity{ID: &id}) | ||
if err != nil || len(list) == 0 { | ||
return nil, errors.Wrap(err, "failed to find activity") | ||
} | ||
|
||
return list[0], nil | ||
} | ||
|
||
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) { | ||
qb := squirrel.Select("id", "creator_id", "type", "level", "payload", "created_ts"). | ||
From("activity"). | ||
Where("1 = 1"). | ||
PlaceholderFormat(squirrel.Dollar) | ||
|
||
if find.ID != nil { | ||
qb = qb.Where(squirrel.Eq{"id": *find.ID}) | ||
} | ||
if find.Type != nil { | ||
qb = qb.Where(squirrel.Eq{"type": find.Type.String()}) | ||
} | ||
|
||
query, args, err := qb.ToSql() | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
rows, err := d.db.QueryContext(ctx, query, args...) | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer rows.Close() | ||
|
||
list := []*store.Activity{} | ||
for rows.Next() { | ||
activity := &store.Activity{} | ||
var payloadBytes []byte | ||
createdTsPlaceHolder := time.Time{} | ||
if err := rows.Scan( | ||
&activity.ID, | ||
&activity.CreatorID, | ||
&activity.Type, | ||
&activity.Level, | ||
&payloadBytes, | ||
&createdTsPlaceHolder, | ||
); err != nil { | ||
return nil, err | ||
} | ||
|
||
activity.CreatedTs = createdTsPlaceHolder.Unix() | ||
|
||
payload := &storepb.ActivityPayload{} | ||
if err := protojson.Unmarshal(payloadBytes, payload); err != nil { | ||
return nil, err | ||
} | ||
activity.Payload = payload | ||
list = append(list, activity) | ||
} | ||
|
||
if err := rows.Err(); err != nil { | ||
return nil, err | ||
} | ||
|
||
return list, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
package postgres | ||
|
||
import "google.golang.org/protobuf/encoding/protojson" | ||
|
||
var ( | ||
protojsonUnmarshaler = protojson.UnmarshalOptions{ | ||
DiscardUnknown: true, | ||
} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
package postgres | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
|
||
"github.com/Masterminds/squirrel" | ||
"github.com/pkg/errors" | ||
|
||
"github.com/usememos/memos/store" | ||
) | ||
|
||
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) { | ||
var configBytes []byte | ||
if create.Type == store.IdentityProviderOAuth2Type { | ||
bytes, err := json.Marshal(create.Config.OAuth2Config) | ||
if err != nil { | ||
return nil, err | ||
} | ||
configBytes = bytes | ||
} else { | ||
return nil, errors.Errorf("unsupported idp type %s", string(create.Type)) | ||
} | ||
|
||
qb := squirrel.Insert("idp").Columns("name", "type", "identifier_filter", "config") | ||
values := []any{create.Name, create.Type, create.IdentifierFilter, string(configBytes)} | ||
|
||
if create.ID != 0 { | ||
qb = qb.Columns("id") | ||
values = append(values, create.ID) | ||
} | ||
|
||
qb = qb.Values(values...).PlaceholderFormat(squirrel.Dollar) | ||
qb = qb.Suffix("RETURNING id") | ||
|
||
stmt, args, err := qb.ToSql() | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
var id int32 | ||
err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
create.ID = id | ||
return create, nil | ||
} | ||
func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) { | ||
qb := squirrel.Select("id", "name", "type", "identifier_filter", "config"). | ||
From("idp"). | ||
Where("1 = 1"). | ||
PlaceholderFormat(squirrel.Dollar) | ||
|
||
if v := find.ID; v != nil { | ||
qb = qb.Where(squirrel.Eq{"id": *v}) | ||
} | ||
|
||
query, args, err := qb.ToSql() | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
rows, err := d.db.QueryContext(ctx, query, args...) | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer rows.Close() | ||
|
||
var identityProviders []*store.IdentityProvider | ||
for rows.Next() { | ||
var identityProvider store.IdentityProvider | ||
var identityProviderConfig string | ||
if err := rows.Scan( | ||
&identityProvider.ID, | ||
&identityProvider.Name, | ||
&identityProvider.Type, | ||
&identityProvider.IdentifierFilter, | ||
&identityProviderConfig, | ||
); err != nil { | ||
return nil, err | ||
} | ||
|
||
if identityProvider.Type == store.IdentityProviderOAuth2Type { | ||
oauth2Config := &store.IdentityProviderOAuth2Config{} | ||
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { | ||
return nil, err | ||
} | ||
identityProvider.Config = &store.IdentityProviderConfig{ | ||
OAuth2Config: oauth2Config, | ||
} | ||
} else { | ||
return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type)) | ||
} | ||
identityProviders = append(identityProviders, &identityProvider) | ||
} | ||
|
||
if err := rows.Err(); err != nil { | ||
return nil, err | ||
} | ||
|
||
return identityProviders, nil | ||
} | ||
|
||
func (d *DB) GetIdentityProvider(ctx context.Context, find *store.FindIdentityProvider) (*store.IdentityProvider, error) { | ||
list, err := d.ListIdentityProviders(ctx, find) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if len(list) == 0 { | ||
return nil, nil | ||
} | ||
|
||
return list[0], nil | ||
} | ||
|
||
func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) { | ||
qb := squirrel.Update("idp"). | ||
PlaceholderFormat(squirrel.Dollar) | ||
var err error | ||
|
||
if v := update.Name; v != nil { | ||
qb = qb.Set("name", *v) | ||
} | ||
if v := update.IdentifierFilter; v != nil { | ||
qb = qb.Set("identifier_filter", *v) | ||
} | ||
if v := update.Config; v != nil { | ||
var configBytes []byte | ||
if update.Type == store.IdentityProviderOAuth2Type { | ||
bytes, err := json.Marshal(update.Config.OAuth2Config) | ||
if err != nil { | ||
return nil, err | ||
} | ||
configBytes = bytes | ||
} else { | ||
return nil, errors.Errorf("unsupported idp type %s", string(update.Type)) | ||
} | ||
qb = qb.Set("config", string(configBytes)) | ||
} | ||
|
||
qb = qb.Where(squirrel.Eq{"id": update.ID}) | ||
|
||
stmt, args, err := qb.ToSql() | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
_, err = d.db.ExecContext(ctx, stmt, args...) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return d.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &update.ID}) | ||
} | ||
|
||
func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error { | ||
qb := squirrel.Delete("idp"). | ||
Where(squirrel.Eq{"id": delete.ID}). | ||
PlaceholderFormat(squirrel.Dollar) | ||
|
||
stmt, args, err := qb.ToSql() | ||
if err != nil { | ||
return err | ||
} | ||
|
||
result, err := d.db.ExecContext(ctx, stmt, args...) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
if _, err = result.RowsAffected(); err != nil { | ||
return err | ||
} | ||
|
||
return nil | ||
} |
Oops, something went wrong.