Skip to content

Commit

Permalink
- 新增后台登录接口
Browse files Browse the repository at this point in the history
- 新增 jwt 检测
- 代码优化
  • Loading branch information
chenmingyong0423 committed Feb 12, 2024
1 parent c15366a commit 0a126e9
Show file tree
Hide file tree
Showing 4 changed files with 396 additions and 0 deletions.
53 changes: 53 additions & 0 deletions server/internal/ioc/jwt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright 2024 chenmingyong0423

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ioc

import (
"strings"

"github.com/chenmingyong0423/fnote/server/internal/pkg/jwtutil"
"github.com/gin-gonic/gin"
)

// JwtParseMiddleware jwt 解析中间件
func JwtParseMiddleware() gin.HandlerFunc {
return func(ctx *gin.Context) {
uri := ctx.Request.RequestURI
// 非 admin 接口不需要 jwt
if !strings.HasPrefix(uri, "/admin") {
ctx.Next()
return
}
// 登录和初始化接口不需要 jwt
if uri == "/admin/login" || uri == "/admin/init" {
ctx.Next()
return
}

jwtStr := ctx.GetHeader("Authorization")
if jwtStr == "" {
ctx.AbortWithStatusJSON(401, nil)
return
}
// 解析 jwt
claims, err := jwtutil.ParseJwt(jwtStr)
if err != nil {
ctx.AbortWithStatusJSON(401, nil)
return
}
ctx.Set("jwtClaims", claims)
ctx.Next()
}
}
113 changes: 113 additions & 0 deletions server/internal/pkg/aesutil/aes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright 2024 chenmingyong0423

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package aesutil

import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
)

var key []byte

func init() {
var err error
key, err = generateRandomBytes(32) // 生成256位密钥
if err != nil {
panic(err)
}
}

// generateRandomBytes 生成指定长度的随机字节
func generateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
if _, err := io.ReadFull(rand.Reader, b); err != nil {
return nil, err
}
return b, nil
}

// AesEncrypt 加密给定的消息
func AesEncrypt(plainText []byte) (string, error) {
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}

// 填充原文以满足AES块大小
padding := block.BlockSize() - len(plainText)%block.BlockSize()
padText := bytes.Repeat([]byte{byte(padding)}, padding)
paddedText := append(plainText, padText...)

// 初始化向量IV必须是唯一的,但不需要保密
iv, err := generateRandomBytes(block.BlockSize())
if err != nil {
return "", err
}

// 加密
ciphertext := make([]byte, len(paddedText))
mode := cipher.NewCBCEncrypter(block, iv)
mode.CryptBlocks(ciphertext, paddedText)

// 将IV附加到密文前以便解密时使用
encrypted := base64.StdEncoding.EncodeToString(append(iv, ciphertext...))
return encrypted, nil
}

// AesDecrypt 解密给定的消息
func AesDecrypt(encrypted string) (string, error) {
encryptedBytes, err := base64.StdEncoding.DecodeString(encrypted)
if err != nil {
return "", err
}

block, err := aes.NewCipher(key)
if err != nil {
return "", err
}

if len(encryptedBytes) < block.BlockSize() {
return "", fmt.Errorf("ciphertext too short")
}

// 提取IV
iv := encryptedBytes[:block.BlockSize()]
encryptedBytes = encryptedBytes[block.BlockSize():]

// 解密
decrypted := make([]byte, len(encryptedBytes))
mode := cipher.NewCBCDecrypter(block, iv)
mode.CryptBlocks(decrypted, encryptedBytes)

// 移除填充
padding := decrypted[len(decrypted)-1]
if int(padding) > len(decrypted) || padding == 0 {
return "", fmt.Errorf("invalid padding")
}
padLen := int(padding)
for _, val := range decrypted[len(decrypted)-padLen:] {
if val != padding {
return "", fmt.Errorf("invalid padding")
}
}
decrypted = decrypted[:len(decrypted)-padLen]

return string(decrypted), nil
}
83 changes: 83 additions & 0 deletions server/internal/pkg/jwtutil/jwt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright 2024 chenmingyong0423

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package jwtutil

import (
"crypto/rand"
"time"

"github.com/chenmingyong0423/fnote/server/internal/pkg/aesutil"

"github.com/pkg/errors"

"github.com/golang-jwt/jwt/v5"
)

var (
jwtKey []byte
)

func init() {
jwtKey = make([]byte, 32) // 生成32字节(256位)的密钥
if _, err := rand.Read(jwtKey); err != nil {
panic(err) // 生成密钥时发生错误
}
}

// GenerateJwt 生成 JWT
func GenerateJwt() (string, int64, error) {
now := time.Now()
exp := now.Add(time.Hour * 12)
t := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
Issuer: "https://github.com/chenmingyong0423/fnote",
Subject: "",
Audience: nil,
ExpiresAt: jwt.NewNumericDate(exp),
NotBefore: jwt.NewNumericDate(now),
IssuedAt: jwt.NewNumericDate(now),
})
signedString, err := t.SignedString(jwtKey)
if err != nil {
return "", 0, errors.Wrap(err, "generate jwt failed")
}

// aes 加密
encrypt, err := aesutil.AesEncrypt([]byte(signedString))
if err != nil {
return "", 0, err
}
return encrypt, exp.Unix(), nil
}

func ParseJwt(jwtStr string) (jwt.Claims, error) {
claims := &jwt.RegisteredClaims{}
decrypt, err := aesutil.AesDecrypt(jwtStr)
if err != nil {
return nil, err
}
token, err := jwt.ParseWithClaims(decrypt, claims, func(token *jwt.Token) (interface{}, error) {
return jwtKey, nil
})

if err != nil {
return nil, err
}

if claims, ok := token.Claims.(*jwt.RegisteredClaims); ok && token.Valid {
return claims, nil
} else {
return nil, err
}
}
147 changes: 147 additions & 0 deletions server/internal/pkg/web/wrap/wrap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Copyright 2023 chenmingyong0423

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package apiwrap

import (
"errors"
"fmt"
"log/slog"
"net/http"

"github.com/gin-gonic/gin"
)

type ResponseBody[T any] struct {
Code int `json:"code"`
Message string `json:"message"`
Data T `json:"data,omitempty"`
}

func SuccessResponse() *ResponseBody[any] {
return &ResponseBody[any]{
Code: 0,
Message: "success",
}
}

func SuccessResponseWithData[T any](data T) *ResponseBody[T] {
return &ResponseBody[T]{
Code: 0,
Message: "success",
Data: data,
}
}

func NewResponseBody[T any](code int, message string, data T) *ResponseBody[T] {
return &ResponseBody[T]{
Code: code,
Message: message,
Data: data,
}
}

type HttpCodeError int

type ErrorResponseBody struct {
HttpCode int
Message string
}

func NewErrorResponseBody(httpCode int, message string) ErrorResponseBody {
return ErrorResponseBody{
HttpCode: httpCode,
Message: message,
}
}

func (er ErrorResponseBody) Error() string {
return er.Message
}

type ListVO[T any] struct {
List []T `json:"list"`
}

func NewListVO[T any](t []T) ListVO[T] {
return ListVO[T]{
List: t,
}
}

type Page struct {
// 当前页
PageNo int64 `form:"pageNo" binding:"required"`
// 每页数量
PageSize int64 `form:"pageSize" binding:"required"`
}

type PageVO[T any] struct {
Page
// 总页数
TotalPages int64 `json:"totalPages"`
// 总数量
TotalCount int64 `json:"totalCount"`
List []T `json:"list"`
}

func (p *PageVO[T]) SetTotalCountAndCalculateTotalPages(totalCount int64) {
if p.PageSize == 0 {
p.TotalPages = 0
} else {
p.TotalPages = (totalCount + p.PageSize - 1) / p.PageSize
}
p.TotalCount = totalCount
}

func Wrap[T any](fn func(ctx *gin.Context) (T, error)) gin.HandlerFunc {
return func(ctx *gin.Context) {
result, err := fn(ctx)
if err != nil {
ErrorHandler(ctx, err)
return
}
ctx.JSON(http.StatusOK, result)
}
}

func ErrorHandler(ctx *gin.Context, err error) {
l := slog.Default().With("X-Request-ID", ctx.GetString("X-Request-ID"))
var e ErrorResponseBody
switch {
case errors.As(err, &e):
l.ErrorContext(ctx, e.Error())
ctx.JSON(e.HttpCode, nil)
default:
l.ErrorContext(ctx, fmt.Sprintf("%+v", err))
ctx.JSON(http.StatusInternalServerError, nil)
}
}

func WrapWithBody[T any, R any](fn func(ctx *gin.Context, req R) (T, error)) gin.HandlerFunc {
return func(ctx *gin.Context) {
var req R
bodyErr := ctx.Bind(&req)
if bodyErr != nil {
ErrorHandler(ctx, bodyErr)
return
}
result, err := fn(ctx, req)
if err != nil {
ErrorHandler(ctx, err)
return
}
ctx.JSON(http.StatusOK, result)
}
}

0 comments on commit 0a126e9

Please sign in to comment.