diff --git a/common/database.go b/common/database.go index fa2bffe..0c8f575 100644 --- a/common/database.go +++ b/common/database.go @@ -10,7 +10,7 @@ import ( var DB *gorm.DB -func init() { +func InitDB() *gorm.DB { // 配置 MySQL 连接参数 username := "root" // 账号 password := "jiaobaba" // 密码 @@ -37,6 +37,7 @@ func init() { db.AutoMigrate(&model.User{}) DB = db + return db } func GetDB() *gorm.DB { diff --git a/common/jwt.go b/common/jwt.go index 5685c62..c7b55ff 100644 --- a/common/jwt.go +++ b/common/jwt.go @@ -34,3 +34,12 @@ func ReleaseToken(user model.User) (string, error) { return tokenString, nil } + +func ParseToken(tokenString string) (*jwt.Token, *Claims, error) { + claims := &Claims{} + + token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { + return jwtKey, nil + }) + return token, claims, err +} diff --git a/controller/info.go b/controller/info.go new file mode 100644 index 0000000..fe1d50d --- /dev/null +++ b/controller/info.go @@ -0,0 +1,16 @@ +package controller + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +func Info(ctx *gin.Context) { + user, _ := ctx.Get("user") + + ctx.JSON(http.StatusOK, gin.H{ + "code": http.StatusOK, + "data": gin.H{"user": user}, + }) +} diff --git a/controller/login.go b/controller/login.go index afa8150..c819f73 100644 --- a/controller/login.go +++ b/controller/login.go @@ -10,7 +10,7 @@ import ( ) func Login(ctx *gin.Context) { - DB := common.GetDB() + DB := common.InitDB() // 获取参数 telephone := ctx.Query("telephone") password := ctx.Query("password") diff --git a/middleware/authmiddleware.go b/middleware/authmiddleware.go new file mode 100644 index 0000000..03835bf --- /dev/null +++ b/middleware/authmiddleware.go @@ -0,0 +1,60 @@ +package middleware + +import ( + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/zggsong/gin-vue-demo/common" + "github.com/zggsong/gin-vue-demo/model" +) + +// 认证中间件 +func AuthMiddleWare() gin.HandlerFunc { + return func(c *gin.Context) { + // 获取 authorization header + tokenString := c.GetHeader("Authorization") + + // 验证token + if tokenString == "" || !strings.HasPrefix(tokenString, "Bearer ") { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": http.StatusUnauthorized, + "message": "请求未授权", + }) + c.Abort() + return + } + + tokenString = tokenString[7:] + + token, claims, err := common.ParseToken(tokenString) + if err != nil || !token.Valid { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": http.StatusUnauthorized, + "message": "请求未授权", + }) + c.Abort() + return + } + // 验证通过后获取Claim中的UserId + userId := claims.UserId + DB := common.GetDB() + var user model.User + DB.First(&user, userId) + + // 用户不存在 + if user.ID == 0 { + c.JSON(http.StatusUnauthorized, gin.H{ + "code": http.StatusUnauthorized, + "message": "请求未授权", + }) + c.Abort() + return + } + + // 用户信息存在,则将用户信息存入上下文 + c.Set("user", user) + + c.Next() + } +} diff --git a/router/router.go b/router/router.go index f9a1ced..8191589 100644 --- a/router/router.go +++ b/router/router.go @@ -3,10 +3,12 @@ package router import ( "github.com/gin-gonic/gin" "github.com/zggsong/gin-vue-demo/controller" + "github.com/zggsong/gin-vue-demo/middleware" ) func CollectRoute(r *gin.Engine) *gin.Engine { r.POST("/api/auth/register", controller.Register) r.POST("/api/auth/login", controller.Login) + r.GET("/api/auth/info", middleware.AuthMiddleWare(), controller.Info) // 认证中间件保护info接口 return r }